In [43]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn
from ipynb.fs.full.model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image  
from torch.utils.data import DataLoader, RandomSampler
import numpy as np

### configuration

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 64
INIT_DIM = 8
LATENT_DIM = 3
NUM_EPOCHS = 200
BATCH_SIZE = 1
LR_RATE = 3e-4
KERNEL_SIZE = 4

### loading dataset

In [None]:
# Dataset Loading
data_path = 'dataset' # setting path
transform = transforms.Compose([transforms.Resize((64, 64)),                 # sequence of transformations to be done
                                transforms.Grayscale(num_output_channels=1), # on each image (resize, greyscale,
                                transforms.ToTensor()])                      # convert to tensor)
dataset = datasets.ImageFolder(root=data_path, transform=transform) # read data from folder
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) # create dataloader object
model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE).to(DEVICE) # initializing model object
optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE) # defining optimizer
loss_fn = nn.BCELoss(reduction='sum') # define loss function

### training model

In [16]:
# Start Training
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    print(f'Epoch: {epoch}')
    for i, (x, _) in loop:
        # forward pass
        x = x.to(DEVICE).view(1, INPUT_DIM, INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)
        
        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        
        # backpropagation
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

6it [00:00, 56.00it/s, loss=2.79e+3]

Epoch: 0


800it [00:13, 60.54it/s, loss=2.78e+3]
5it [00:00, 48.33it/s, loss=2.79e+3]

Epoch: 1


800it [00:13, 58.73it/s, loss=2.78e+3]
6it [00:00, 53.98it/s, loss=2.78e+3]

Epoch: 2


800it [00:14, 55.25it/s, loss=2.78e+3]
6it [00:00, 49.84it/s, loss=2.79e+3]

Epoch: 3


800it [00:13, 59.50it/s, loss=2.78e+3]
5it [00:00, 47.75it/s, loss=2.78e+3]

Epoch: 4


800it [00:14, 56.21it/s, loss=2.79e+3]
7it [00:00, 63.33it/s, loss=2.78e+3]

Epoch: 5


800it [00:14, 56.91it/s, loss=2.79e+3]
6it [00:00, 58.12it/s, loss=2.78e+3]

Epoch: 6


800it [00:13, 57.15it/s, loss=2.78e+3]
5it [00:00, 44.31it/s, loss=2.78e+3]

Epoch: 7


800it [00:13, 57.20it/s, loss=2.78e+3]
6it [00:00, 59.56it/s, loss=2.79e+3]

Epoch: 8


800it [00:14, 56.52it/s, loss=2.79e+3]
6it [00:00, 59.51it/s, loss=2.79e+3]

Epoch: 9


800it [00:13, 58.42it/s, loss=2.78e+3]
6it [00:00, 48.75it/s, loss=2.78e+3]

Epoch: 10


800it [00:14, 56.05it/s, loss=2.78e+3]
7it [00:00, 59.94it/s, loss=2.78e+3]

Epoch: 11


800it [00:14, 56.10it/s, loss=2.78e+3]
6it [00:00, 57.31it/s, loss=2.78e+3]

Epoch: 12


800it [00:14, 55.28it/s, loss=2.79e+3]
5it [00:00, 44.13it/s, loss=2.79e+3]

Epoch: 13


800it [00:14, 55.78it/s, loss=2.78e+3]
6it [00:00, 50.00it/s, loss=2.78e+3]

Epoch: 14


800it [00:14, 55.21it/s, loss=2.78e+3]
5it [00:00, 48.62it/s, loss=2.79e+3]

Epoch: 15


800it [00:14, 54.80it/s, loss=2.79e+3]
7it [00:00, 65.40it/s, loss=2.78e+3]

Epoch: 16


800it [00:14, 56.00it/s, loss=2.78e+3]
5it [00:00, 47.70it/s, loss=2.79e+3]

Epoch: 17


800it [00:14, 54.24it/s, loss=2.78e+3]
6it [00:00, 54.53it/s, loss=2.78e+3]

Epoch: 18


800it [00:14, 56.17it/s, loss=2.78e+3]
5it [00:00, 44.60it/s, loss=2.79e+3]

Epoch: 19


800it [00:14, 54.72it/s, loss=2.78e+3]
6it [00:00, 57.83it/s, loss=2.78e+3]

Epoch: 20


800it [00:14, 55.65it/s, loss=2.78e+3]
6it [00:00, 55.85it/s, loss=2.78e+3]

Epoch: 21


800it [00:14, 56.10it/s, loss=2.79e+3]
6it [00:00, 53.20it/s, loss=2.77e+3]

Epoch: 22


800it [00:14, 55.09it/s, loss=2.77e+3]
7it [00:00, 60.75it/s, loss=2.78e+3]

Epoch: 23


800it [00:14, 54.88it/s, loss=2.77e+3]
7it [00:00, 66.22it/s, loss=2.79e+3]

Epoch: 24


800it [00:14, 56.37it/s, loss=2.78e+3]
5it [00:00, 44.51it/s, loss=2.78e+3]

Epoch: 25


800it [00:14, 56.03it/s, loss=2.78e+3]
5it [00:00, 48.40it/s, loss=2.77e+3]

Epoch: 26


800it [00:13, 57.55it/s, loss=2.78e+3]
7it [00:00, 65.86it/s, loss=2.77e+3]

Epoch: 27


800it [00:14, 54.11it/s, loss=2.78e+3]
7it [00:00, 61.81it/s, loss=2.78e+3]

Epoch: 28


800it [00:14, 54.04it/s, loss=2.78e+3]
6it [00:00, 55.40it/s, loss=2.77e+3]

Epoch: 29


800it [00:14, 53.63it/s, loss=2.76e+3]
6it [00:00, 59.40it/s, loss=2.75e+3]

Epoch: 30


800it [00:14, 56.05it/s, loss=2.78e+3]
7it [00:00, 57.61it/s, loss=2.78e+3]

Epoch: 31


800it [00:14, 55.91it/s, loss=2.75e+3]
6it [00:00, 54.74it/s, loss=2.76e+3]

Epoch: 32


800it [00:15, 52.44it/s, loss=2.79e+3]
7it [00:00, 66.60it/s, loss=2.75e+3]

Epoch: 33


800it [00:14, 54.07it/s, loss=2.75e+3]
5it [00:00, 47.18it/s, loss=2.79e+3]

Epoch: 34


800it [00:14, 55.60it/s, loss=2.78e+3]
5it [00:00, 39.71it/s, loss=2.75e+3]

Epoch: 35


800it [00:14, 54.03it/s, loss=2.76e+3]
6it [00:00, 56.43it/s, loss=2.76e+3]

Epoch: 36


800it [00:14, 53.82it/s, loss=2.77e+3]
6it [00:00, 55.98it/s, loss=2.74e+3]

Epoch: 37


800it [00:15, 52.56it/s, loss=2.78e+3]
6it [00:00, 59.05it/s, loss=2.75e+3]

Epoch: 38


800it [00:14, 54.20it/s, loss=2.75e+3]
4it [00:00, 39.09it/s, loss=2.7e+3] 

Epoch: 39


800it [00:14, 53.87it/s, loss=2.74e+3]
7it [00:00, 62.75it/s, loss=2.76e+3]

Epoch: 40


800it [00:14, 54.01it/s, loss=2.7e+3] 
7it [00:00, 65.97it/s, loss=2.75e+3]

Epoch: 41


800it [00:14, 55.22it/s, loss=2.72e+3]
5it [00:00, 45.73it/s, loss=2.71e+3]

Epoch: 42


800it [00:14, 55.76it/s, loss=2.71e+3]
6it [00:00, 50.28it/s, loss=2.73e+3]

Epoch: 43


800it [00:14, 53.95it/s, loss=2.75e+3]
5it [00:00, 45.70it/s, loss=2.69e+3]

Epoch: 44


800it [00:14, 55.66it/s, loss=2.76e+3]
6it [00:00, 56.35it/s, loss=2.77e+3]

Epoch: 45


800it [00:14, 54.50it/s, loss=2.74e+3]
5it [00:00, 46.54it/s, loss=2.75e+3]

Epoch: 46


800it [00:14, 54.80it/s, loss=2.7e+3] 
5it [00:00, 43.68it/s, loss=2.76e+3]

Epoch: 47


800it [00:15, 53.01it/s, loss=2.73e+3]
6it [00:00, 48.11it/s, loss=2.74e+3]

Epoch: 48


800it [00:14, 55.92it/s, loss=2.71e+3]
5it [00:00, 42.63it/s, loss=2.75e+3]

Epoch: 49


800it [00:14, 56.60it/s, loss=2.74e+3]
7it [00:00, 63.97it/s, loss=2.75e+3]

Epoch: 50


800it [00:14, 55.15it/s, loss=2.76e+3]
6it [00:00, 51.54it/s, loss=2.7e+3] 

Epoch: 51


800it [00:14, 55.07it/s, loss=2.72e+3]
5it [00:00, 44.99it/s, loss=2.74e+3]

Epoch: 52


800it [00:14, 54.29it/s, loss=2.75e+3]
4it [00:00, 39.26it/s, loss=2.69e+3]

Epoch: 53


800it [00:14, 53.92it/s, loss=2.71e+3]
6it [00:00, 53.19it/s, loss=2.7e+3] 

Epoch: 54


800it [00:14, 56.07it/s, loss=2.72e+3]
6it [00:00, 56.70it/s, loss=2.7e+3] 

Epoch: 55


800it [00:15, 53.03it/s, loss=2.71e+3]
7it [00:00, 61.38it/s, loss=2.76e+3]

Epoch: 56


800it [00:14, 55.36it/s, loss=2.74e+3]
5it [00:00, 48.69it/s, loss=2.75e+3]

Epoch: 57


800it [00:14, 55.87it/s, loss=2.76e+3]
5it [00:00, 43.51it/s, loss=2.76e+3]

Epoch: 58


800it [00:14, 55.67it/s, loss=2.72e+3]
5it [00:00, 49.79it/s, loss=2.7e+3] 

Epoch: 59


800it [00:14, 56.06it/s, loss=2.73e+3]
7it [00:00, 65.03it/s, loss=2.69e+3]

Epoch: 60


800it [00:14, 54.47it/s, loss=2.72e+3]
5it [00:00, 49.53it/s, loss=2.7e+3] 

Epoch: 61


800it [00:14, 54.51it/s, loss=2.71e+3]
7it [00:00, 65.93it/s, loss=2.74e+3]

Epoch: 62


800it [00:14, 54.08it/s, loss=2.74e+3]
5it [00:00, 41.57it/s, loss=2.7e+3] 

Epoch: 63


800it [00:14, 55.09it/s, loss=2.74e+3]
7it [00:00, 60.83it/s, loss=2.74e+3]

Epoch: 64


800it [00:14, 55.23it/s, loss=2.71e+3]
7it [00:00, 67.30it/s, loss=2.71e+3]

Epoch: 65


800it [00:14, 53.88it/s, loss=2.73e+3]
6it [00:00, 58.14it/s, loss=2.74e+3]

Epoch: 66


800it [00:14, 55.82it/s, loss=2.71e+3]
7it [00:00, 64.35it/s, loss=2.7e+3] 

Epoch: 67


800it [00:14, 54.75it/s, loss=2.71e+3]
5it [00:00, 44.59it/s, loss=2.71e+3]

Epoch: 68


800it [00:14, 54.18it/s, loss=2.7e+3] 
6it [00:00, 51.68it/s, loss=2.72e+3]

Epoch: 69


800it [00:14, 54.95it/s, loss=2.75e+3]
5it [00:00, 46.18it/s, loss=2.71e+3]

Epoch: 70


800it [00:14, 56.85it/s, loss=2.73e+3]
6it [00:00, 59.09it/s, loss=2.69e+3]

Epoch: 71


800it [00:15, 52.94it/s, loss=2.73e+3]
4it [00:00, 39.80it/s, loss=2.72e+3]

Epoch: 72


800it [00:14, 53.66it/s, loss=2.72e+3]
6it [00:00, 56.05it/s, loss=2.74e+3]

Epoch: 73


800it [00:15, 53.22it/s, loss=2.71e+3]
7it [00:00, 66.37it/s, loss=2.73e+3]

Epoch: 74


800it [00:14, 56.44it/s, loss=2.73e+3]
7it [00:00, 60.57it/s, loss=2.74e+3]

Epoch: 75


800it [00:14, 54.30it/s, loss=2.73e+3]
5it [00:00, 43.85it/s, loss=2.69e+3]

Epoch: 76


800it [00:14, 54.85it/s, loss=2.7e+3] 
6it [00:00, 57.01it/s, loss=2.71e+3]

Epoch: 77


800it [00:14, 53.59it/s, loss=2.69e+3]
6it [00:00, 54.19it/s, loss=2.76e+3]

Epoch: 78


800it [00:14, 56.28it/s, loss=2.73e+3]
7it [00:00, 64.50it/s, loss=2.71e+3]

Epoch: 79


800it [00:14, 53.38it/s, loss=2.7e+3] 
5it [00:00, 42.93it/s, loss=2.69e+3]

Epoch: 80


800it [00:14, 55.21it/s, loss=2.7e+3] 
4it [00:00, 37.14it/s, loss=2.78e+3]

Epoch: 81


800it [00:14, 54.49it/s, loss=2.71e+3]
7it [00:00, 55.91it/s, loss=2.73e+3]

Epoch: 82


800it [00:14, 53.81it/s, loss=2.71e+3]
5it [00:00, 44.85it/s, loss=2.75e+3]

Epoch: 83


800it [00:14, 55.27it/s, loss=2.72e+3]
5it [00:00, 45.76it/s, loss=2.7e+3] 

Epoch: 84


800it [00:14, 54.63it/s, loss=2.71e+3]
7it [00:00, 64.89it/s, loss=2.69e+3]

Epoch: 85


800it [00:14, 53.94it/s, loss=2.73e+3]
7it [00:00, 62.97it/s, loss=2.73e+3]

Epoch: 86


800it [00:14, 54.55it/s, loss=2.7e+3] 
6it [00:00, 51.94it/s, loss=2.71e+3]

Epoch: 87


800it [00:14, 55.04it/s, loss=2.69e+3]
7it [00:00, 64.39it/s, loss=2.73e+3]

Epoch: 88


800it [00:15, 53.01it/s, loss=2.72e+3]
5it [00:00, 45.30it/s, loss=2.73e+3]

Epoch: 89


800it [00:14, 54.30it/s, loss=2.71e+3]
7it [00:00, 64.73it/s, loss=2.74e+3]

Epoch: 90


800it [00:14, 54.73it/s, loss=2.72e+3]
5it [00:00, 46.07it/s, loss=2.71e+3]

Epoch: 91


800it [00:14, 53.41it/s, loss=2.73e+3]
7it [00:00, 64.74it/s, loss=2.71e+3]

Epoch: 92


800it [00:15, 52.08it/s, loss=2.72e+3]
5it [00:00, 42.04it/s, loss=2.72e+3]

Epoch: 93


800it [00:15, 50.99it/s, loss=2.7e+3] 
5it [00:00, 49.20it/s, loss=2.7e+3] 

Epoch: 94


800it [00:14, 54.26it/s, loss=2.73e+3]
6it [00:00, 55.21it/s, loss=2.73e+3]

Epoch: 95


800it [00:15, 52.75it/s, loss=2.72e+3]
5it [00:00, 45.60it/s, loss=2.73e+3]

Epoch: 96


800it [00:15, 52.23it/s, loss=2.71e+3]
6it [00:00, 57.40it/s, loss=2.75e+3]

Epoch: 97


800it [00:15, 50.73it/s, loss=2.72e+3]
7it [00:00, 62.52it/s, loss=2.69e+3]

Epoch: 98


800it [00:16, 49.25it/s, loss=2.73e+3]
7it [00:00, 62.88it/s, loss=2.7e+3] 

Epoch: 99


800it [00:15, 51.27it/s, loss=2.71e+3]
5it [00:00, 44.44it/s, loss=2.69e+3]

Epoch: 100


800it [00:15, 53.23it/s, loss=2.69e+3]
7it [00:00, 62.31it/s, loss=2.71e+3]

Epoch: 101


800it [00:15, 52.95it/s, loss=2.73e+3]
7it [00:00, 61.95it/s, loss=2.71e+3]

Epoch: 102


800it [00:15, 52.46it/s, loss=2.71e+3]
5it [00:00, 49.52it/s, loss=2.72e+3]

Epoch: 103


800it [00:15, 50.37it/s, loss=2.7e+3] 
5it [00:00, 41.30it/s, loss=2.69e+3]

Epoch: 104


800it [00:15, 50.31it/s, loss=2.74e+3]
4it [00:00, 34.87it/s, loss=2.71e+3]

Epoch: 105


800it [00:15, 52.58it/s, loss=2.72e+3]
6it [00:00, 58.82it/s, loss=2.69e+3]

Epoch: 106


800it [00:16, 49.56it/s, loss=2.69e+3]
7it [00:00, 63.45it/s, loss=2.72e+3]

Epoch: 107


800it [00:16, 49.82it/s, loss=2.69e+3]
7it [00:00, 63.55it/s, loss=2.71e+3]

Epoch: 108


800it [00:15, 52.73it/s, loss=2.71e+3]
7it [00:00, 62.51it/s, loss=2.69e+3]

Epoch: 109


800it [00:16, 49.34it/s, loss=2.72e+3]
4it [00:00, 38.21it/s, loss=2.7e+3] 

Epoch: 110


800it [00:14, 55.32it/s, loss=2.72e+3]
6it [00:00, 51.33it/s, loss=2.71e+3]

Epoch: 111


800it [00:15, 50.41it/s, loss=2.72e+3]
7it [00:00, 64.14it/s, loss=2.71e+3]

Epoch: 112


800it [00:15, 50.83it/s, loss=2.72e+3]
5it [00:00, 41.70it/s, loss=2.71e+3]

Epoch: 113


800it [00:15, 51.38it/s, loss=2.72e+3]
4it [00:00, 38.83it/s, loss=2.83e+3]

Epoch: 114


800it [00:15, 51.24it/s, loss=2.69e+3]
6it [00:00, 51.88it/s, loss=2.74e+3]

Epoch: 115


800it [00:15, 51.82it/s, loss=2.75e+3]
7it [00:00, 60.28it/s, loss=2.71e+3]

Epoch: 116


800it [00:15, 52.00it/s, loss=2.7e+3] 
5it [00:00, 48.06it/s, loss=2.72e+3]

Epoch: 117


800it [00:15, 50.54it/s, loss=2.71e+3]
7it [00:00, 62.60it/s, loss=2.7e+3] 

Epoch: 118


800it [00:15, 50.06it/s, loss=2.73e+3]
5it [00:00, 48.91it/s, loss=2.72e+3]

Epoch: 119


800it [00:15, 50.42it/s, loss=2.73e+3]
4it [00:00, 35.08it/s, loss=2.7e+3] 

Epoch: 120


800it [00:15, 50.84it/s, loss=2.82e+3]
7it [00:00, 64.05it/s, loss=2.72e+3]

Epoch: 121


800it [00:15, 51.42it/s, loss=2.72e+3]
7it [00:00, 63.27it/s, loss=2.73e+3]

Epoch: 122


800it [00:15, 50.29it/s, loss=2.71e+3]
7it [00:00, 63.82it/s, loss=2.72e+3]

Epoch: 123


800it [00:15, 50.25it/s, loss=2.71e+3]
6it [00:00, 53.54it/s, loss=2.73e+3]

Epoch: 124


800it [00:16, 49.30it/s, loss=2.73e+3]
6it [00:00, 52.60it/s, loss=2.74e+3]

Epoch: 125


800it [00:14, 53.75it/s, loss=2.7e+3] 
6it [00:00, 54.45it/s, loss=2.7e+3] 

Epoch: 126


800it [00:15, 51.99it/s, loss=2.71e+3]
6it [00:00, 50.62it/s, loss=2.71e+3]

Epoch: 127


800it [00:15, 50.71it/s, loss=2.7e+3] 
7it [00:00, 61.33it/s, loss=2.71e+3]

Epoch: 128


800it [00:14, 54.15it/s, loss=2.72e+3]
5it [00:00, 45.15it/s, loss=2.68e+3]

Epoch: 129


800it [00:15, 52.53it/s, loss=2.73e+3]
7it [00:00, 65.50it/s, loss=2.7e+3] 

Epoch: 130


800it [00:15, 50.47it/s, loss=2.72e+3]
7it [00:00, 61.20it/s, loss=2.7e+3] 

Epoch: 131


800it [00:16, 49.28it/s, loss=2.69e+3]
7it [00:00, 59.27it/s, loss=2.72e+3]

Epoch: 132


800it [00:16, 49.21it/s, loss=2.69e+3]
7it [00:00, 64.15it/s, loss=2.72e+3]

Epoch: 133


800it [00:15, 51.52it/s, loss=2.71e+3]
4it [00:00, 35.59it/s, loss=2.69e+3]

Epoch: 134


800it [00:16, 49.44it/s, loss=2.75e+3]
4it [00:00, 33.29it/s, loss=2.71e+3]

Epoch: 135


800it [00:15, 52.39it/s, loss=2.74e+3]
5it [00:00, 41.85it/s, loss=2.7e+3] 

Epoch: 136


800it [00:16, 49.52it/s, loss=2.72e+3]
6it [00:00, 58.78it/s, loss=2.69e+3]

Epoch: 137


800it [00:15, 50.61it/s, loss=2.73e+3]
4it [00:00, 35.65it/s, loss=2.69e+3]

Epoch: 138


800it [00:15, 51.36it/s, loss=2.73e+3]
7it [00:00, 56.73it/s, loss=2.69e+3]

Epoch: 139


800it [00:16, 49.68it/s, loss=2.73e+3]
7it [00:00, 62.52it/s, loss=2.71e+3]

Epoch: 140


800it [00:15, 52.22it/s, loss=2.73e+3]
5it [00:00, 43.81it/s, loss=2.68e+3]

Epoch: 141


800it [00:15, 52.71it/s, loss=2.71e+3]
4it [00:00, 37.37it/s, loss=2.7e+3] 

Epoch: 142


800it [00:16, 49.83it/s, loss=2.72e+3]
6it [00:00, 59.61it/s, loss=2.68e+3]

Epoch: 143


800it [00:15, 52.13it/s, loss=2.72e+3]
4it [00:00, 34.18it/s, loss=2.71e+3]

Epoch: 144


800it [00:15, 51.83it/s, loss=2.71e+3]
7it [00:00, 63.69it/s, loss=2.68e+3]

Epoch: 145


800it [00:14, 54.22it/s, loss=2.7e+3] 
7it [00:00, 64.54it/s, loss=2.69e+3]

Epoch: 146


800it [00:14, 53.68it/s, loss=2.72e+3]
7it [00:00, 63.88it/s, loss=2.72e+3]

Epoch: 147


800it [00:15, 52.27it/s, loss=2.71e+3]
7it [00:00, 61.28it/s, loss=2.71e+3]

Epoch: 148


800it [00:15, 51.27it/s, loss=2.71e+3]
5it [00:00, 49.99it/s, loss=2.72e+3]

Epoch: 149


800it [00:16, 49.50it/s, loss=2.71e+3]
4it [00:00, 34.67it/s, loss=2.69e+3]

Epoch: 150


800it [00:16, 48.50it/s, loss=2.68e+3]
4it [00:00, 32.57it/s, loss=2.72e+3]

Epoch: 151


800it [00:15, 51.31it/s, loss=2.71e+3]
4it [00:00, 35.18it/s, loss=2.72e+3]

Epoch: 152


800it [00:15, 51.15it/s, loss=2.71e+3]
5it [00:00, 44.12it/s, loss=2.68e+3]

Epoch: 153


800it [00:16, 49.49it/s, loss=2.71e+3]
6it [00:00, 52.94it/s, loss=2.7e+3] 

Epoch: 154


800it [00:15, 52.10it/s, loss=2.72e+3]
4it [00:00, 37.89it/s, loss=2.73e+3]

Epoch: 155


800it [00:15, 50.99it/s, loss=2.72e+3]
7it [00:00, 63.28it/s, loss=2.71e+3]

Epoch: 156


800it [00:16, 49.89it/s, loss=2.7e+3] 
6it [00:00, 50.58it/s, loss=2.72e+3]

Epoch: 157


800it [00:15, 50.63it/s, loss=2.71e+3]
7it [00:00, 63.44it/s, loss=2.7e+3] 

Epoch: 158


800it [00:16, 49.79it/s, loss=2.69e+3]
5it [00:00, 45.19it/s, loss=2.73e+3]

Epoch: 159


800it [00:15, 51.17it/s, loss=2.71e+3]
5it [00:00, 43.51it/s, loss=2.72e+3]

Epoch: 160


800it [00:15, 50.32it/s, loss=2.71e+3]
7it [00:00, 64.25it/s, loss=2.71e+3]

Epoch: 161


800it [00:15, 50.28it/s, loss=2.72e+3]
5it [00:00, 46.85it/s, loss=2.73e+3]

Epoch: 162


800it [00:15, 50.46it/s, loss=2.72e+3]
4it [00:00, 35.93it/s, loss=2.7e+3] 

Epoch: 163


800it [00:14, 53.72it/s, loss=2.72e+3]
5it [00:00, 45.60it/s, loss=2.69e+3]

Epoch: 164


800it [00:15, 51.59it/s, loss=2.7e+3] 
5it [00:00, 49.82it/s, loss=2.7e+3] 

Epoch: 165


800it [00:16, 49.38it/s, loss=2.73e+3]
5it [00:00, 48.42it/s, loss=2.7e+3] 

Epoch: 166


800it [00:15, 50.58it/s, loss=2.72e+3]
5it [00:00, 49.69it/s, loss=2.72e+3]

Epoch: 167


800it [00:15, 52.30it/s, loss=2.71e+3]
7it [00:00, 64.24it/s, loss=2.71e+3]

Epoch: 168


800it [00:16, 47.90it/s, loss=2.7e+3] 
7it [00:00, 62.76it/s, loss=2.71e+3]

Epoch: 169


800it [00:15, 52.15it/s, loss=2.69e+3]
5it [00:00, 39.59it/s, loss=2.73e+3]

Epoch: 170


800it [00:15, 50.33it/s, loss=2.72e+3]
7it [00:00, 63.36it/s, loss=2.72e+3]

Epoch: 171


800it [00:14, 54.64it/s, loss=2.7e+3] 
6it [00:00, 54.91it/s, loss=2.7e+3] 

Epoch: 172


800it [00:15, 51.64it/s, loss=2.7e+3] 
4it [00:00, 35.42it/s, loss=2.68e+3]

Epoch: 173


800it [00:16, 49.64it/s, loss=2.72e+3]
6it [00:00, 56.93it/s, loss=2.75e+3]

Epoch: 174


800it [00:15, 50.32it/s, loss=2.71e+3]
7it [00:00, 62.51it/s, loss=2.73e+3]

Epoch: 175


800it [00:15, 51.17it/s, loss=2.69e+3]
7it [00:00, 62.53it/s, loss=2.69e+3]

Epoch: 176


800it [00:15, 52.11it/s, loss=2.72e+3]
6it [00:00, 59.62it/s, loss=2.73e+3]

Epoch: 177


800it [00:16, 49.06it/s, loss=2.69e+3]
5it [00:00, 43.94it/s, loss=2.72e+3]

Epoch: 178


800it [00:16, 49.67it/s, loss=2.73e+3]
7it [00:00, 63.52it/s, loss=2.7e+3] 

Epoch: 179


800it [00:15, 51.03it/s, loss=2.69e+3]
7it [00:00, 63.68it/s, loss=2.7e+3] 

Epoch: 180


800it [00:15, 53.07it/s, loss=2.72e+3]
7it [00:00, 61.75it/s, loss=2.7e+3] 

Epoch: 181


800it [00:16, 49.99it/s, loss=2.73e+3]
6it [00:00, 46.68it/s, loss=2.68e+3]

Epoch: 182


800it [00:14, 53.41it/s, loss=2.68e+3]
6it [00:00, 58.62it/s, loss=2.7e+3] 

Epoch: 183


800it [00:15, 50.75it/s, loss=2.71e+3]
4it [00:00, 38.88it/s, loss=2.68e+3]

Epoch: 184


800it [00:15, 51.84it/s, loss=2.72e+3]
5it [00:00, 46.46it/s, loss=2.71e+3]

Epoch: 185


800it [00:15, 52.63it/s, loss=2.71e+3]
7it [00:00, 64.71it/s, loss=2.69e+3]

Epoch: 186


800it [00:16, 49.77it/s, loss=2.74e+3]
7it [00:00, 58.44it/s, loss=2.69e+3]

Epoch: 187


800it [00:16, 49.53it/s, loss=2.74e+3]
5it [00:00, 48.42it/s, loss=2.7e+3] 

Epoch: 188


800it [00:16, 49.49it/s, loss=2.72e+3]
6it [00:00, 55.31it/s, loss=2.72e+3]

Epoch: 189


800it [00:15, 52.78it/s, loss=2.69e+3]
7it [00:00, 62.85it/s, loss=2.72e+3]

Epoch: 190


800it [00:15, 51.21it/s, loss=2.72e+3]
7it [00:00, 64.10it/s, loss=2.74e+3]

Epoch: 191


800it [00:14, 55.00it/s, loss=2.72e+3]
7it [00:00, 64.11it/s, loss=2.71e+3]

Epoch: 192


800it [00:14, 53.68it/s, loss=2.7e+3] 
4it [00:00, 38.06it/s, loss=2.71e+3]

Epoch: 193


800it [00:15, 52.42it/s, loss=2.72e+3]
7it [00:00, 61.06it/s, loss=2.72e+3]

Epoch: 194


800it [00:15, 51.64it/s, loss=2.7e+3] 
4it [00:00, 38.69it/s, loss=2.69e+3]

Epoch: 195


800it [00:15, 52.16it/s, loss=2.72e+3]
7it [00:00, 61.18it/s, loss=2.7e+3] 

Epoch: 196


800it [00:16, 49.55it/s, loss=2.7e+3] 
5it [00:00, 46.89it/s, loss=2.7e+3] 

Epoch: 197


800it [00:15, 50.97it/s, loss=2.71e+3]
5it [00:00, 48.91it/s, loss=2.71e+3]

Epoch: 198


800it [00:15, 50.01it/s, loss=2.69e+3]
5it [00:00, 49.96it/s, loss=2.7e+3] 

Epoch: 199


800it [00:15, 52.70it/s, loss=2.75e+3]


### saving the model

In [21]:
torch.save(model.state_dict(), 'models/model_1')

### loading the model

In [30]:
model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE)
model.load_state_dict(torch.load('models/model_1'))
model.eval()

VariationalAutoEncoder(
  (enc1): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (fc1): Linear(in_features=576, out_features=128, bias=True)
  (fc_mu): Linear(in_features=128, out_features=3, bias=True)
  (fc_log_var): Linear(in_features=128, out_features=3, bias=True)
  (fc2): Linear(in_features=3, out_features=64, bias=True)
  (dec1): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(1, 1))
  (dec2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec4): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec5): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec6): ConvTranspose2d(1, 1, ker

### generating images

In [39]:
def inference(num_examples=20):
    idx = np.random.choice(range(0,800),num_examples,replace=True) # sampling from dataset
    for example,i in zip(range(num_examples),idx):
        image = dataset[i][0]
        with torch.no_grad():
            mu, sigma = model.encode(image.view(1,64,64)) # getting encoding of each image
        epsilon = torch.randn_like(sigma)
        z = mu + sigma*epsilon
        out = model.decode(z) # generating new image from the encoding
        out = out.view(-1, 1, 64, 64)
        save_image(out, f'output/img_{i}.png')

In [40]:
inference()