In [1]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn
from ipynb.fs.full.model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image  
from torch.utils.data import DataLoader, RandomSampler
import numpy as np
import matplotlib.pyplot as plt

### configuration

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 256
INIT_DIM = 8
LATENT_DIM = 3
NUM_EPOCHS = 500
BATCH_SIZE = 1
LR_RATE = 3e-4
KERNEL_SIZE = 4

### loading dataset

In [3]:
# import os

# samples = np.random.choice(os.listdir('dataset/0'),400,replace=False)
# for sample in samples:
#     os.rename('dataset/0/'+sample, 'dataset2/0/'+sample)

In [4]:
# Dataset Loading
data_path = 'dataset' # setting path
transform = transforms.Compose([transforms.Resize((INPUT_DIM, INPUT_DIM)),   # sequence of transformations to be done
                                transforms.Grayscale(num_output_channels=1), # on each image (resize, greyscale,
                                transforms.ToTensor()])                      # convert to tensor)

dataset = datasets.ImageFolder(root=data_path, transform=transform) # read data from folder

train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) # create dataloader object

model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE).to(DEVICE) # initializing model object

optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE) # defining optimizer
loss_fn = nn.BCELoss(reduction='sum') # define loss function

### training model

In [5]:
# Start Training
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    print(f'Epoch: {epoch}')
    for i, (x, _) in loop:
        # forward pass
        x = x.to(DEVICE).view(1, INPUT_DIM, INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)
        
        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        
        # backpropagation
        beta = min(1, epoch/(NUM_EPOCHS/2))
        loss = reconstruction_loss + (beta * kl_div)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())
    print(f'Loss: {loss.item()}')

3it [00:00, 26.90it/s, loss=4.5e+4]

Epoch: 0


800it [00:21, 37.99it/s, loss=4.47e+4]


Loss: 44685.2421875


4it [00:00, 31.21it/s, loss=4.47e+4]

Epoch: 1


800it [00:24, 32.87it/s, loss=4.44e+4]


Loss: 44442.671875


4it [00:00, 35.36it/s, loss=4.45e+4]

Epoch: 2


800it [00:23, 34.23it/s, loss=4.44e+4]


Loss: 44375.87109375


4it [00:00, 38.76it/s, loss=4.42e+4]

Epoch: 3


800it [00:23, 33.56it/s, loss=4.39e+4]


Loss: 43880.55859375


3it [00:00, 29.12it/s, loss=4.38e+4]

Epoch: 4


800it [00:25, 31.95it/s, loss=4.36e+4]


Loss: 43563.109375


4it [00:00, 35.66it/s, loss=4.33e+4]

Epoch: 5


800it [00:25, 31.67it/s, loss=4.25e+4]


Loss: 42539.171875


4it [00:00, 34.78it/s, loss=4.35e+4]

Epoch: 6


800it [00:25, 31.87it/s, loss=4.26e+4]


Loss: 42604.25390625


4it [00:00, 34.42it/s, loss=4.29e+4]

Epoch: 7


800it [00:25, 31.01it/s, loss=4.29e+4]


Loss: 42854.35546875


3it [00:00, 29.90it/s, loss=4.25e+4]

Epoch: 8


800it [00:26, 30.73it/s, loss=4.21e+4]


Loss: 42052.94140625


4it [00:00, 29.97it/s, loss=4.26e+4]

Epoch: 9


800it [00:25, 31.47it/s, loss=4.27e+4]


Loss: 42694.77734375


4it [00:00, 33.54it/s, loss=4.19e+4]

Epoch: 10


800it [00:26, 30.45it/s, loss=4.21e+4]


Loss: 42086.40625


4it [00:00, 34.69it/s, loss=4.24e+4]

Epoch: 11


800it [00:25, 31.11it/s, loss=4.27e+4]


Loss: 42695.7890625


4it [00:00, 31.93it/s, loss=4.21e+4]

Epoch: 12


800it [00:26, 30.47it/s, loss=4.18e+4]


Loss: 41836.296875


4it [00:00, 34.84it/s, loss=4.18e+4]

Epoch: 13


800it [00:26, 30.34it/s, loss=4.21e+4]


Loss: 42102.28125


3it [00:00, 29.22it/s, loss=4.21e+4]

Epoch: 14


800it [00:26, 30.49it/s, loss=4.21e+4]


Loss: 42057.6015625


3it [00:00, 27.48it/s, loss=4.23e+4]

Epoch: 15


800it [00:25, 30.97it/s, loss=4.24e+4]


Loss: 42412.24609375


3it [00:00, 25.74it/s, loss=4.23e+4]

Epoch: 16


800it [00:26, 30.71it/s, loss=4.18e+4]


Loss: 41772.2109375


4it [00:00, 31.67it/s, loss=4.19e+4]

Epoch: 17


800it [00:26, 30.41it/s, loss=4.21e+4]


Loss: 42131.375


4it [00:00, 30.82it/s, loss=4.19e+4]

Epoch: 18


800it [00:26, 30.39it/s, loss=4.18e+4]


Loss: 41757.015625


3it [00:00, 27.30it/s, loss=4.21e+4]

Epoch: 19


800it [00:26, 30.43it/s, loss=4.2e+4] 


Loss: 42021.96875


4it [00:00, 31.38it/s, loss=4.17e+4]

Epoch: 20


800it [00:26, 29.89it/s, loss=4.2e+4] 


Loss: 42032.984375


4it [00:00, 31.87it/s, loss=4.22e+4]

Epoch: 21


800it [00:25, 31.47it/s, loss=4.22e+4]


Loss: 42201.88671875


3it [00:00, 25.48it/s, loss=4.2e+4] 

Epoch: 22


800it [00:26, 29.94it/s, loss=4.29e+4]


Loss: 42901.7265625


4it [00:00, 32.49it/s, loss=4.2e+4] 

Epoch: 23


800it [00:26, 30.23it/s, loss=4.17e+4]


Loss: 41656.22265625


4it [00:00, 31.08it/s, loss=4.2e+4] 

Epoch: 24


800it [00:26, 29.75it/s, loss=4.2e+4] 


Loss: 41963.578125


3it [00:00, 26.52it/s, loss=4.19e+4]

Epoch: 25


800it [00:26, 29.90it/s, loss=4.23e+4]


Loss: 42346.25390625


4it [00:00, 29.81it/s, loss=4.18e+4]

Epoch: 26


800it [00:26, 30.35it/s, loss=4.25e+4]


Loss: 42476.40625


3it [00:00, 26.82it/s, loss=4.17e+4]

Epoch: 27


800it [00:26, 29.80it/s, loss=4.23e+4]


Loss: 42348.36328125


4it [00:00, 32.81it/s, loss=4.17e+4]

Epoch: 28


800it [00:26, 30.40it/s, loss=4.21e+4]


Loss: 42090.1796875


3it [00:00, 29.80it/s, loss=4.19e+4]

Epoch: 29


800it [00:27, 29.41it/s, loss=4.16e+4]


Loss: 41622.21875


3it [00:00, 29.35it/s, loss=4.21e+4]

Epoch: 30


800it [00:26, 30.52it/s, loss=4.17e+4]


Loss: 41674.61328125


4it [00:00, 29.73it/s, loss=4.2e+4]

Epoch: 31


800it [00:26, 29.96it/s, loss=4.19e+4]


Loss: 41879.19140625


4it [00:00, 33.14it/s, loss=4.27e+4]

Epoch: 32


800it [00:26, 29.69it/s, loss=4.16e+4]


Loss: 41643.6640625


4it [00:00, 32.57it/s, loss=4.14e+4]

Epoch: 33


800it [00:26, 30.26it/s, loss=4.18e+4]


Loss: 41791.98828125


4it [00:00, 32.07it/s, loss=4.18e+4]

Epoch: 34


800it [00:26, 29.67it/s, loss=4.22e+4]


Loss: 42179.04296875


3it [00:00, 29.24it/s, loss=4.18e+4]

Epoch: 35


800it [00:26, 30.04it/s, loss=4.21e+4]


Loss: 42052.24609375


4it [00:00, 30.95it/s, loss=4.2e+4] 

Epoch: 36


800it [00:26, 29.94it/s, loss=4.18e+4]


Loss: 41783.76953125


3it [00:00, 26.02it/s, loss=4.17e+4]

Epoch: 37


800it [00:26, 29.90it/s, loss=4.17e+4]


Loss: 41686.90625


3it [00:00, 28.98it/s, loss=4.15e+4]

Epoch: 38


800it [00:26, 29.82it/s, loss=4.18e+4]


Loss: 41801.203125


3it [00:00, 24.44it/s, loss=4.19e+4]

Epoch: 39


800it [00:26, 29.73it/s, loss=4.17e+4]


Loss: 41743.14453125


3it [00:00, 26.88it/s, loss=4.18e+4]

Epoch: 40


800it [00:26, 30.13it/s, loss=4.2e+4] 


Loss: 42014.18359375


4it [00:00, 32.51it/s, loss=4.16e+4]

Epoch: 41


800it [00:27, 29.34it/s, loss=4.15e+4]


Loss: 41545.28125


4it [00:00, 32.41it/s, loss=4.19e+4]

Epoch: 42


800it [00:27, 29.34it/s, loss=4.16e+4]


Loss: 41590.2890625


4it [00:00, 32.37it/s, loss=4.18e+4]

Epoch: 43


800it [00:27, 29.19it/s, loss=4.19e+4]


Loss: 41914.20703125


4it [00:00, 32.86it/s, loss=4.18e+4]

Epoch: 44


800it [00:27, 29.27it/s, loss=4.2e+4] 


Loss: 41980.7734375


3it [00:00, 27.05it/s, loss=4.17e+4]

Epoch: 45


800it [00:27, 29.38it/s, loss=4.21e+4]


Loss: 42124.484375


4it [00:00, 32.63it/s, loss=4.15e+4]

Epoch: 46


800it [00:26, 29.74it/s, loss=41662.0]


Loss: 41662.0


4it [00:00, 32.95it/s, loss=4.21e+4]

Epoch: 47


800it [00:27, 29.51it/s, loss=4.15e+4]


Loss: 41545.21484375


4it [00:00, 30.28it/s, loss=4.18e+4]

Epoch: 48


800it [00:27, 29.45it/s, loss=4.16e+4]


Loss: 41576.96484375


4it [00:00, 31.75it/s, loss=4.2e+4] 

Epoch: 49


800it [00:27, 29.45it/s, loss=4.16e+4]


Loss: 41576.52734375


4it [00:00, 31.86it/s, loss=4.14e+4]

Epoch: 50


800it [00:27, 29.15it/s, loss=4.17e+4]


Loss: 41680.58984375


3it [00:00, 25.66it/s, loss=4.19e+4]

Epoch: 51


800it [00:27, 29.04it/s, loss=4.22e+4]


Loss: 42213.875


4it [00:00, 31.45it/s, loss=4.19e+4]

Epoch: 52


800it [00:27, 29.12it/s, loss=4.18e+4]


Loss: 41848.90234375


4it [00:00, 30.57it/s, loss=4.16e+4]

Epoch: 53


800it [00:27, 29.21it/s, loss=4.15e+4]


Loss: 41453.15234375


3it [00:00, 28.38it/s, loss=4.15e+4]

Epoch: 54


800it [00:27, 28.78it/s, loss=4.15e+4]


Loss: 41540.375


3it [00:00, 25.88it/s, loss=4.14e+4]

Epoch: 55


800it [00:27, 28.77it/s, loss=4.18e+4]


Loss: 41761.41015625


3it [00:00, 29.56it/s, loss=4.2e+4] 

Epoch: 56


800it [00:27, 29.19it/s, loss=4.15e+4]


Loss: 41523.00390625


3it [00:00, 25.26it/s, loss=4.15e+4]

Epoch: 57


800it [00:28, 28.51it/s, loss=4.2e+4] 


Loss: 42042.31640625


4it [00:00, 31.57it/s, loss=4.15e+4]

Epoch: 58


800it [00:28, 28.20it/s, loss=4.18e+4]


Loss: 41804.0078125


4it [00:00, 31.71it/s, loss=4.16e+4]

Epoch: 59


800it [00:28, 28.57it/s, loss=4.2e+4] 


Loss: 42040.43359375


3it [00:00, 25.78it/s, loss=4.17e+4]

Epoch: 60


800it [00:28, 28.57it/s, loss=4.27e+4]


Loss: 42670.07421875


3it [00:00, 25.38it/s, loss=4.16e+4]

Epoch: 61


800it [00:28, 28.33it/s, loss=4.19e+4]


Loss: 41944.6328125


4it [00:00, 31.95it/s, loss=4.16e+4]

Epoch: 62


800it [00:27, 28.91it/s, loss=4.23e+4]


Loss: 42266.48046875


4it [00:00, 31.23it/s, loss=4.16e+4]

Epoch: 63


800it [00:28, 28.34it/s, loss=4.19e+4]


Loss: 41947.8984375


4it [00:00, 30.48it/s, loss=4.2e+4] 

Epoch: 64


800it [00:28, 27.67it/s, loss=4.15e+4]


Loss: 41520.4609375


4it [00:00, 29.55it/s, loss=4.14e+4]

Epoch: 65


800it [00:28, 28.45it/s, loss=4.2e+4] 


Loss: 42037.07421875


4it [00:00, 30.35it/s, loss=4.14e+4]

Epoch: 66


800it [00:28, 28.07it/s, loss=4.22e+4]


Loss: 42178.6875


3it [00:00, 29.52it/s, loss=4.15e+4]

Epoch: 67


800it [00:28, 27.90it/s, loss=4.19e+4]


Loss: 41877.05078125


3it [00:00, 26.72it/s, loss=4.18e+4]

Epoch: 68


800it [00:29, 27.55it/s, loss=4.18e+4]


Loss: 41750.953125


3it [00:00, 25.50it/s, loss=4.18e+4]

Epoch: 69


800it [00:28, 27.72it/s, loss=4.15e+4]


Loss: 41487.4765625


3it [00:00, 28.02it/s, loss=4.16e+4]

Epoch: 70


800it [00:29, 27.11it/s, loss=4.17e+4]


Loss: 41697.2421875


3it [00:00, 29.97it/s, loss=4.21e+4]

Epoch: 71


800it [00:29, 27.53it/s, loss=4.19e+4]


Loss: 41893.125


3it [00:00, 28.60it/s, loss=4.17e+4]

Epoch: 72


800it [00:29, 27.02it/s, loss=4.19e+4]


Loss: 41915.7890625


3it [00:00, 29.12it/s, loss=4.16e+4]

Epoch: 73


800it [00:29, 27.02it/s, loss=4.18e+4]


Loss: 41778.8828125


3it [00:00, 28.33it/s, loss=4.17e+4]

Epoch: 74


800it [00:29, 26.79it/s, loss=4.16e+4]


Loss: 41646.17578125


3it [00:00, 23.75it/s, loss=4.15e+4]

Epoch: 75


800it [00:30, 26.37it/s, loss=4.16e+4]


Loss: 41627.32421875


3it [00:00, 24.10it/s, loss=4.18e+4]

Epoch: 76


800it [00:30, 26.63it/s, loss=4.14e+4]


Loss: 41377.0078125


3it [00:00, 24.18it/s, loss=4.16e+4]

Epoch: 77


800it [00:30, 25.90it/s, loss=4.14e+4]


Loss: 41373.66796875


3it [00:00, 27.63it/s, loss=4.15e+4]

Epoch: 78


800it [00:31, 25.68it/s, loss=4.19e+4]


Loss: 41927.9453125


3it [00:00, 25.19it/s, loss=4.2e+4] 

Epoch: 79


800it [00:31, 25.60it/s, loss=4.19e+4]


Loss: 41859.34375


3it [00:00, 26.35it/s, loss=4.15e+4]

Epoch: 80


800it [00:31, 25.37it/s, loss=4.13e+4]


Loss: 41313.0546875


3it [00:00, 26.05it/s, loss=4.23e+4]

Epoch: 81


800it [00:31, 25.30it/s, loss=4.18e+4]


Loss: 41750.9921875


3it [00:00, 28.02it/s, loss=4.19e+4]

Epoch: 82


800it [00:32, 24.90it/s, loss=4.15e+4]


Loss: 41464.078125


3it [00:00, 27.51it/s, loss=4.19e+4]

Epoch: 83


800it [00:32, 24.94it/s, loss=4.17e+4]


Loss: 41747.18359375


3it [00:00, 21.50it/s, loss=4.18e+4]

Epoch: 84


800it [00:31, 25.05it/s, loss=4.18e+4]


Loss: 41752.39453125


3it [00:00, 26.23it/s, loss=4.15e+4]

Epoch: 85


800it [00:32, 24.72it/s, loss=4.18e+4]


Loss: 41800.0703125


3it [00:00, 27.03it/s, loss=4.17e+4]

Epoch: 86


800it [00:33, 24.14it/s, loss=4.14e+4]


Loss: 41384.546875


3it [00:00, 24.09it/s, loss=4.13e+4]

Epoch: 87


800it [00:32, 24.48it/s, loss=4.16e+4]


Loss: 41596.0625


3it [00:00, 20.79it/s, loss=4.16e+4]

Epoch: 88


800it [00:32, 24.90it/s, loss=4.27e+4]


Loss: 42712.37109375


3it [00:00, 21.05it/s, loss=4.17e+4]

Epoch: 89


800it [00:32, 24.57it/s, loss=4.17e+4]


Loss: 41740.30859375


3it [00:00, 22.24it/s, loss=4.17e+4]

Epoch: 90


800it [00:33, 24.14it/s, loss=4.19e+4]


Loss: 41896.30078125


3it [00:00, 22.40it/s, loss=4.17e+4]

Epoch: 91


800it [00:33, 24.00it/s, loss=4.16e+4]


Loss: 41618.6953125


3it [00:00, 25.88it/s, loss=4.18e+4]

Epoch: 92


800it [00:33, 23.68it/s, loss=4.14e+4]


Loss: 41432.60546875


3it [00:00, 25.99it/s, loss=4.17e+4]

Epoch: 93


800it [00:33, 23.96it/s, loss=4.15e+4]


Loss: 41495.59765625


3it [00:00, 24.99it/s, loss=41893.0]

Epoch: 94


800it [00:33, 23.70it/s, loss=4.18e+4]


Loss: 41831.07421875


3it [00:00, 22.47it/s, loss=4.16e+4]

Epoch: 95


800it [00:34, 23.45it/s, loss=4.18e+4]


Loss: 41808.22265625


3it [00:00, 24.67it/s, loss=4.18e+4]

Epoch: 96


800it [00:33, 23.92it/s, loss=41939.0]


Loss: 41939.0


3it [00:00, 25.62it/s, loss=4.14e+4]

Epoch: 97


800it [00:33, 23.55it/s, loss=4.2e+4] 


Loss: 41951.01953125


3it [00:00, 22.05it/s, loss=4.17e+4]

Epoch: 98


800it [00:34, 23.13it/s, loss=4.15e+4]


Loss: 41451.3671875


3it [00:00, 25.10it/s, loss=4.15e+4]

Epoch: 99


800it [00:34, 23.21it/s, loss=4.2e+4] 

Loss: 41984.109375





### saving the model

In [8]:
torch.save(model.state_dict(), 'models/test_overfitting')

### loading the model

In [9]:
model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE)
model.load_state_dict(torch.load('models/test_overfitting'))
model.eval()

VariationalAutoEncoder(
  (enc1): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (fc1): Linear(in_features=14400, out_features=128, bias=True)
  (fc_mu): Linear(in_features=128, out_features=3, bias=True)
  (fc_log_var): Linear(in_features=128, out_features=3, bias=True)
  (fc2): Linear(in_features=3, out_features=256, bias=True)
  (dec1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(1, 1))
  (dec2): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec4): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec5): ConvTranspose2d(32, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec6): ConvTranspose2d