In [None]:
from src import *

import numpy as np
import torch
from torch import optim, nn

import matplotlib.pyplot as plt

import time
from tqdm.auto import tqdm


# CUDA

In [None]:
torch.manual_seed(1)
torch.cuda.manual_seed(1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('PyTorch is using', device)

# Parameters

In [None]:
# dataset related
dataset_name = "MNIST"
resize_h, resize_w = 28, 28

# dataloader related
batch_size = 36
num_workers = 8
prefetch_factor = 4

# model related
num_layers = 2
num_z = 1024

# training related
epochs = 30
init_lr = 1e-3


# Training related functions

In [None]:
def train(epoch,
          model, train_loader,
          criterion, optimizer):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = criterion(recon_batch, data, mu, logvar)
        train_loss += loss.item() / len(train_loader.dataset)
        loss.backward()
        optimizer.step()
        
    return train_loss

def reconstruct(model, test_loader):
    model.eval()
    
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model(data)

            n = min(data.shape[0], 8)
            samples = data[:n].cpu().numpy()
            recons = recon_batch[:n].cpu().numpy()
            
            break

    return samples, recons

def test(epoch,
         model, test_loader,
         criterion):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            
            loss = criterion(recon_batch, data, mu, logvar)
            test_loss += loss.item() / len(test_loader.dataset)
    
    return test_loss


# Visualization related functions

In [None]:
def to_rgb(sample):
    r = sample[0]
    g = sample[1]
    b = sample[2]
    rgb = (np.dstack((r,g,b)) * 255.999) .astype(np.uint8)
    return rgb

def visualize_imgs(samples, recons):
    (n,c,h,w) = samples.shape
    plt.figure(figsize=(28, 8))
    
    if (c == 3):
        for i in range(n):
            plt.subplot(2, n, i + 1)
            plt.imshow(to_rgb(samples[i]))

            plt.subplot(2, n, i + 1 + n)
            plt.imshow(to_rgb(recons[i]))
        plt.show()
    elif (c == 1):
        for i in range(n):
            plt.subplot(2, n, i + 1)
            plt.imshow(samples[i].reshape(28, 28), cmap='gray_r')

            plt.subplot(2, n, i + 1 + n)
            plt.imshow(recons[i].reshape(28, 28), cmap='gray_r')
        plt.show()
    

# Main

### Dataset

In [None]:
transforms = transforms.Compose([
    transforms.Resize(size=(resize_h, resize_w)),
    transforms.ToTensor(),    
])
dataset = CustomDataset(dataset_name=dataset_name, transforms=transforms)
train_loader = dataset.get_dataloader(is_train=True,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      prefetch_factor=prefetch_factor)
test_loader = dataset.get_dataloader(is_train=False,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=num_workers,
                                     prefetch_factor=prefetch_factor)

### Model

In [None]:
input_dim = resize_h * resize_w

layers = [int(input_dim / (2 ** i)) for i in range(num_layers)]
model = VAE(layers, num_z).to(device)
print(model)

criterion = ELBO(input_dim)
optimizer = optim.Adam(model.parameters(), lr=init_lr)

In [None]:
# Learn
last_loss = torch.finfo(torch.float32).max
total_time = 0
train_losses = []
test_losses = []

for epoch in range(1, epochs + 1):
    # generate and visualize
    samples, recons = reconstruct(model, test_loader)
    visualize_imgs(samples, recons)
    
    # train
    start_time = time.time()
    loss = train(epoch,
                 model, train_loader,
                 criterion, optimizer)
    end_time = time.time()
    train_losses.append(loss)
    dt = end_time - start_time
    total_time += dt

    # test
    loss = test(epoch,
                model, test_loader,
                criterion)
    test_losses.append(loss)
    
    print(f'Epoch {epoch} / {epochs}: {loss:.2f} in {dt:.2f} secs', '*' if loss < last_loss else '')

print('Train loss:', train_losses)
print('Test loss:', test_losses)
print(f'Average {total_time / epochs:.2f} secs per epoch consumed')
print(f'Total {total_time:.2f} secs consumed')