In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

N_CLASSES = 10 #MNIST 

In [4]:
class Encoder(nn.Module):
    ''' This the encoder part of VAE

    '''
    def __init__(self,latent_dim,hidden_dim, dr=0.2, n_out=10):
        super(ConvNetwork, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=5)
        self.drop1 = torch.nn.Dropout2d(dr)
        self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=5)
        self.drop2 = torch.nn.Dropout2d(dr)

        self.pool = torch.nn.MaxPool2d(2, 2)
        self.relu = torch.nn.ReLU()

        self.fc1 = torch.nn.Linear(16 * 5 * 5, hidden_dim)
        self.drop = torch.nn.Dropout(dr)

        self.mu = torch.nn.Linear(hidden_dim, latent_dim, bias=False)
        self.var = torch.nn.Linear(hidden_dim, latent_dim, bias=False)

    def features(self, x):
        x = self.drop1(self.pool(self.relu(self.conv1(x))))
        x = self.drop2(self.pool(self.relu(self.conv2(x))))
        x = x.view(x.size(0), -1)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.drop(self.relu(self.fc1(x)))
        # latent parameters
        mean = self.mu(x)
        # mean is of shape [batch_size, latent_dim]
        log_var = self.var(x)
        # log_var is of shape [batch_size, latent_dim]

        return mean, log_var
    

In [6]:
class Decoder(nn.Module):
    ''' This the decoder part of VAE

    '''
    def __init__(self,latent_dim,hidden_dim, dr=0.2, n_out=10):
        super(ConvNetwork, self).__init__()
        self.conv1_t = torch.nn.ConvTranspose2d(8,3,kernel_size=5)
        self.drop1 = torch.nn.Dropout2d(dr)
        self.conv2_t = torch.nn.ConvTranspose2d(16, 8, kernel_size=5)
        self.drop2 = torch.nn.Dropout2d(dr)

        self.Unpool = torch.nn.MaxUnpool2d(2,2)
        self.relu = torch.nn.ReLU()

        self.fc1_inv = torch.nn.Linear(hidden_dim, 16 * 5 * 5)
        self.drop = torch.nn.Dropout(dr)

        self.latent2hidden = torch.nn.Linear(latent_dim, hidden_dim, bias=False)

    def reverse_features(self, x):
        x = x.reshape(16,5,5)
        x = self.conv2_t(self.relu(self.Unpool(self.drop2(x)))) 
        x = self.conv1_t(self.relu(self.Unpool(self.drop1(x)))) 
        return x

    def forward(self, x):
        x = self.relu(self.drop(self.latent2hidden(x))) 
        x = self.fc1_inv(x)
        x = self.reverse_features(x)

        return x

In [7]:
class CVAE(nn.Module):
    ''' This the VAE, which takes a encoder and decoder.

    '''
    def __init__(self, input_dim, hidden_dim, latent_dim, n_classes):
        '''
        Args:
            input_dim: A integer indicating the size of input (in case of MNIST 28 * 28).
            hidden_dim: A integer indicating the size of hidden dimension.
            latent_dim: A integer indicating the latent size.
            n_classes: A integer indicating the number of classes. (dimension of one-hot representation of labels)
        '''
        super().__init__()

        self.encoder = Encoder(latent_dim, hidden_dim)
        self.decoder = Decoder(latent_dim, hidden_dim)

    def forward(self, x, C):

        x = torch.cat((x, C), dim=1)

        # encode
        z_mu, z_var = self.encoder(x)

        # sample from the distribution having latent parameters z_mu, z_var
        # reparameterize
        std = torch.exp(z_var / 2)
        eps = torch.randn_like(std)
        x_sample = eps.mul(std).add_(z_mu)

       
        z = torch.cat((x_sample, C2), dim=1)

        # decode
        generated_x = self.decoder(z)

        return generated_x, z_mu, z_var