# pytorch_VAE.ipynb
# WESmith 07/15/23
## Variational Autoencoder (VAE)
## reference:
## https://github.com/rasbt/stat453-deep-learning-ss21/blob/main/L17/1_VAE_mnist_sigmoid_mse.ipynb


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
DEVICE = torch.device(f'cuda:{CUDA_DEVICE_NUM}' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

In [None]:
transform = transforms.ToTensor()

In [None]:
batch_size       = 64
lr               = 1e-3
decay            = 1e-5
data_dir         = 'data'
model_path       = 'results/model_VAE.pth'
optimizer_path   = 'results/optimizer_VAE.pth'

In [None]:
mnist_data = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)

In [None]:
data_loader = torch.utils.data.DataLoader(dataset=mnist_data, batch_size=batch_size, shuffle=True)

In [None]:
# get example data
dataiter = iter(data_loader)
images, labels = next(dataiter)

In [None]:
images.shape

In [None]:
class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)

class Trim(nn.Module):
    def __init__(self, *args):
        super().__init__()

    def forward(self, x):
        return x[:, :, :28, :28]

In [None]:
class VAE(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
                nn.Conv2d(1, 32, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.Flatten(),)    

        self.z_mean    = nn.Linear(3136, 2)
        self.z_log_var = nn.Linear(3136, 2)

        self.decoder = nn.Sequential(
                nn.Linear(2, 3136),
                Reshape(-1, 64, 7, 7),
                nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0), 
                Trim(),  # 1x29x29 -> 1x28x28
                nn.Sigmoid())
        
    def reparameterize(self, z_mu, z_log_var):
        #eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(z_mu.get_device())
        eps = torch.randn(z_mu.size(0), z_mu.size(1))
        z = z_mu + eps * torch.exp(z_log_var/2.) 
        return z

    def forward(self, x):
        x = self.encoder(x)
        z_mean, z_log_var = self.z_mean(x), self.z_log_var(x)
        encoded = self.reparameterize(z_mean, z_log_var)
        decoded = self.decoder(encoded)
        return encoded, z_mean, z_log_var, decoded

In [None]:
model     = VAE()
loss_fn   = F.mse_loss
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=decay)

In [None]:
count = 0
for k in model.parameters():
    count += k.numel()
count

In [None]:
# examine the model outputs
encoded, z_mean, z_log_var, decoded = model(images)

In [None]:
encoded.shape, z_mean.shape, z_log_var.shape, decoded.shape

In [None]:
loss_fn(images, decoded)

In [None]:
# see if previous training exists, if so, load, otherwise train
if os.path.isfile(model_path) and os.path.isfile(optimizer_path):
    model_state_dict = torch.load(model_path)
    model.load_state_dict(model_state_dict)
    optimizer_state_dict = torch.load(optimizer_path)
    optimizer.load_state_dict(optimizer_state_dict)

In [None]:
# TRAINING
n_epochs = 1  # 1 epochs took 1m 25s to run on acer
outputs= []
for epoch in range(n_epochs):
    for (img, _) in data_loader:

        encoded, z_mean, z_log_var, decoded = model(img)

        # total loss = reconstruction loss + KL divergence
        #kl_divergence = (0.5 * (z_mean**2 + 
        #                        torch.exp(z_log_var) - z_log_var - 1)).sum()
        kl_div = -0.5 * torch.sum(1 + z_log_var 
                                    - z_mean**2 
                                    - torch.exp(z_log_var), 
                                      axis=1) # sum over latent dimension

        batchsize = kl_div.size(0)
        kl_div    = kl_div.mean() # average over batch dimension

        pixel_term = loss_fn(decoded, img, reduction='none')
        pixel_term = pixel_term.view(batchsize, -1).sum(axis=1) # sum over pixels
        pixel_term = pixel_term.mean() # average over batch dimension

        loss = pixel_term + kl_div

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')
    outputs.append((epoch, img, recon))
    torch.save(model.state_dict(),   model_path)     #'results/model.pth')
    torch.save(optimizer.state_dict(), optimizer_path) #'results/optimizer.pth')

In [None]:
encoded, z_mean, z_log_var, decoded = model(images)

In [None]:
images.shape, encoded.shape, z_mean.shape, z_log_var.shape, decoded.shape

In [None]:
nc  = 20
lim = 2
plt.figure(figsize=(20,8))
plt.gray()
for i, item in enumerate(images):
    if i >= nc: break
    plt.subplot(3, nc, i+1)
    #item = item.reshape(-1, 28, 28)
    # item is (1, 28, 28) with singleton from data_loader
    plt.imshow(item[0])

for i, item in enumerate(decoded):
    if i >= nc: break
    plt.subplot(3, nc, nc+i+1)
    #item = item.reshape(-1, 28, 28)
    plt.imshow(item[0].detach().numpy()) # also a singleton, so take [0]
    
for i, item in enumerate(encoded):
    if i >= nc: break
    plt.subplot(3, nc, 2*nc+i+1)
    val = item.detach().numpy()
    plt.scatter(val[0], val[1])
    plt.xlim(-lim, lim)
    plt.ylim(-lim, lim)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.grid()

In [None]:
nr, nc = (7, 7)
lim = 2
plt.figure(figsize=(20,20))
plt.gray()
count = 1
for y in np.linspace(lim, -lim, num=nr):
    for x in np.linspace(-lim, lim, num=nc):
        dd = model.decoder(torch.tensor((x, y), dtype=torch.float32))
        plt.subplot(nr, nc, count)
        txt = f'({x:.2f},{y:.2f})'
        plt.imshow(dd[0][0].detach().numpy())
        plt.text(0, 0, txt, va='top', color='white', fontsize=12)
        count += 1

In [None]:
np.linspace(2, -2, num=5)

In [None]:
dd = model.decoder(torch.tensor((-1, 0), dtype=torch.float32))
plt.imshow(dd[0][0].detach().numpy())
plt.show()