# Variational Autoencoder

## 1. Settings

### 1) Import required libraries

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import matplotlib as mpl
from IPython.display import Image

### 2) Set hyperparameters

In [None]:
batch_size = 128
learning_rate = 1e-3
num_epochs = 10

## 2. Data

### 1) Download Data

In [None]:
mnist_train = dset.MNIST("./", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_test = dset.MNIST("./", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_train, mnist_val = torch.utils.data.random_split(mnist_train, [50000, 10000])

In [None]:
mnist_train[0][0].size()    # (1, 28, 28)

In [None]:
mnist_train[0][1]           # label

### 2) Set DataLoader

In [None]:
dataloaders = {}
dataloaders['train'] = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
dataloaders['val'] = DataLoader(mnist_val, batch_size=batch_size, shuffle=False)
dataloaders['test'] = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

In [None]:
len(dataloaders["train"])

## 3. Model & Optimizer

In [None]:
# https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html
!wget -q https://www.dropbox.com/s/lmpjzzkqhk7d408/vae_gaussian.png
Image("vae_gaussian.png")

### 1) Model

In [None]:
# build your own variational autoencoder
# encoder: 784(28*28) -> 256
# sampling: 256 -> 10
# decoder: 10 -> 256 -> 784(28*28)

class VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(VariationalAutoencoder,self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, ??),    
            nn.Tanh(),                          # activation function
        )
        
        self.fc_mu = nn.Linear(??, ??)
        self.fc_var = nn.Linear(??, ??)
        
        self.decoder = nn.Sequential(
            nn.Linear(??, ??),
            nn.Tanh(),                          # activation function
            nn.Linear(??, 28*28),
            nn.Sigmoid()
        )
                
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = ??
        eps = ??
        return mu + eps*std
    
    def decode(self, z):
        recon = self.decoder(z)
        return recon
    
                
    def forward(self, x):                # x: (batch_size, 1, 28, 28)
        batch_size = x.size(0)
        mu, log_var = self.encode(x.view(batch_size, -1))
        z = self.reparameterize(mu, log_var)
        out = self.decode(z)
        return out, mu, log_var

### 2) Loss func & Optimizer

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
BCE = torch.nn.BCELoss(reduction='sum')

def loss_func(x, recon_x, mu, log_var):
    #batch_size = x.size(0)
    #MSE_loss = MSE(x, recon_x.view(batch_size, 1, 28, 28))

    BCE_loss = BCE(recon_x, x.view(-1, 784))
    KLD_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE_loss + KLD_loss

In [None]:
model = VariationalAutoencoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## 4. Train 

In [None]:
import time
import copy

def train_model(model, dataloaders, criterion, optimizer, num_epochs=10):
    """
    model: model to train
    dataloaders: train, val, test data's loader
    criterion: loss function
    optimizer: optimizer to update your model
    """
    since = time.time()

    train_loss_history = []
    val_loss_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_loss = 100000000

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()            # Set model to training mode
            else:
                model.eval()            # Set model to evaluate mode

            running_loss = 0.0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)                                       # transfer inputs to GPU 

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):

                    outputs, mu, log_var = model(inputs)
                    loss = criterion(inputs, outputs, mu, log_var)  # calculate a loss


                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()                             # perform back-propagation from the loss
                        optimizer.step()                             # perform gradient descent with given optimizer

                # statistics
                running_loss += loss.item()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
            
            # deep copy the model
            if phase == 'train':
                train_loss_history.append(epoch_loss)

            if phase == 'val':
                val_loss_history.append(epoch_loss)

            if phase == 'val' and epoch_loss < best_val_loss:
                best_val_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
            

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Loss: {:4f}'.format(best_val_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, train_loss_history, val_loss_history

In [None]:
best_model, train_loss_history, val_loss_history = train_model(model, dataloaders, loss_func, optimizer, num_epochs=num_epochs)

In [None]:
# Let's draw a learning curve like below.
plt.plot(train_loss_history, label='train')
plt.plot(val_loss_history, label='val')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.show()

## 5. Check with Test Image (Can VAE reconstruct input images?)

In [None]:
with torch.no_grad():
    running_loss = 0.0
    for inputs, labels in dataloaders["test"]:
        inputs = inputs.to(device)

        outputs, mu, log_var = best_model(inputs)
        test_loss = loss_func(inputs, outputs, mu, log_var)
        
        running_loss += test_loss.item()

    test_loss = running_loss / len(dataloaders["test"].dataset)
    print(test_loss)        

In [None]:
out_img = torch.squeeze(outputs.cpu().data)
print(out_img.size())

for i in range(5):
    plt.subplot(1,2,1)
    plt.imshow(torch.squeeze(inputs[i]).cpu().numpy(),cmap='gray')
    plt.subplot(1,2,2)
    plt.imshow(out_img[i].numpy().reshape(28, 28),cmap='gray')
    plt.show()

## 6. Visualizing MNIST

In [None]:
np.random.seed(42)

from sklearn.manifold import TSNE

In [None]:
train_dataset_array = mnist_train.dataset.data.numpy() / 255
train_dataset_array = np.float32(train_dataset_array)
labels = mnist_train.dataset.targets.numpy()

In [None]:
subset_indices = []
subset_indices_per_class = []

for i in range(10):
    indices = np.where(labels == i)[0]
    subset_size = len(indices) // 6
    subset_indices += indices[:subset_size].tolist()
    subset_indices_per_class.append(indices[:subset_size].tolist())

train_dataset_array = train_dataset_array[subset_indices]
labels = labels[subset_indices]

In [None]:
train_dataset_array = torch.tensor(train_dataset_array)
inputs = train_dataset_array.to(device)
outputs, mu, log_var = best_model(inputs)

In [None]:
encoded = mu.cpu().detach().numpy()
tsne = TSNE()   
X_train_2D = tsne.fit_transform(encoded)
X_train_2D = (X_train_2D - X_train_2D.min()) / (X_train_2D.max() - X_train_2D.min())

In [None]:
plt.scatter(X_train_2D[:, 0], X_train_2D[:, 1], c=labels, s=10, cmap="tab10")
plt.axis("off")
plt.show()

Let's make this diagram a bit prettier:

In [None]:
# adapted from https://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html
plt.figure(figsize=(10, 8))
cmap = plt.cm.tab10
plt.scatter(X_train_2D[:, 0], X_train_2D[:, 1], c=labels, s=10, cmap=cmap)
image_positions = np.array([[1., 1.]])
for index, position in enumerate(X_train_2D):
    dist = np.sum((position - image_positions) ** 2, axis=1)
    if np.min(dist) > 0.02: # if far enough from other images
        image_positions = np.r_[image_positions, [position]]
        imagebox = mpl.offsetbox.AnnotationBbox(
            mpl.offsetbox.OffsetImage(torch.squeeze(inputs).cpu().numpy()[index], cmap="binary"),
            position, bboxprops={"edgecolor": cmap(labels[index]), "lw": 2})
        plt.gca().add_artist(imagebox)
plt.axis("off")
plt.show()

## 7. Walk through latent space of MNIST

In [None]:
encoded.shape

In [None]:
mean_encoded = []
for i in range(10):
    mean_encoded.append(encoded[np.where(labels == i)[0]].mean(axis=0))

In [None]:
selected_class = [1, 7]
samples = []
with torch.no_grad():
    for idx, coef in enumerate(np.linspace(0, 1, 10)):
        interpolated = coef * mean_encoded[selected_class[0]] + (1.-coef) * mean_encoded[selected_class[1]]
        samples.append(interpolated)
    samples = np.stack(samples)
    z = torch.tensor(samples).to(device)
    generated = best_model.decoder(z).to(device)

In [None]:
generated = generated.view(10, 1, 28, 28)
img = make_grid(generated, nrow=10)
npimg = img.cpu().numpy()
plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [None]:
selected_class = [1, 8]
samples = []
with torch.no_grad():
    for idx, coef in enumerate(np.linspace(0, 1, 10)):
        interpolated = coef * mean_encoded[selected_class[0]] + (1.-coef) * mean_encoded[selected_class[1]]
        samples.append(interpolated)
    samples = np.stack(samples)
    z = torch.tensor(samples).to(device)
    generated = best_model.decoder(z).to(device)

In [None]:
generated = generated.view(10, 1, 28, 28)
img = make_grid(generated, nrow=10)
npimg = img.cpu().numpy()
plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

## 8. Comparison between low capacity model and high capacity model

In [None]:
# build your own variational autoencoder
# encoder: 784(28*28) -> 512 -> 256
# sampling: 256 -> 10
# decoder: 10 -> 256 -> 512 -> 784(28*28)

class VariationalAutoencoderHigh(nn.Module):
    def __init__(self):
        super(VariationalAutoencoderHigh,self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, ??),    
            nn.ReLU(),                          # activation function
            nn.Linear(??, ??),
            nn.ReLU()                           # activation function
        )
        
        self.fc_mu = nn.Linear(??, ??)
        self.fc_var = nn.Linear(??, ??)
        
        self.decoder = nn.Sequential(
            nn.Linear(??, ??),
            nn.ReLU(),                          # activation function
            nn.Linear(??, ??),
            nn.ReLU(),                          # activation function
            nn.Linear(??, 28*28),
            nn.Sigmoid()
        )
                
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = ??
        eps = ??
        return mu + eps*std
    
    def decode(self, z):
        recon = self.decoder(z)
        return recon
    
                
    def forward(self, x):                # x: (batch_size, 1, 28, 28)
        batch_size = x.size(0)
        mu, log_var = self.encode(x.view(batch_size, -1))
        z = self.reparameterize(mu, log_var)
        out = self.decode(z)
        return out, mu, log_var

In [None]:
model = VariationalAutoencoderHigh().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
best_model_high, train_loss_history_high, val_loss_history_high = train_model(model, dataloaders, loss_func, optimizer, num_epochs=num_epochs)

In [None]:
# Let's draw a learning curve for low and high capacity models.
plt.plot(train_loss_history, label='low_train')
plt.plot(val_loss_history, label='low_val')
plt.plot(train_loss_history_high, label='high_train')
plt.plot(val_loss_history_high, label='high_val')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.show()

In [None]:
with torch.no_grad():
    running_loss = 0.0
    for inputs, labels in dataloaders["test"]:
        inputs = inputs.to(device)

        outputs, mu, log_var = best_model_high(inputs) # best_model_high 
        test_loss = loss_func(inputs, outputs, mu, log_var)
        
        running_loss += test_loss.item()

    test_loss = running_loss / len(dataloaders["test"].dataset)
    print(test_loss)        

In [None]:
out_img_high = torch.squeeze(outputs.cpu().data) # out_img_high
print(out_img.size())

for i in range(5):
    plt.subplot(1,3,1)
    plt.imshow(torch.squeeze(inputs[i]).cpu().numpy(),cmap='gray')
    plt.subplot(1,3,2)
    plt.imshow(out_img[i].numpy().reshape(28, 28),cmap='gray')
    plt.subplot(1,3,3)
    plt.imshow(out_img_high[i].numpy().reshape(28, 28),cmap='gray')
    plt.show()

## 9. What if we sample z from a normal distribution?

In [None]:
z = torch.randn(100, 10).cuda()
outputs = best_model_high.decoder(z)
out_img_high = torch.squeeze(outputs.cpu().data)
print(out_img_high.shape)

plt.figure(figsize=(10, 10))
for i in range(100):
    plt.subplot(10, 10, i+1)
    plt.imshow(out_img_high[i].reshape(28, 28).numpy(), cmap='gray')
    plt.axis('off')

plt.show()

## 10. BCE loss and MSE loss

In [None]:
#Tutorial on Variational Autoencoders Carl Doersch
!wget -q https://www.dropbox.com/s/5kkhyo7apxkay5z/BCE_loss%20and%20MSE_loss.PNG
Image("BCE_loss and MSE_loss.PNG")