In [1]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn
from ipynb.fs.full.model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image  
from torch.utils.data import DataLoader, RandomSampler
import numpy as np
import matplotlib.pyplot as plt
import os
from mpl_toolkits import mplot3d
import cv2
from moviepy.editor import ImageSequenceClip

from typing import Any, Dict, List, Optional, Tuple, Union

### Importing the model

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 256
INIT_DIM = 8
LATENT_DIM = 3
NUM_EPOCHS = 50
BATCH_SIZE = 1
LR_RATE = 3e-4
KERNEL_SIZE = 4

In [4]:
# Dataset Loading
data_path = 'test_set' # setting path
transform = transforms.Compose([transforms.Resize((INPUT_DIM, INPUT_DIM)),   # sequence of transformations to be done
                                transforms.Grayscale(num_output_channels=1), # on each image (resize, greyscale,
                                transforms.ToTensor()])                      # convert to tensor)

dataset = datasets.ImageFolder(root=data_path, transform=transform) # read data from folder

train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) # create dataloader object

model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE).to(DEVICE) # initializing model object

optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE) # defining optimizer
loss_fn = nn.BCELoss(reduction='sum') # define loss function

In [5]:
model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE)
model.load_state_dict(torch.load('models/model_256x'))

<All keys matched successfully>

### Generating stack from image A to image B

Defining some helper functions:

In [6]:
def get_random_sample(mu: torch.tensor, sigma: torch.tensor) -> torch.tensor: # from the mean and standard deviation generate a random sample from distribution
    epsilon = torch.randn_like(sigma)
    sample = mu + sigma*epsilon
    return sample

def reconstruct(z: torch.tensor) -> np.array: # takes the sample and decodes it into a image
    out = model.decode(z) # decode sample
    out = out.view(-1, 1, 256, 256) # reshape it to a image
    save_image(out, 'temp.png') # save temporary image to png
    out = cv2.imread('temp.png') # read from saved png
    os.remove('temp.png') # delete temporary image (this is done because when I tried to convert it directly to numpy array the image got worse)
    return out

In [7]:
def function(img_A: torch.tensor,img_B: torch.tensor, n_steps: int, path: str) -> list[np.array]:
    
    # encoding the images
    with torch.no_grad():
        mu_A, sigma_A = model.encode(img_A.view(1,256,256))
        mu_B, sigma_B = model.encode(img_B.view(1,256,256))
        
    # getting random sample using the encodings
    z_A = get_random_sample(mu_A, sigma_A)
    z_B = get_random_sample(mu_B, sigma_B)
    
    # generating the steps from img A to img B
    alpha = np.linspace(0, 1, n_steps)
    zs = [(1-a)*z_A + a*z_B for a in alpha] # get latent vector from each step
    
    # generating the images for each step
    frames = []
    for i,z in enumerate(zs):
        img = reconstruct(z) # reconstruct image from latent vector
        thresh = 127
        frame = cv2.threshold(img, thresh, 255, cv2.THRESH_BINARY)[1] # binarize the image
        frames.append(frame)
    return frames

In [55]:
def generate_stack(img_index: list[int], n_steps: int, path: str) -> list[np.array]:
    
    # creating folder for saving images and gif
    try:
        os.mkdir(path)
        os.mkdir(path+'/generated_images')
        os.mkdir(path+'/real_images')
    except:
        print("Folder already exists")
    else:
        print(f"Folder {path} was created")
    
    # generating slices between the passed images
    images = [dataset[i][0] for i in img_index]
    frames = []
    for i in tqdm(range(1, len(images))):
        new_frames = function(images[i-1], images[i], n_steps, path)
        frames += new_frames
        
    print(f'{len(frames)} images were generated')
    
    for i,frame in enumerate(frames):
        # saving frames
        plt.imsave(f'{path}/generated_images/{i}.png', frame)
        
        # labeling the frame for the gif
        text = f'{i}'
        pos = (0,250)
        font = cv2.FONT_HERSHEY_SIMPLEX
        frame = cv2.putText(frame, text, org=pos, color=(255,0,0), fontFace = font, fontScale = .5)
    
    # writing gif
    clip = ImageSequenceClip(list(frames), fps=60)
    clip.write_gif(f'{path}/animation.gif', fps=60)
    
    print('Gif was created')
        
    return frames

In [56]:
slices = generate_stack(steps, 20, 'output/gifs/test_0_175_25')

Folder output/gifs/test_0_175_25 was created


100%|█████████████████████████████████████████████| 9/9 [00:01<00:00,  4.55it/s]


180 images were generated
MoviePy - Building file output/gifs/test_0_175_25/animation.gif with imageio.


                                                                                

Gif was created


In [58]:
steps

[0, 20, 40, 60, 80, 100, 120, 140, 160, 180]

In [42]:
steps = [i for i in range(200)]
steps = steps[::int(np.ceil(len(x)/10))]
steps

[0, 20, 40, 60, 80, 100, 120, 140, 160, 180]

In [57]:
paths = sorted(os.listdir('test_set/0'))[:200]
frames = []
for i,path in enumerate(paths):
    img = cv2.imread(f'test_set/0/{path}')
    img = cv2.resize(img,(256,256))
    thresh = 127
    im_bw = cv2.threshold(img, thresh, 255, cv2.THRESH_BINARY)[1]
    plt.imsave(f'output/gifs/test_0_175_25/real_images/{i}.png', im_bw)
    
    # labeling the frame for the gif
    text = f'{i}'
    pos = (0,250)
    font = cv2.FONT_HERSHEY_SIMPLEX
    frame = cv2.putText(im_bw, text, org=pos, color=(255,0,0), fontFace = font, fontScale = .5)
    frames.append(frame)
    
clip = ImageSequenceClip(frames, fps=60)
clip.write_gif(f'output/gifs/test_0_175_25/real_images/animation.gif', fps=60)

MoviePy - Building file output/gifs/test_0_175_25/real_images/animation.gif with imageio.


                                                                                

In [51]:
np.arange(1, 200+2, 25) - 1

array([  0,  25,  50,  75, 100, 125, 150, 175, 200])