# CVAE
A Conditional Variational Autoencoder (CVAE) is a type of neural network that combines the concepts of variational autoencoders (VAEs) and conditional models. It is used for generating new data samples that are conditioned on some input data. In a CVAE, both the encoder and decoder are conditioned on additional information, which allows the model to generate more specific and controlled outputs. This makes CVAEs particularly useful for tasks such as image generation, where the generated images need to adhere to certain conditions or labels.

## Imports

In [32]:
# !pip install torch torchvision

In [33]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

## cuda setup
for gpu

In [34]:
# device = torch.device("cuda")
# kwargs = {'num_workers': 1, 'pin_memory': True} 

for cpu

In [35]:
device = torch.device("cpu")
kwargs = {'num_workers': 1, 'pin_memory': False} 


# Model Definition

In [36]:
class CVAE(nn.Module):
    def __init__(self, feature_size, latent_size, class_size):
        super(CVAE, self).__init__()
        self.feature_size = feature_size
        self.class_size = class_size

        # encode
        self.fc1  = nn.Linear(feature_size + class_size, 400)
        self.fc21 = nn.Linear(400, latent_size)
        self.fc22 = nn.Linear(400, latent_size)

        # decode
        self.fc3 = nn.Linear(latent_size + class_size, 400)
        self.fc4 = nn.Linear(400, feature_size)

        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x, c): # Q(z|x, c)
        '''
        x: (bs, feature_size)
        c: (bs, class_size)
        '''
        inputs = torch.cat([x, c], 1) # (bs, feature_size+class_size)
        h1 = self.elu(self.fc1(inputs))
        z_mu = self.fc21(h1)
        z_var = self.fc22(h1)
        return z_mu, z_var
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z, c): # P(x|z, c)
        '''
        z: (bs, latent_size)
        c: (bs, class_size)
        '''
        inputs = torch.cat([z, c], 1) # (bs, latent_size+class_size)
        h3 = self.elu(self.fc3(inputs))
        return self.sigmoid(self.fc4(h3))

    def forward(self, x, c):
        mu, logvar = self.encode(x.view(-1, 28*28), c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar


In [37]:
def one_hot(labels, class_size):
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return targets.to(device)

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        labels = one_hot(labels, 10)
        recon_batch, mu, logvar = model(data, labels)
        optimizer.zero_grad()
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.detach().cpu().numpy()
        optimizer.step()
        if batch_idx % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, labels) in enumerate(test_loader):
            data, labels = data.to(device), labels.to(device)
            labels = one_hot(labels, 10)
            recon_batch, mu, logvar = model(data, labels)
            test_loss += loss_function(recon_batch, data, mu, logvar).detach().cpu().numpy()
            if i == 0:
                n = min(data.size(0), 5)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(-1, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'reconstruction_' + str(f"{epoch:02}") + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

## Hyper Parameters

In [38]:
batch_size = 64
latent_size = 20
epochs = 30

# Load data

In [39]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=False, **kwargs)

# Main

In [40]:
model = CVAE(28*28, latent_size, 10).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            c = torch.eye(10, 10).to(device)
            sample = torch.randn(10, latent_size).to(device)
            sample = model.decode(sample, c).cpu()
            save_image(sample.view(10, 1, 28, 28),
                       'sample_' + str(f"{epoch:02}") + '.png')

====> Epoch: 1 Average loss: 144.0097
====> Test set loss: 119.3797
====> Epoch: 2 Average loss: 116.8415
====> Test set loss: 112.6678
====> Epoch: 3 Average loss: 111.7811
====> Test set loss: 109.3593
====> Epoch: 4 Average loss: 108.7166
====> Test set loss: 107.1574
====> Epoch: 5 Average loss: 106.7611
====> Test set loss: 105.7611
====> Epoch: 6 Average loss: 105.4390
====> Test set loss: 104.3345
====> Epoch: 7 Average loss: 104.3514
====> Test set loss: 103.2793
====> Epoch: 8 Average loss: 103.6150
====> Test set loss: 102.8086
====> Epoch: 9 Average loss: 103.0056
====> Test set loss: 102.2745
====> Epoch: 10 Average loss: 102.5373
====> Test set loss: 102.0384
====> Epoch: 11 Average loss: 102.0464
====> Test set loss: 101.5557
====> Epoch: 12 Average loss: 101.6635
====> Test set loss: 101.3816
====> Epoch: 13 Average loss: 101.3778
====> Test set loss: 101.1414
====> Epoch: 14 Average loss: 101.0660
====> Test set loss: 101.1256
====> Epoch: 15 Average loss: 100.8876
====

In [41]:
import numpy as np

with torch.no_grad():
    class_index = 0
    c = torch.zeros(10, 10).to(device)
    c[:, class_index] = 1
    sample = torch.randn(10, latent_size).to(device)
    sample = model.decode(sample, c).cpu()
    save_image(sample.view(10, 1, 28, 28), 'sample_class_10.png')

    # import matplotlib.pyplot as plt

    # # Print the sample tensor as an image
    # plt.figure(figsize=(10, 1))
    # for i in range(10):
    #     plt.subplot(1, 10, i + 1)
    #     plt.imshow(sample[i].view(28, 28).cpu().numpy(), cmap='gray')
    #     plt.axis('off')
    # plt.show()

In [42]:
torch.save(model.state_dict(), 'cvae_model.pth')

# Load Model

In [43]:
model = CVAE(28*28, latent_size, 10).to(device)
model.load_state_dict(torch.load('cvae_model.pth'))
model.eval()

  model.load_state_dict(torch.load('cvae_model.pth'))


CVAE(
  (fc1): Linear(in_features=794, out_features=400, bias=True)
  (fc21): Linear(in_features=400, out_features=20, bias=True)
  (fc22): Linear(in_features=400, out_features=20, bias=True)
  (fc3): Linear(in_features=30, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=784, bias=True)
  (elu): ELU(alpha=1.0)
  (sigmoid): Sigmoid()
)

In [44]:
import matplotlib.pyplot as plt

with torch.no_grad():
    class_index = 0
    c = torch.zeros(10, 10).to(device)
    c[:, class_index] = 1
    sample = torch.randn(10, latent_size).to(device)
    sample = model.decode(sample, c).cpu()
    # save_image(sample.view(10, 1, 28, 28), 'sample_class_10.png')

    # # Print the sample tensor as an image
    plt.figure(figsize=(10, 1))
    for i in range(10):
        plt.subplot(1, 10, i + 1)
        plt.imshow(sample[i].view(28, 28).cpu().numpy(), cmap='gray')
        plt.axis('off')
    plt.show()

ModuleNotFoundError: No module named 'matplotlib'