In [3]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn
from ipynb.fs.full.model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image  
from torch.utils.data import DataLoader, RandomSampler
import numpy as np
import matplotlib.pyplot as plt

### configuration

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 256
INIT_DIM = 8
LATENT_DIM = 3
NUM_EPOCHS = 10
BATCH_SIZE = 1
LR_RATE = 3e-4
KERNEL_SIZE = 4

### loading dataset

In [5]:
# Dataset Loading
data_path = 'dataset' # setting path
transform = transforms.Compose([transforms.Resize((INPUT_DIM, INPUT_DIM)),   # sequence of transformations to be done
                                transforms.Grayscale(num_output_channels=1), # on each image (resize, greyscale,
                                transforms.ToTensor()])                      # convert to tensor)

dataset = datasets.ImageFolder(root=data_path, transform=transform) # read data from folder

train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) # create dataloader object

model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE).to(DEVICE) # initializing model object

optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE) # defining optimizer
loss_fn = nn.BCELoss(reduction='sum') # define loss function

### training model

In [6]:
# Start Training
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    print(f'Epoch: {epoch}')
    for i, (x, _) in loop:
        # forward pass
        x = x.to(DEVICE).view(1, INPUT_DIM, INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)
        
        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        
        # backpropagation
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

2it [00:00, 17.83it/s, loss=4.47e+4]

Epoch: 0


800it [00:20, 39.38it/s, loss=4.46e+4]
4it [00:00, 37.58it/s, loss=4.45e+4]

Epoch: 1


800it [00:22, 35.01it/s, loss=4.43e+4]
3it [00:00, 29.15it/s, loss=4.41e+4]

Epoch: 2


800it [00:24, 33.15it/s, loss=4.34e+4]
3it [00:00, 26.06it/s, loss=4.28e+4]

Epoch: 3


800it [00:26, 30.00it/s, loss=4.41e+4]
4it [00:00, 33.49it/s, loss=4.35e+4]

Epoch: 4


800it [00:26, 30.28it/s, loss=4.29e+4]
3it [00:00, 24.14it/s, loss=4.32e+4]

Epoch: 5


800it [00:27, 29.44it/s, loss=4.24e+4]
3it [00:00, 26.08it/s, loss=4.23e+4]

Epoch: 6


800it [00:27, 29.56it/s, loss=4.26e+4]
4it [00:00, 32.93it/s, loss=4.21e+4]

Epoch: 7


800it [00:27, 29.44it/s, loss=4.28e+4]
3it [00:00, 26.63it/s, loss=4.23e+4]

Epoch: 8


800it [00:27, 29.62it/s, loss=4.21e+4]
3it [00:00, 28.76it/s, loss=4.24e+4]

Epoch: 9


800it [00:27, 29.61it/s, loss=42230.5]


### saving the model

In [7]:
torch.save(model.state_dict(), 'models/model_256x_10')

### loading the model

In [4]:
model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE)
model.load_state_dict(torch.load('models/model_256x_10'))
model.eval()

VariationalAutoEncoder(
  (enc1): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (fc1): Linear(in_features=14400, out_features=128, bias=True)
  (fc_mu): Linear(in_features=128, out_features=3, bias=True)
  (fc_log_var): Linear(in_features=128, out_features=3, bias=True)
  (fc2): Linear(in_features=3, out_features=256, bias=True)
  (dec1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(1, 1))
  (dec2): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec4): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec5): ConvTranspose2d(32, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec6): ConvTranspose2d