# pytorch_VAE.ipynb
# WESmith 07/15/23
## Variational Autoencoder (VAE)
## reference:
## https://github.com/rasbt/stat453-deep-learning-ss21/blob/main/L17/1_VAE_mnist_sigmoid_mse.ipynb


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import os
from torchsummary import summary

In [None]:
DEVICE = torch.device(f'cuda:{CUDA_DEVICE_NUM}' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

In [None]:
transform = transforms.ToTensor()

In [None]:
batch_size       = 64
lr               = 1e-3
decay            = 1e-5
data_dir         = 'data'
model_path       = 'results/model_VAE.pth'
optimizer_path   = 'results/optimizer_VAE.pth'
model_path_2     = 'results/model_VAE_2.pth'      # WS mods to model
optimizer_path_2 = 'results/optimizer_VAE_2.pth'

In [None]:
mnist_data = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)

In [None]:
data_loader = torch.utils.data.DataLoader(dataset=mnist_data, batch_size=batch_size, shuffle=True)

In [None]:
# get a data batch for demonstrations
dataiter = iter(data_loader)
images, labels = next(dataiter)

In [None]:
images.shape

In [None]:
class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)

class Trim(nn.Module):
    def __init__(self, *args):
        super().__init__()

    def forward(self, x):
        return x[:, :, :28, :28]

In [None]:
class VAE(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
                nn.Conv2d(1, 32, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.Flatten(),)    

        self.z_mean    = nn.Linear(3136, 2)
        self.z_log_var = nn.Linear(3136, 2)

        self.decoder = nn.Sequential(
                nn.Linear(2, 3136),
                Reshape(-1, 64, 7, 7),
                nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0), 
                Trim(),  # 1x29x29 -> 1x28x28
                nn.Sigmoid())
        
    def reparameterize(self, z_mu, z_log_var):
        #eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(z_mu.get_device())
        eps = torch.randn(z_mu.size(0), z_mu.size(1))
        z = z_mu + eps * torch.exp(z_log_var/2.) 
        return z

    def forward(self, x):
        x = self.encoder(x)
        z_mean, z_log_var = self.z_mean(x), self.z_log_var(x)
        encoded = self.reparameterize(z_mean, z_log_var)
        decoded = self.decoder(encoded)
        return encoded, z_mean, z_log_var, decoded

In [None]:
class VAE2(nn.Module):  # WS mod to increase autoencoding layer size

    def __init__(self, encode_size):  # WS mod
        super().__init__()
        self.encode_size = encode_size  # WS mod

        self.encoder = nn.Sequential(
                nn.Conv2d(1, 32, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.Flatten(),)    

        self.z_mean    = nn.Linear(3136, encode_size)  # WS mod
        self.z_log_var = nn.Linear(3136, encode_size)  # WS mod

        self.decoder = nn.Sequential(
                nn.Linear(encode_size, 3136),  # WS mod
                Reshape(-1, 64, 7, 7),
                nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0), 
                Trim(),  # 1x29x29 -> 1x28x28
                nn.Sigmoid())

    def reparameterize(self, z_mu, z_log_var):
        #eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(z_mu.get_device())
        eps = torch.randn(z_mu.size(0), z_mu.size(1))
        z = z_mu + eps * torch.exp(z_log_var/2.) 
        return z

    def forward(self, x):
        x = self.encoder(x)
        z_mean, z_log_var = self.z_mean(x), self.z_log_var(x)
        encoded = self.reparameterize(z_mean, z_log_var)
        decoded = self.decoder(encoded)
        return encoded, z_mean, z_log_var, decoded    

In [None]:
def train(model, optimizer, loss_fn, data_loader, n_epochs=1, save=False):
    # 'save' is a tuple of (model_path, optimizer_path), default=False
    count = 0
    for epoch in range(n_epochs):
        for (img, _) in data_loader:

            encoded, z_mean, z_log_var, decoded = model(img)

            # total loss = reconstruction loss + KL divergence
            #kl_divergence = (0.5 * (z_mean**2 + 
            #                        torch.exp(z_log_var) - z_log_var - 1)).sum()
            kl_div = -0.5 * torch.sum(1 + z_log_var 
                                        - z_mean**2 
                                        - torch.exp(z_log_var), 
                                          axis=1) # sum over latent dimension

            batchsize = kl_div.size(0)
            kl_div    = kl_div.mean() # average over batch dimension

            pixel_term = loss_fn(decoded, img, reduction='none')
            pixel_term = pixel_term.view(batchsize, -1).sum(axis=1) # sum over pixels
            pixel_term = pixel_term.mean() # average over batch dimension

            loss = pixel_term + kl_div

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if count % 100 == 0:
                print(f'epoch {epoch + 1} batch {count} has loss {loss.item():.4f}')
            count += 1

        print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')
        if save:
            torch.save(model.state_dict(),     save[0])
            torch.save(optimizer.state_dict(), save[1])

In [None]:
def train_or_load(model, optimizer, load, loss_fn=None, data_loader=None, n_epochs=1, force_train=False):
    # 'load' is a tuple of (model_path, optimizer_path)
    # if 'force_train' is True, the model will be trained regardless of saved training
    # this is useful if additional training is desired of a saved model
    # - first time force_train=False to either read in an existing training or start a new training
    # - subsequent calls have force_train=True to further train the model
    if os.path.isfile(load[0]) and os.path.isfile(load[1]) and not force_train:
        print(f'reading model from {load}')
        model_state_dict = torch.load(load[0])
        model.load_state_dict(model_state_dict)
        optimizer_state_dict = torch.load(load[1])
        optimizer.load_state_dict(optimizer_state_dict)
    else:
        print(f'training model, will be saved in {load}')
        train(model, optimizer, loss_fn, data_loader, n_epochs=n_epochs, save=load)

In [None]:
def examine_loss(model, data):
    '''convenience function to look at loss from a batch of data for a given model'''
    encoded, z_mean, z_log_var, decoded = model(data)
    print(f'LOSS over batch size of {data.shape[0]}: {loss_fn(data, decoded):.5f}')

In [None]:
def examine_decoded(model, images, nr=6, nc=6, wid=12, hei=12):
    encoded, z_mean, z_log_var, decoded = model(images)
    plt.figure(figsize=(wid, hei))
    plt.gray()
    count = 1
    for i, j in zip(images, decoded):
        if count >= nr * nc: break
        plt.subplot(nr, nc, count)
        # item is (1, 28, 28) with singleton from data_loader
        plt.imshow(i[0])
        plt.subplot(nr, nc, count + 1)
        plt.imshow(j[0].detach().numpy()) # also a singleton, so take [0]
        count += 2
    plt.tight_layout()
    plt.suptitle(f'ENCODING DIMENSION: {encoded.shape[1]}', fontsize=14)

## ORIGINAL MODEL WITH ENCODING SIZE FIXED AT 2

In [None]:
model_path     = 'results/model_VAE.pth'
optimizer_path = 'results/optimizer_VAE.pth'
model          = VAE()
loss_fn        = F.mse_loss
optimizer      = optim.Adam(model.parameters(), lr=lr, weight_decay=decay)

In [None]:
sum(k.numel() for k in model.parameters())

In [None]:
examine_loss(model, images)  # loss before training

In [None]:
train_or_load(model, optimizer, (model_path, optimizer_path))

In [None]:
examine_loss(model, images)  # loss after training

In [None]:
model

In [None]:
summary(model, images.shape[1:])

## NEW MODEL WITH ENCODING SIZE 8

In [None]:
encode_size      = 8  # was 2 in original model
model_path_2     = f'results/model_VAE_{encode_size}.pth'      # WS mods to model
optimizer_path_2 = f'results/optimizer_VAE_{encode_size}.pth'
model2           = VAE2(encode_size)
optimizer2       = optim.Adam(model2.parameters(), lr=lr, weight_decay=decay)

In [None]:
model_path_2, optimizer_path_2

In [None]:
sum(k.numel() for k in model2.parameters())

In [None]:
examine_loss(model2, images) # loss before training

In [None]:
load = (model_path_2, optimizer_path_2)
train_or_load(model2, optimizer2, load, loss_fn=loss_fn, data_loader=data_loader, force_train=False)

In [None]:
examine_loss(model2, images) # loss after training

## NEW MODEL WITH ENCODING SIZE 16

In [None]:
encode_size      = 16  # was 2 in original model
model_path_3     = f'results/model_VAE_{encode_size}.pth'      # WS mods to model
optimizer_path_3 = f'results/optimizer_VAE_{encode_size}.pth'
model3           = VAE2(encode_size)
optimizer3       = optim.Adam(model3.parameters(), lr=lr, weight_decay=decay)

In [None]:
model_path_3, optimizer_path_3

In [None]:
sum(k.numel() for k in model3.parameters())

In [None]:
load = (model_path_3, optimizer_path_3)
train_or_load(model3, optimizer3, load, loss_fn=loss_fn, data_loader=data_loader, force_train=False)

## EXAMINE ORIGINAL, DECODED, ENCODED (FOR 2D ENCODING) 

In [None]:
examine_decoded(model, images, nr=6, nc=8)

In [None]:
examine_decoded(model2, images, nc=8)

In [None]:
examine_decoded(model3, images, nc=8)

## EXAMINE DECODED FROM LINEAR SCAN OF ENCODING 
### (only works for 2D encoding)

In [None]:
import sys
nr, nc = (7, 7)
lim = 2
count = 1
encoded, z_mean, z_log_var, decoded = model(images)
if encoded.shape[1] != 2:
    sys.exit(f'ERROR: this plot only works for an encoding dimension of 2')
plt.figure(figsize=(20,20))
plt.gray()
for y in np.linspace(lim, -lim, num=nr):
    for x in np.linspace(-lim, lim, num=nc):
        dd = model.decoder(torch.tensor((x, y), dtype=torch.float32))
        plt.subplot(nr, nc, count)
        txt = f'({x:.2f},{y:.2f})'
        plt.imshow(dd[0][0].detach().numpy())
        plt.text(0, 0, txt, va='top', color='white', fontsize=12)
        count += 1
plt.suptitle(f'ENCODING DIMENSION: {encoded.shape[1]}', fontsize=14)
plt.show()