# Generating Faces with CVAE in PyTorch [TRAIN]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from torchvision import datasets, transforms, models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.utils import save_image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import math
from PIL import Image
from IPython.display import display
import glob

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
transformObj = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
])

In [None]:
dataroot = "../input/celeba-dataset/img_align_celeba/"

dataset = datasets.ImageFolder(root=dataroot, transform=transformObj)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_size=100):
        super(VAE, self).__init__()
        
        self.latent_size = latent_size
        
        self.l1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.l1b = nn.BatchNorm2d(32)
        self.l2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.l2b = nn.BatchNorm2d(64)
        self.l3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1)
        self.l3b = nn.BatchNorm2d(128)
        self.l4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.l4b = nn.BatchNorm2d(256)
        
        self.l41 = nn.Linear(256*4*4, self.latent_size)
        self.l42 = nn.Linear(256*4*4, self.latent_size)
        
        self.f = nn.Linear(self.latent_size, 256*4*4)
        
        self.l5 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1)
        self.l6 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.l7 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.l8 = nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=4, stride=2, padding=1)
        
    def encoder(self, x_in):
        h = F.leaky_relu(self.l1b(self.l1(x_in)))
        h = F.leaky_relu(self.l2b(self.l2(h)))
        h = F.leaky_relu(self.l3b(self.l3(h)))
        h = F.leaky_relu(self.l4b(self.l4(h)))
        
        h = h.view(h.size(0), -1)
        
        return self.l41(h), self.l42(h)
    
    def decoder(self, z):
        z = self.f(z)
        z = z.view(-1, 256, 4, 4)
        
        z = F.leaky_relu(self.l5(z))
        z = F.leaky_relu(self.l6(z))
        z = F.leaky_relu(self.l7(z))
        z = torch.sigmoid(self.l8(z))
        
        return z
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return torch.add(eps.mul(std), mu)
    
    def forward(self, x_in):
        mu, log_var = self.encoder(x_in)
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [None]:
vae = VAE()
    
vae.to(device)

In [None]:
optimizer = optim.Adam(vae.parameters(), lr=0.0005)

def loss_function(recon_x, x, mu, log_var):
    
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    #MSL = F.mse_loss(recon_x, x)
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL Divergence from MIT 6.S191
    #return (MSL + KLD)
    
    return (BCE + KLD)

In [None]:
def train(epoch):
    vae.train()

    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(device)
        optimizer.zero_grad()
        
        r_batch, mu, log_var = vae(data)

        loss = loss_function(r_batch, data, mu, log_var)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx%2500==0:
            print("Batch no. finished in Epoch: ", batch_idx)
    print("-------------------------------------------------")
    print('Epoch: {} Train mean loss: {:.8f}'.format(epoch, train_loss / len(dataloader.dataset)))
    print("-------------------------------------------------")
    return train_loss

In [None]:
n_epoches = 8

loss_hist = []

for epoch in range(1, n_epoches+1):
    loss_epoch = train(epoch)
    loss_hist.append(loss_epoch)

In [None]:
with torch.no_grad():
    counter = 0
    for i in range(100): 
        counter += 1
        z = (torch.rand(100)*2).to(device)
        sample = vae.decoder(z).to(device)
        save_image(sample.view(3, 64, 64), './sample' + str(counter) + '.png')

In [None]:
plt.plot(loss_hist)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

In [None]:
for img in glob.glob("*.png"):
    display(img)