# Variational Auto-Encoder in PyTorch & Fastai

Playing with the plain VAE and DFC-VAE as described in the following papers: 
https://arxiv.org/abs/1312.6114
https://arxiv.org/abs/1610.00291

Experimenting with TransposeConv vs Subpixel Conv upscaling methods.

Using face-aligned & cropped MS Celeb face dataset of ~200k celebrity faces.

Requires GPU device to run.

In [None]:
import os
import PIL

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
import torchvision

from fastai.basic_data import *
from fastai.data_block import *
from fastai.vision import *

%matplotlib inline
%reload_ext autoreload
%autoreload 2

torch.__version__

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE

In [None]:
DATA_PATH = 'sandro/'

In [None]:
imlist = ImageList\
            .from_folder(DATA_PATH)\
            .split_by_rand_pct(valid_pct=0.01)\
            .label_from_folder()

In [None]:
BS = 128 # batch size
SZ = 64 # images will be resized to (SZxSZ) for training.

db = imlist.transform(size=SZ, resize_method=ResizeMethod.SQUISH)\
            .databunch(bs=BS)\
            .normalize([(0.5, 0.5, 0.5),(0.5, 0.5, 0.5)])

In [None]:
# defining denormalizer for normalized images.

class DeNormalize:
    
    def __init__(self, mean, std):
        
        self.mean = mean
        self.std = std

    def __call__(self, x, inplace=False):
        
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        tensor = x if inplace else x.clone() 
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        
        return tensor
    
denorm = DeNormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

In [None]:
# testing that data loader and denorm work

x, _ = next(iter(db.train_dl))
plt.imshow(denorm(x[0]).permute(1, 2, 0))

In [None]:
db.show_batch()

In [None]:
# defining conv and transposed conv layer blocks

def get_conv(nf: int, of: int, ks: int, stride: int = 1, use_bn: bool = True):
    
    conv = nn.Conv2d(in_channels=nf, 
                  out_channels=of, 
                  kernel_size=ks, 
                  stride=stride, 
                  padding=ks // 2, 
                  bias=False)
    
    bn = nn.BatchNorm2d(of)
    
    act = nn.LeakyReLU(0.2, inplace=True)
    
    return nn.Sequential(conv, bn, act) if use_bn else nn.Sequential(conv, act)

def get_deconv(nf: int, of: int, ks: int, stride: int = 1, opad: int = 0, use_bn: bool = True):
    
    deconv = nn.ConvTranspose2d(in_channels=nf, 
                                out_channels=of, 
                                kernel_size=ks, 
                                stride=stride, 
                                padding=ks // 2,
                                output_padding=opad,
                                bias=False)
    
    bn = nn.BatchNorm2d(of)
    
    act = nn.LeakyReLU(0.2, inplace=True)
    
    return nn.Sequential(deconv, bn, act) if use_bn else nn.Sequential(deconv, act)

get_conv(3, 10, 5, 2, use_bn=False), get_deconv(10, 3, 5, 2, 0)

In [None]:
# defining model and helper methods

# copied from FastAI library.
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
    "ICNR init of `x`, with `scale` and `init` function."
    ni,nf,h,w = x.shape
    ni2 = int(ni/(scale**2))
    k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
    k = k.contiguous().view(ni2, nf, -1)
    k = k.repeat(1, 1, scale**2)
    k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
    x.data.copy_(k)

class PixelShuffleICNR(nn.Module):
    """Pixel Shuffle with ICRN initialization."""
    
    def __init__(self, ni: int, of: int = None, scale: int = 2):
        
        super().__init__()
        
        self.conv = nn.Conv2d(ni, of * scale * scale, kernel_size=1)
        #icnr(self.conv.weight)
        
        self.shuffle = nn.PixelShuffle(scale)

    def forward(self, x):
        return self.shuffle(self.conv(x))

class Unflatten(nn.Module):
    
    def __init__(self, *sizes):
        super().__init__()
        self.sizes = sizes
        
    def forward(self, x):
        return x.view(x.size(0), *self.sizes)
    
def weights_init(m):
    
    classname = m.__class__.__name__
    
    if classname.find('Conv') != -1:
        
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        
    elif classname.find('BatchNorm') != -1:
        
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def tranp_conv_decoder(nf: int, ks: int, z_dim: int):
    
    deconv1 = get_deconv(8*nf, 4*nf, ks, stride=2, opad=1)
    deconv2 = get_deconv(4*nf, 2*nf, ks, stride=2, opad=1)
    deconv3 = get_deconv(2*nf, nf, ks, stride=2, opad=1)
    deconv4 = get_deconv(nf, 3, ks, use_bn=False)
    
    return nn.Sequential(nn.Linear(z_dim, 32768),
                                Unflatten(512, 8, 8), 
                                deconv1, 
                                deconv2,
                                deconv3,
                                deconv4)

def subpixel_conv_decoder(nf: int, ks: int, z_dim: int):
    
    return nn.Sequential(nn.Linear(z_dim, 32768),
                                     Unflatten(512, 8, 8), 
                         
                                     get_conv(8*nf, 4*nf, ks, use_bn=True),
                                     PixelShuffleICNR(4*nf, 4*nf, scale=2),
                                     nn.LeakyReLU(0.2, inplace=True),
                         
                                     get_conv(4*nf, 2*nf, ks, use_bn=True),
                                     PixelShuffleICNR(2*nf, 2*nf, scale=2),
                                     nn.LeakyReLU(0.2, inplace=True),
                                     
                                     get_conv(2*nf, nf, ks, use_bn=True), 
                                     PixelShuffle_ICNR(nf, nf, scale=2),
                                     nn.LeakyReLU(0.2, inplace=True),
                         
                                     get_conv(nf, 3, ks, use_bn=False)
                                    )
    
class VAE(nn.Module):
    
    def __init__(self, z_dim: int, ks: int = 5, decoder=tranp_conv_decoder):
        
        super().__init__()
        
        nf = 64
        
        self.z_dim = z_dim
        
        conv1 = get_conv(3, nf, ks, use_bn=False)
        conv2 = get_conv(nf, 2*nf, ks, stride=2)
        conv3 = get_conv(2*nf, 4*nf, ks, stride=2)
        conv4 = get_conv(4*nf, 8*nf, ks, stride=2)
        
        self.encoder = nn.Sequential(conv1, 
                                     conv2, 
                                     conv3, 
                                     conv4, 
                                     Flatten())
        
        self.mu = nn.Linear(32768, z_dim)
        self.logvar = nn.Linear(32768, z_dim)
        
        self.decoder = decoder(nf, ks, z_dim)
        
    def encode(self, x):
        
        conv = self.encoder(x)
        
        mu, logvar = self.mu(conv), self.logvar(conv)
        
        return mu, logvar
    
    def sample_z(self, mu, logvar):
        
        eps = torch.empty(mu.size(0), self.z_dim).normal_().to(mu.device)
        return mu + eps * torch.exp(0.5 * logvar)
    
    def decode(self, z):
        
        return torch.tanh(self.decoder(z))
        
    def forward(self, x):
        
        mu, logvar = self.encode(x)
        
        z = self.sample_z(mu, logvar)
        
        return self.decode(z)

model = VAE(z_dim=100)
model.apply(weights_init)
inp = torch.empty(5, 3, 64, 64).normal_()
mu, logvar = model.encode(inp)
out = model.decode(model.sample_z(mu, logvar))
#assert inp.size() == out.size()
out.size()

In [None]:
# defining perceptual loss using pretrained VGG16 architecture.

class Hook:

    def __init__(self, m): 
        
        self.feats = None
        
        self.hook = m.register_forward_hook(self._hook_fn)
    
    def _hook_fn(self, m, inp, out): 
        self.feats = out
        
    def remove(self): 
        self.hook.remove()        

class PerceptualLoss:
    
    def __init__(self):
        
        vgg16 = models.vgg16_bn(pretrained=True)
        
        self.vgg16_head = nn.Sequential(*list(vgg16.children())[0])
        
        self.vgg16_head.to(DEVICE)
        
        # hard-coded layers used in perceptual loss
        feat_indices = [2, 5, 9]
        
        for p in self.vgg16_head.parameters():
            p.requires_grad = False
        
        self.hooks = [Hook(self.vgg16_head[i]) for i in feat_indices]
        
    def __call__(self, x_recon, x):
        
        self.vgg16_head(x_recon)
        x_recon_feats = [h.feats.clone() for h in self.hooks]
        self.vgg16_head(x)
        loss = sum([F.mse_loss(x_recon_feats[i], h.feats, reduction='sum') for i, h in enumerate(self.hooks)])
        return loss
            
    def close(self):
        for h in self.hooks:
            h.remove()
            
kacuri_loss = PerceptualLoss()
kacuri_loss.vgg16_head

In [None]:
# VAE training loop. 

L = 1 # number of trials latent vector Z is sampled.
N_EPOCHS = 10 # number of epochs
LOG_INTERVAL = 50 # interval of logging info during training.
ALPHA, BETA = 1, 0.75 # weights for KLD and Perceptual loss components.

# defining a model
model = VAE(z_dim=100, decoder=tranp_conv_decoder).to(DEVICE)

lr = 5e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

TB_LOG = False # whether to write info in tensorboard or not. 
if TB_LOG:
    writer = SummaryWriter('~/tb-logs-1')

# fetching sample from validation set to visualize reconstruction quality in the middle of training.
x_val, _ = next(iter(db.valid_dl))

model.train()
step = 1
for i_epoch in range(N_EPOCHS):
    
    train_loss = 0
    for i_batch, (x, _) in enumerate(db.train_dl):
        
        optimizer.zero_grad()
        
        mu, logvar = model.encode(x)
        
        if TB_LOG:
            writer.add_histogram('mu', mu, step)
            writer.add_histogram('var', logvar.exp(), step)
        
        # monte carlo estimation of reconstruction loss.
        recon_loss = 0
        recon_loss_mse = 0
        for l in range(L):
            
            z = model.sample_z(mu, logvar)
        
            x_recon = model.decode(z)
        
            recon_loss_mse += F.mse_loss(x_recon, x, reduction='sum')
            recon_loss += kacuri_loss(x_recon, x)
    
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        # final loss
        loss = ALPHA * kld + BETA * ((recon_loss)/L)# + 0.05 * recon_loss_mse
        
        if TB_LOG:
            writer.add_scalar('kld_loss', loss, step)
            writer.add_scalar('recon_loss', loss, step)
            writer.add_scalar('loss', loss, step)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i_batch % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                i_epoch, i_batch * x.size(0), len(db.train_ds),
                100. * i_batch / len(db.train_dl),
                loss.item() / x.size(0)))
            
            
            with torch.no_grad():
                    
                x_val_recon = model(x_val[0:8])
                    
                #if TB_LOG:
                    #for i in range(8):  
                        #writer.add_image(f'Recon_x_{i}', denorm(x_val_recon[i]), step)
                        
                # visualizing reconstruction images
                grd = make_grid(torch.cat([x_val[0:8], x_val_recon], dim=0), 8)
                grd=denorm(grd)
                plt.figure(figsize=(20, 20))
                plt.imshow(grd.permute(1, 2, 0))
                plt.show()
        
        step += 1
            

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          i_epoch, train_loss / len(db.train_ds)))
    
    # learning rate decay
    lr *= 0.75
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

if TB_LOG:        
    writer.close()

In [None]:
# saving the model

torch.save(model.state_dict(), f'model_mse_ep(5)_lr(2e-4)_bs({BS})_L(1)_')

In [None]:
# loading model

model = VAE(z_dim=100, decoder=tranp_conv_decoder).to(DEVICE)
with open(f'model_mse_ep(5)_lr(2e-4)_bs({BS})_L(1)_', 'rb') as f:
    model.load_state_dict(torch.load(f))

In [None]:
# testing reconstruction quality

model.eval()
with torch.no_grad():
    for i_batch, (x, _) in enumerate(db.valid_dl):
        mu, logvar = model.encode(x)
        
        z = model.sample_z(mu, logvar)
        
        x_recon = model.decode(z).cpu() 
        
        x = x.cpu()
        for i in range(x.size(0)):
            img = torch.cat([denorm(x[i,:,:,:]).permute(1, 2, 0), torch.ones(x.size(-1), 1, 3), denorm(x_recon[i,:,:,:]).permute(1, 2, 0)], dim=1)
            plt.imshow(img)
            plt.show()
            if i >= 100:
                break
        break

### Testing Visual Arithmetic

In [None]:
# reading ms celeb annotation data
df = pd.read_csv('list_attr_celeba.csv')

In [None]:
df.head()

In [None]:
# gets vector of special facial feature (e.g. eyeglasses, beard, etc...)
# generating P(z|x) for thousand images with given feature and averaging z's as described in original paper.

def get_mean_vec(condition):

    res = []
    for i, xx in enumerate(df[condition].image_id.values):
        try:
            img = PIL.Image.open(os.path.join(DATA_PATH, xx))
        except FileNotFoundError:
            continue

        targ_sz = resize_to(img, SZ, use_min=True)
        img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')

        x = torch.Tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0) / 255

        x -= 0.5
        x /= 0.5

        x = x.to(DEVICE)

        mu, logvar = model.encode(x)
        z = model.sample_z(mu, logvar)
        res.append(z)
        if i >= 1000: break


    z = torch.mean(torch.cat(res, dim=0), dim=0).unsqueeze(0)
    
    return z

In [None]:
# deriving smile vector
z_eye = get_mean_vec((df.Eyeglasses == 1))
z_no_eye = get_mean_vec((df.Eyeglasses == -1))

In [None]:
# vizualising eyeglass vector 1) reconstructed from mean eyeglass vector 
# 2) reconstructed from mean face no-eyeglass 3) difference betweeen 1)-3)

plt.imshow(denorm(model.decode(z_eye)[0].cpu()).permute(1, 2, 0))
plt.show()
plt.imshow(denorm(model.decode(z_no_eye)[0].cpu()).permute(1, 2, 0))
plt.show()
plt.imshow(denorm(model.decode(z_eye - z_no_eye)[0].cpu()).permute(1, 2, 0))

In [None]:
# putting eyeglassess onto selebs by adding eyeglass vector computed above.

model.eval()
with torch.no_grad():
    for i_batch, (x, _) in enumerate(db.valid_dl):
        mu, logvar = model.encode(x)
        
        z = model.sample_z(mu, logvar) + z_eye
        
        x_recon = model.decode(z).cpu() 
        
        x = x.cpu()
        for i in range(x.size(0)):
            img = torch.cat([denorm(x[i,:,:,:]).permute(1, 2, 0), torch.ones(x.size(-1), 1, 3), denorm(x_recon[i,:,:,:]).permute(1, 2, 0)], dim=1)
            plt.imshow(img)
            plt.show()
            if i >= 100:
                break
        break

### Visualizing 2D space with basis vectors being "Smile" and "Eyeglasses"

In [None]:
# mean Z vector of no smile and with eyeglasses
z_glass_smile = get_mean_vec((df.Eyeglasses == 1) & (df.Smiling == -1)).cpu()

# mean Z vector of smile and with no eyeglasses
z_glass_no_smile = get_mean_vec((df.Eyeglasses == -1) & (df.Smiling == 1)).cpu()

# mean Z vector of no smile and with no eyeglasses
z_no_glass_no_smile = get_mean_vec((df.Eyeglasses == -1) & (df.Smiling == -1)).cpu()

In [None]:
# visualizing the above vectors

plt.imshow(denorm(model.decode(z_glass_smile.cuda())[0].cpu()).permute(1, 2, 0))
plt.show()
plt.imshow(denorm(model.decode(z_glass_no_smile.cuda())[0].cpu()).permute(1, 2, 0))
plt.show()
plt.imshow(denorm(model.decode(z_no_glass_no_smile.cuda())[0].cpu()).permute(1, 2, 0))
plt.show()

In [None]:
# constructing 2D space

line = np.linspace(0, 1, 10)

# definining 2D basis vectors
I, J = z_glass_smile - z_no_glass_no_smile, z_glass_no_smile - z_no_glass_no_smile

res = []
for a in line:
    for b in line:
        z = z_no_glass_no_smile + a * I + b * J
        x_recon = denorm(model.decode(z.to(DEVICE))).cpu()
        res.append(x_recon)

In [None]:
# displaying space

grd = make_grid(torch.cat(res, dim=0), 10)
plt.figure(figsize=(20, 20))
plt.imshow(grd.permute(1, 2, 0))