In [1]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn
from ipynb.fs.full.model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image  
from torch.utils.data import DataLoader, RandomSampler
import numpy as np

### configuration

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 256
INIT_DIM = 8
LATENT_DIM = 3
NUM_EPOCHS = 50
BATCH_SIZE = 1
LR_RATE = 3e-4
KERNEL_SIZE = 4

### loading dataset

In [3]:
# Dataset Loading
data_path = 'dataset' # setting path
transform = transforms.Compose([transforms.Resize((INPUT_DIM, INPUT_DIM)),   # sequence of transformations to be done
                                transforms.Grayscale(num_output_channels=1), # on each image (resize, greyscale,
                                transforms.ToTensor()])                      # convert to tensor)

dataset = datasets.ImageFolder(root=data_path, transform=transform) # read data from folder

train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) # create dataloader object

model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE).to(DEVICE) # initializing model object

optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE) # defining optimizer
loss_fn = nn.BCELoss(reduction='sum') # define loss function

### training model

In [4]:
# Start Training
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    print(f'Epoch: {epoch}')
    for i, (x, _) in loop:
        # forward pass
        x = x.to(DEVICE).view(1, INPUT_DIM, INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)
        
        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
        
        # backpropagation
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

2it [00:00, 19.54it/s, loss=4.69e+4]

Epoch: 0


800it [00:20, 38.66it/s, loss=4.46e+4]
3it [00:00, 29.24it/s, loss=4.46e+4]

Epoch: 1


800it [00:24, 33.10it/s, loss=4.45e+4]
5it [00:00, 40.64it/s, loss=4.45e+4]

Epoch: 2


800it [00:23, 34.64it/s, loss=4.42e+4]
4it [00:00, 34.75it/s, loss=4.44e+4]

Epoch: 3


800it [00:25, 31.20it/s, loss=4.32e+4]
4it [00:00, 34.17it/s, loss=4.33e+4]

Epoch: 4


800it [00:26, 30.66it/s, loss=4.35e+4]
4it [00:00, 32.34it/s, loss=4.32e+4]

Epoch: 5


800it [00:26, 29.72it/s, loss=4.24e+4]
4it [00:00, 34.08it/s, loss=4.28e+4]

Epoch: 6


800it [00:26, 29.94it/s, loss=4.28e+4]
4it [00:00, 31.97it/s, loss=4.26e+4]

Epoch: 7


800it [00:26, 30.05it/s, loss=4.52e+4]
3it [00:00, 26.60it/s, loss=4.2e+4] 

Epoch: 8


800it [00:26, 30.27it/s, loss=42089.0]
3it [00:00, 28.75it/s, loss=4.29e+4]

Epoch: 9


800it [00:26, 30.08it/s, loss=4.2e+4] 
4it [00:00, 30.03it/s, loss=4.18e+4]

Epoch: 10


800it [00:26, 30.27it/s, loss=4.19e+4]
3it [00:00, 25.88it/s, loss=4.17e+4]

Epoch: 11


800it [00:27, 29.52it/s, loss=4.19e+4]
4it [00:00, 32.81it/s, loss=4.24e+4]

Epoch: 12


800it [00:26, 30.13it/s, loss=4.19e+4]
4it [00:00, 31.54it/s, loss=4.21e+4]

Epoch: 13


800it [00:26, 30.05it/s, loss=4.19e+4]
4it [00:00, 32.24it/s, loss=4.17e+4]

Epoch: 14


800it [00:26, 29.97it/s, loss=4.26e+4]
4it [00:00, 32.66it/s, loss=4.23e+4]

Epoch: 15


800it [00:26, 29.91it/s, loss=4.22e+4]
3it [00:00, 27.38it/s, loss=4.25e+4]

Epoch: 16


800it [00:27, 29.36it/s, loss=4.21e+4]
4it [00:00, 30.81it/s, loss=4.15e+4]

Epoch: 17


800it [00:27, 29.61it/s, loss=4.21e+4]
4it [00:00, 32.50it/s, loss=4.25e+4]

Epoch: 18


800it [00:27, 29.28it/s, loss=4.19e+4]
3it [00:00, 26.93it/s, loss=4.18e+4]

Epoch: 19


800it [00:27, 29.05it/s, loss=4.15e+4]
4it [00:00, 32.49it/s, loss=4.26e+4]

Epoch: 20


800it [00:27, 29.51it/s, loss=4.19e+4]
3it [00:00, 29.86it/s, loss=4.2e+4] 

Epoch: 21


800it [00:27, 28.94it/s, loss=4.21e+4]
3it [00:00, 27.82it/s, loss=4.19e+4]

Epoch: 22


800it [00:27, 29.57it/s, loss=4.22e+4]
3it [00:00, 26.16it/s, loss=4.2e+4] 

Epoch: 23


800it [00:27, 28.98it/s, loss=4.2e+4] 
3it [00:00, 26.94it/s, loss=4.17e+4]

Epoch: 24


800it [00:26, 29.64it/s, loss=4.19e+4]
3it [00:00, 29.68it/s, loss=4.18e+4]

Epoch: 25


800it [00:27, 28.68it/s, loss=4.2e+4] 
4it [00:00, 32.76it/s, loss=4.14e+4]

Epoch: 26


800it [00:27, 28.96it/s, loss=4.16e+4]
3it [00:00, 29.31it/s, loss=4.21e+4]

Epoch: 27


800it [00:28, 28.53it/s, loss=4.22e+4]
4it [00:00, 30.56it/s, loss=4.18e+4]

Epoch: 28


800it [00:27, 29.25it/s, loss=4.19e+4]
3it [00:00, 22.87it/s, loss=4.15e+4]

Epoch: 29


800it [00:27, 29.14it/s, loss=4.19e+4]
3it [00:00, 28.17it/s, loss=4.28e+4]

Epoch: 30


800it [00:27, 29.34it/s, loss=4.15e+4]
4it [00:00, 30.75it/s, loss=4.2e+4] 

Epoch: 31


800it [00:27, 29.09it/s, loss=4.17e+4]
3it [00:00, 29.46it/s, loss=4.2e+4] 

Epoch: 32


800it [00:27, 29.28it/s, loss=4.2e+4] 
3it [00:00, 25.75it/s, loss=4.18e+4]

Epoch: 33


800it [00:27, 28.88it/s, loss=4.2e+4] 
4it [00:00, 31.06it/s, loss=4.14e+4]

Epoch: 34


800it [00:28, 28.29it/s, loss=4.22e+4]
3it [00:00, 26.21it/s, loss=4.19e+4]

Epoch: 35


800it [00:28, 28.54it/s, loss=4.14e+4]
4it [00:00, 31.75it/s, loss=4.19e+4]

Epoch: 36


800it [00:27, 28.65it/s, loss=4.15e+4]
3it [00:00, 29.00it/s, loss=4.14e+4]

Epoch: 37


800it [00:27, 29.28it/s, loss=4.15e+4]
3it [00:00, 25.08it/s, loss=4.2e+4] 

Epoch: 38


800it [00:28, 28.47it/s, loss=4.16e+4]
4it [00:00, 31.66it/s, loss=4.17e+4]

Epoch: 39


800it [00:27, 28.97it/s, loss=4.18e+4]
3it [00:00, 28.45it/s, loss=4.18e+4]

Epoch: 40


800it [00:28, 28.50it/s, loss=4.15e+4]
4it [00:00, 32.24it/s, loss=4.2e+4] 

Epoch: 41


800it [00:28, 28.53it/s, loss=4.2e+4] 
3it [00:00, 28.96it/s, loss=4.19e+4]

Epoch: 42


800it [00:27, 28.91it/s, loss=4.17e+4]
4it [00:00, 32.23it/s, loss=4.15e+4]

Epoch: 43


800it [00:28, 28.19it/s, loss=4.14e+4]
3it [00:00, 26.63it/s, loss=4.17e+4]

Epoch: 44


800it [00:27, 28.80it/s, loss=4.21e+4]
4it [00:00, 30.67it/s, loss=41943.0]

Epoch: 45


800it [00:28, 28.39it/s, loss=4.13e+4]
4it [00:00, 30.27it/s, loss=4.19e+4]

Epoch: 46


800it [00:28, 28.50it/s, loss=4.16e+4]
4it [00:00, 31.59it/s, loss=4.14e+4]

Epoch: 47


800it [00:27, 28.64it/s, loss=4.14e+4]
4it [00:00, 31.60it/s, loss=4.18e+4]

Epoch: 48


800it [00:28, 28.34it/s, loss=4.19e+4]
4it [00:00, 31.56it/s, loss=4.17e+4]

Epoch: 49


800it [00:27, 28.70it/s, loss=4.22e+4]


### saving the model

In [5]:
torch.save(model.state_dict(), 'models/model_256x')

### loading the model

In [6]:
model = VariationalAutoEncoder(init_dim=INIT_DIM, latent_dim=LATENT_DIM, kernel_size=KERNEL_SIZE)
model.load_state_dict(torch.load('models/model_256x'))
model.eval()

VariationalAutoEncoder(
  (enc1): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc3): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (enc4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (fc1): Linear(in_features=14400, out_features=128, bias=True)
  (fc_mu): Linear(in_features=128, out_features=3, bias=True)
  (fc_log_var): Linear(in_features=128, out_features=3, bias=True)
  (fc2): Linear(in_features=3, out_features=256, bias=True)
  (dec1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(1, 1))
  (dec2): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec4): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec5): ConvTranspose2d(32, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (dec6): ConvTranspose2d

### generating images

In [7]:
def inference(num_examples=100):
    idx = np.random.choice(range(0,800),num_examples,replace=False) # sampling from dataset
    for example,i in zip(range(num_examples),idx):
        image = dataset[i][0]
        with torch.no_grad():
            mu, sigma = model.encode(image.view(1,256,256)) # getting encoding of each image
        epsilon = torch.randn_like(sigma)
        z = mu + sigma*epsilon
        out = model.decode(z) # generating new image from the encoding
        out = out.view(-1, 1, 256, 256)
        save_image(out, f'output/256x/img_{i}.png')

In [9]:
inference()