# variational.ipynb
A variational autoenconder writing the loss function from Eq. (7) in   
[Doersch 2016, Tutorial on Variational Autoencoders](https://arxiv.org/pdf/1606.05908.pdf)


-Sergio Verduzco  
June 2023



### Some resources I consulted:
https://youtu.be/uaaqyVS9-rM  
https://www.youtube.com/watch?v=YV9D3TWY5Zo  
https://www.youtube.com/watch?v=8wrLjnQ7EWQ  
https://www.youtube.com/watch?v=VELQT1-hILo  
https://github.com/karpathy/examples/blob/master/vae/main.py  
https://arxiv.org/pdf/1906.02691.pdf

In [None]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import lovely_tensors as lt
lt.monkey_patch()

In [None]:
# Preliminary parameters
data_dir = '/home/z/Downloads/data/'  
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"device = {device}")
batch_size = 64
dataset = 'MNIST'  # MNIST. CIFAR10 not implemented yet.

Tricky things with the cell above:
* Batch size


Tricky things with the cell below:
* Which normalization value to use?
  * Must ensure it is consistent with the output nonlinearity of the decoder
  * Turns out `ToTensor` sets the values between 0 and 1, and this is enough. Further normalizing may hurt performance.

In [None]:
if dataset == 'CIFAR10':  # ignore!
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
    #testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
    classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
elif dataset == 'MNIST':
    transform = transforms.Compose(
        [transforms.ToTensor(),
         #transforms.Normalize((1.), (0.5)),
         torch.squeeze])
    trainset = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)
    #testset = torchvision.datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)
else:
    raise ValueError("Specify a valid dataset")

trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=2)
# testloader = torch.utils.data.DataLoader(testset,
#                                          batch_size=batch_size,
#                                          shuffle=False,
#                                          num_workers=2)


In [None]:
help(torch.randn_like)

In [None]:
class FFP(nn.Module):
    """ A feedforward perceptron. """
    def __init__(self, sizes, nltypes, bias=True):
        """
            sizes: list with size of each layer.
            nltypes: list with nonlinearity type for each inner or
                output layer. Entries are 'relu', 'sig', or 'tanh'.
            bias: whether the layers have a bias unit
        """
        assert len(sizes)-1 == len(nltypes), "length mismatch in nltypes, sizes"
        super(FFP, self).__init__()
        # Add activation functions
        self.nlfs = []
        for nltype in nltypes:
            if nltype == "relu":
                self.nlfs.append(nn.ReLU())
            elif nltype == "sig":
                self.nlfs.append(nn.Sigmoid())
            elif nltype == "tanh":
                self.nlfs.append(nn.Tanh())
            elif nltype == "linear":
                self.nlfs.append("linear")
            else:
                raise ValueError(f"unknown nonlinearity {nltype}")
        # create layers
        self.bias = bias
        self.sizes = sizes
        layers = []
        for lidx in range(1,len(sizes)):
            layers.append(nn.Linear(sizes[lidx-1], sizes[lidx], bias=bias))
        self.layers = nn.ModuleList(layers)
                
    def forward(self, x):
        for lidx, layer in enumerate(self.layers):
            if self.nlfs[lidx] == "linear":
                x = layer(x)
            else:
                x = self.nlfs[lidx](layer(x))
        return x

class normal_encoder(nn.Module):
    """ A FFP that outputs parameters for a multivariate normal distribution.
    
        This is the same as the FFP class, except that the output layer is
        duplicated by concatenating an extra set of units with sigmoidal 
        activation functions. This extra set of units corresponds to the
        entries of a diagonal covariance matrix, whereas the regular outputs
        corresponds to the means.
    """
    def __init__(self, sizes, nltypes, bias=True):
        """
            sizes: list with size of each layer.
            nltypes: list with nonlinearity type for each inner or
                output layer. Entries are 'relu', 'sig', 'tanh', or 'linear'.
            bias: whether the layers have a bias unit
        """
        assert len(sizes)-1 == len(nltypes), "length mismatch in nltypes, sizes"
        super(normal_encoder, self).__init__()
        self.n_layers = len(nltypes)
        # Add activation functions
        self.nlfs = []
        for nltype in nltypes:
            if nltype == "relu":
                self.nlfs.append(nn.ReLU())
            elif nltype == "sig":
                self.nlfs.append(nn.Sigmoid())
            elif nltype == "tanh":
                self.nlfs.append(nn.Tanh())
            elif nltype == "linear":
                self.nlfs.append("linear")
            else:
                raise ValueError(f"unknown nonlinearity {nltype}")
        # create layers
        self.bias = bias
        self.sizes = sizes
        layers = []
        for lidx in range(1,len(sizes)):
            layers.append(nn.Linear(sizes[lidx-1], sizes[lidx], bias=bias))
        # the last element in layers will be the sigmoidal variance layer
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=bias))
        self.layers = nn.ModuleList(layers)
                
    def forward(self, x):
        for lidx, layer in enumerate(self.layers[:-1]):
            if lidx == self.n_layers - 1:
                y = nn.Sigmoid()(self.layers[-1](x))
            if self.nlfs[lidx] == "linear":
                x = layer(x)
            else:
                x = self.nlfs[lidx](layer(x))
        return torch.concatenate((x, y), axis=-1)
        
class standard_SGD():
    """ An SGD optimizer for my FFP module. """
    def __init__(self, model, lr=0.1):
        """
            model: an instance of the FFP class
            lr: learning rate
        """
        self.model = model
        self.lr = lr
        
    def step(self):
        """ Updates the model's parameters. """
        for lidx, layer in enumerate(self.model.layers, 1):
            dw = self.lr * layer.weight.grad
            with torch.no_grad():
                layer.weight -= dw
                if self.model.bias:
                    layer.bias -= self.lr * layer.bias.grad
            
    def zero_grad(self):
        for layer in self.model.layers:
            layer.weight.grad.zero_()
            if self.model.bias:
                layer.bias.grad.zero_()

def KL_Loss(mu:torch.tensor, sigma:torch.tensor, reduction='mean') -> float:
    """ Encoder loss from the KL divergence.

        The loss is the KL divergence between two multivariate
        normals. The first has the given means, and a diagonal 
        variance matrix with the given sigma values. The second
        is a multivariate normal zero mean and identity covariance
        matrix.

        One term that is not relevant for the computation of gradients
        is removed from the loss (-k).
        
        Args:
            mu: mean values, size (m,k), where m=minibatch size, k=dimension
            sigma: variances, size (m,k)
                   All values must be positive.
            reduction: 'mean', 'sum', or None
        Returns:
            loss: KL divergence
    """
    loss =  0.5 * (sigma.sum(axis=1) + (mu * mu).sum(axis=1) - sigma.prod(axis=1).log())
    
    if loss.dim() == 0:
        return loss
    if reduction == 'mean':
        return loss.mean()
    if reduction == 'sum':
        return loss.sum()
    if reduction == None:
        return loss
    raise ValueError("type of reduction was not understood.")

def distro_loss(mu:torch.tensor, sigma:torch.tensor, reduction='mean') -> float:
    """ A MSE loss on the parameters of the multivariate normal.
    
    Args:
            mu: mean values, size (m,k), where m=minibatch size, k=dimension
            sigma: variances, size (m,k)
                   All values must be positive.
            reduction: 'mean', 'sum', or None
        Returns:
            norm of mu plus norm of (sigma - 1)
    """
    loss = (mu * mu).sum() + (sigma - 1).pow(2)

    if loss.dim() == 0:
        return loss
    if reduction == 'mean':
        return loss.mean()
    if reduction == 'sum':
        return loss.sum()
    if reduction == None:
        return loss
    raise ValueError("type of reduction was not understood.")

mse_loss = nn.MSELoss(reduction='mean')

Things that caused trouble with the cell above:
* Handling the separation in the last encoder layer. Half the outputs are mu, half are sigma.
* Ensuring that the sigma values of the encoder are always positive, and finding the right nonlinearity to use for the mu values.
  * The answer is to use sigmoidal activation for sigma, and a linear layer for mu.


Things that caused trouble with the cell below:
* Size of the network and type of nonlinearity for each layer.
  * Last layer of encoder must be able to produce both positive and negative values for mu, but only positive for sigma.
  * One video said the last layer of the decoder should use sigmoid units because of the MNIST image encoding. When I was normalizing the input images, nothing worked for me until I used a linear output layer. After I stoped normalizing a sigmoidal was probably better.
* Optimizer and learning rate

In [None]:
# create the encoder and decoder
n_latent = 2  # number of latent variables
enc_sizes = [784, 200, 100, n_latent]
enc_types = ['tanh', 'relu', 'linear']
encoder = normal_encoder(enc_sizes, enc_types, bias=True).to(device)

dec_sizes = [n_latent, 100, 200, 784]
dec_types = ['tanh', 'relu', 'linear']
decoder = FFP(dec_sizes, dec_types, bias=True).to(device)

# Multivariate normal
assert enc_sizes[-1] == dec_sizes[0], "Check bottleneck sizes"

std_multi_normal = torch.distributions.MultivariateNormal(torch.zeros(n_latent), torch.eye(n_latent))

# Optimizers
# encoder_optim = standard_SGD(encoder, lr=0.001)
# decoder_optim = standard_SGD(decoder, lr=0.001)
encoder_optim = torch.optim.Adam(encoder.parameters(), lr=5e-4)
decoder_optim = torch.optim.Adam(decoder.parameters(), lr=5e-4)
# encoder_optim = torch.optim.SGD(encoder.parameters(), lr=0.1)
# decoder_optim = torch.optim.SGD(decoder.parameters(), lr=0.1)


In [None]:
# training loop
n_epochs = 10
bsize = trainloader.batch_size
w_bit = 1e-5
w = 0.0

for epoch in range(n_epochs):
    accum_distrib_error = 0.
    accum_output_error = 0.
    print("----------------------------------------")
    for i, data in enumerate(trainloader):
        # retrieve the input
        input = data[0].flatten(start_dim=1).to(device)  # (bsize, 784)
        # feed the input to the encoder
        latent = encoder(input)  # (bsize, 2*n_latent)
        # extract means and diagonal of variance matrix
        means = latent[:, :enc_sizes[-1]]  # (bsize, n_latent)
        vars = latent[:, enc_sizes[-1]:]  # (bsize, n_latent)
        # sample from the multivariate normal. Size = (bsize, n_latent)
        #epsilons = std_multi_normal.sample((bsize,)).to(device)
        epsilons = torch.randn_like(vars)
        z = means + epsilons * vars
        # feed the sample to the decoder
        output = decoder(z)
        # Gradient descent
        distrib_error = w * KL_Loss(means, vars)
        # distrib_error = w * distro_loss(means, vars)
        output_error = mse_loss(input, output)
        #output_error = F.binary_cross_entropy(input, output, reduction='mean')
        error = distrib_error + output_error
        
        error.backward()
        
        encoder_optim.step()
        decoder_optim.step()

        encoder_optim.zero_grad()
        decoder_optim.zero_grad()

        with torch.no_grad():
            w += w_bit * (output_error - w * distrib_error)

        # display error
        accum_distrib_error += distrib_error
        accum_output_error += output_error
        if (i+1) % 200 == 0:
            print(f"    distribution error = {accum_distrib_error/i} up to example {i*bsize}")
            print(f"    reconstruction error = {accum_output_error/i} up to example {i*bsize}")
    accum_distrib_error /= len(trainloader)
    accum_output_error /= len(trainloader)
    print(f"distribution error = {accum_distrib_error} in epoch {epoch}")
    print(f"reconstruction error = {accum_output_error} in epoch {epoch}")
    print(f"w = {w}")

In [None]:
class random_image_generator():
    def __init__(self, decoder, distribution):
        self.decoder = decoder
        self.distribution = distribution

    def generate(self, show=True):
        input = self.distribution.sample().to(device)
        
        output = decoder(input).reshape((28,28)).to('cpu').detach()
        if show:
            plt.imshow(output)
        return output

# empirical_multi_normal = torch.distributions.MultivariateNormal(0.0*torch.ones(n_latent), 0.7*torch.eye(n_latent))
# imgen = random_image_generator(decoder, empirical_multi_normal)

imgen = random_image_generator(decoder, std_multi_normal)

In [None]:
a = np.arange(5)
b = torch.tensor([6, 7, 8, 9, 0])
a = torch.tensor(a)
print(a.p)
print(b.p)

a.pow(2).div(b +.001).sum()

In [None]:
# run repeatedly to get many images
val = imgen.generate()

In [None]:
# create a grid of images (when n_latent=2)
min_z1 = -2.
max_z1 = 2.
min_z2 = -2.
max_z2 = 2.
n_grid = 10
fig, axs = plt.subplots(ncols=n_grid, nrows=n_grid, figsize=(n_grid, n_grid))
for idx1, z1 in enumerate(np.linspace(min_z1, max_z1, n_grid)):
    for idx2, z2 in enumerate(np.linspace(min_z2, max_z2, n_grid)):
        input = torch.tensor([z1, z2], dtype=torch.float32).to(device)
        output = decoder(input).reshape((28,28)).to('cpu').detach()
        axs[idx1, idx2].set_xticks([])
        axs[idx1, idx2].set_yticks([])
        axs[idx1, idx2].imshow(output)

In [None]:
# create a grid of images (when n_latent=1)
min_z1 = -2.
max_z1 = 2.
n_grid = 15
fig, axs = plt.subplots(ncols=n_grid, nrows=1, figsize=(n_grid, n_grid))
for idx1, z1 in enumerate(np.linspace(min_z1, max_z1, n_grid)):
    input = torch.tensor([z1], dtype=torch.float32).to(device)
    output = decoder(input).reshape((28,28)).to('cpu').detach()
    axs[idx1].set_xticks([])
    axs[idx1].set_yticks([])
    axs[idx1].imshow(output)

In [None]:
# Test the decoder's statistitics. Should look like the "input" next cell
print(val)
#print(std_multi_normal.sample())

In [None]:
# Test the encoder's output statistics
sample_num = 5

for n in range(sample_num):
    input = trainset[n][0].flatten().reshape(1,-1).to(device)
    latent = encoder(input)
    mu = latent[0, :n_latent]
    sigma = latent[0, n_latent:]
    print("input = ", end=" ")
    print(input)
    print("mu = ", end=" ")
    print(mu)
    print("sigma = ", end=" ")
    print(sigma)
    print("==========================")

In [None]:
torch.save(decoder.state_dict(), 'decoder_distro_loss01.pth')