In [1]:
from typing import List
import torch
from torch import nn
from torch.nn import functional as F
from typing import List, Callable, Union, Any, TypeVar, Tuple
import torchvision
from torchvision.utils import save_image
import torchvision.transforms as transforms

In [2]:
# from torch import tensor as Tensor
Tensor = TypeVar('torch.tensor')
manual_seed = 1265

In [3]:
hidden_dim = 400
in_channels = 1
latent_dim = 20
batch_size = 128

# kld_weight = 0.00025
kld_weight = 0.0001

In [4]:
class VanillaVAE(nn.Module):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super(VanillaVAE, self).__init__()

        self.latent_dim = latent_dim
        out_channels = in_channels

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)

        # Build Decoder
        modules = []
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1])
        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            print(f"hidden_dims[i]: {hidden_dims[i]}, hidden_dims[i+1]: {hidden_dims[i+1]}")
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=2,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= out_channels,
                                      kernel_size= 3, padding= 0),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """        
        result = self.encoder(input)
        # print(f"encoder: {result.shape}")  
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        # print(f"decoder shape of z: {z.shape}")
        result = self.decoder_input(z)

        result = result.unsqueeze(-1).unsqueeze(-1)

        # print(f"shape of decoder_input output: {result.shape}")
        # result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        # print(f"shape of decoder output: {result.shape}")
        result = self.final_layer(result)
        # print(f"shape of final_layer output: {result.shape}")
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  self.decode(z), mu, log_var

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples, self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_dataset = torchvision.datasets.MNIST(root='data',
                                          train=False,
                                          transform=transforms.ToTensor())

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [6]:
model = VanillaVAE(in_channels=1, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

hidden_dims[i]: 512, hidden_dims[i+1]: 256
hidden_dims[i]: 256, hidden_dims[i+1]: 128
hidden_dims[i]: 128, hidden_dims[i+1]: 64
hidden_dims[i]: 64, hidden_dims[i+1]: 32


In [7]:
def loss_function(recons,input, mu, log_var) -> dict:
    """
    Computes the VAE loss function.
    KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
    :param args:
    :param kwargs:
    :return:
    """

    recons_loss =F.mse_loss(recons, input)
    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

    loss = recons_loss + kld_weight * kld_loss

    # print ({'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()})
    return loss

In [8]:
# Train function
def train(epoch):
    model.train()
    train_loss = 0
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)
        reconstructed, mu, logvar = model(images)
        # print(f"model output reconstructed.shape: {reconstructed.shape}")
        # print(f"input images shape output: {images.shape}")
        loss = loss_function(reconstructed, images, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % 100 == 0:
            print("Train Epoch {} [Batch {}/{}]\tLoss: {:.4f}".format(epoch, i, len(train_loader), loss.item()/len(images)))
            
    print('=====> Epoch {}, Average Loss: {:.4f}'.format(epoch, train_loss/len(train_loader.dataset)))


# Test function
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(test_loader):
            images = images.to(device)
            reconstructed, mu, logvar = model(images)
            test_loss += loss_function(reconstructed, images, mu, logvar)
            if batch_idx == 0:
                comparison = torch.cat([images[:5], reconstructed.view(batch_size, 1, 28, 28)[:5]])
                save_image(comparison.cpu(), 'results/reconstruction_' + str(epoch) + '.png', nrow = 5)

    print('=====> Average Test Loss: {:.4f}'.format(test_loss/len(test_loader.dataset)))

In [9]:
# Main function

epochs = 20

for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        # Get rid of the encoder and sample z from the gaussian ditribution and feed it to the decoder to generate samples
        sample = torch.randn(64,20).to(device)
        generated = model.decode(sample).cpu()
        save_image(generated.view(64,1,28,28), 'results/x_sample_' + str(epoch) + '.png')

Train Epoch 1 [Batch 0/469]	Loss: 0.0018
Train Epoch 1 [Batch 100/469]	Loss: 0.0003
Train Epoch 1 [Batch 200/469]	Loss: 0.0002
Train Epoch 1 [Batch 300/469]	Loss: 0.0001
Train Epoch 1 [Batch 400/469]	Loss: 0.0001
=====> Epoch 1, Average Loss: 0.0002
=====> Average Test Loss: 0.0001
Train Epoch 2 [Batch 0/469]	Loss: 0.0001
Train Epoch 2 [Batch 100/469]	Loss: 0.0001
Train Epoch 2 [Batch 200/469]	Loss: 0.0001
Train Epoch 2 [Batch 300/469]	Loss: 0.0001
Train Epoch 2 [Batch 400/469]	Loss: 0.0001
=====> Epoch 2, Average Loss: 0.0001
=====> Average Test Loss: 0.0001
