<a href="https://colab.research.google.com/github/savadikarc/vae/blob/master/conditional_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
base_path = '/content/drive/My Drive/DL_ML/VAE/Images_CVAE/'
!ls /content/drive/My\ Drive/DL_ML/VAE/

In [None]:
# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms
import torch.nn.functional as F

# NumPy, standard
import numpy as np
from scipy.stats import norm

# Visualization
import imageio
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from keras.datasets import mnist, cifar10

In [None]:
latent_size = 2 #15
hidden_size = 1024
base_filters = 8
batch_size = 2048
EPOCHS = 1000
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
class EncoderFC(nn.Module):
    
    def __init__(self, latent_size=10):
        
        super(EncoderFC, self).__init__()
        self.latent_size = latent_size
        
        self.fc = nn.Sequential(
            nn.Linear(784+10, hidden_size, bias=True),
            nn.ReLU(inplace=True)
        )
        
        self.linear_mu = nn.Linear(hidden_size, latent_size)
        self.linear_log_var = nn.Linear(hidden_size, latent_size)
        
    def forward(self, x, condition_vector):
        
        x = x.view(-1, 784)
        x = torch.cat([x, condition_vector], dim=1)
        
        features = self.fc(x)
        mu = self.linear_mu(features)
        log_var = self.linear_log_var(features)
        
        return mu, log_var

In [None]:
class DecoderFC(nn.Module):
    
    def __init__(self, latent_size=10):
        
        super(DecoderFC, self).__init__()
        
        self.latent_size = latent_size
        
        self.fc = nn.Sequential(
            nn.Linear(latent_size+10, hidden_size, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, 784)
            
        )
        
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, z):
        
        x = self.fc(z)
        
        x = self.sigmoid(x)
        
        return x.view(-1, 1, 28, 28)
        

In [None]:
def weight_init(m):
    
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        m.weight.data = nn.init.kaiming_normal_(m.weight.data)
        m.weight.data.requires_grad = True
        
        try:
            m.bias.data = torch.zeros(m.bias.data.size(), requires_grad=True)
        except AttributeError:
            pass
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data = torch.ones(m.weight.data.size(), requires_grad=True)
        m.bias.data = torch.zeros(m.bias.data.size(), requires_grad=True)

In [None]:
encoder = EncoderFC(latent_size)
encoder.apply(weight_init)
encoder = encoder.to(device)

decoder = DecoderFC(latent_size)
decoder.apply(weight_init)
decoder = decoder.to(device)

In [None]:
def train_step(batch_X, batch_y, criterion, optimizer_e, optimizer_d):
    
    optimizer_e.zero_grad()
    optimizer_d.zero_grad()
    
    condition = np.zeros((batch_y.shape[0], 10))
    condition[np.arange(batch_y.shape[0]), batch_y] = 1.
    condition_vector = torch.from_numpy(condition).float().to(device)

    x = torch.FloatTensor(batch_X).to(device)
    
    # Predictive mean and log variance
    mu, log_var = encoder(x, condition_vector)
    
    # Sample i.e. the reparameterization trick
    # var = e^log_var
    # std = sqrt(var)
    # std = e^(log(var)/2)
    _z = mu + torch.randn_like(mu) * torch.exp(log_var / 2.)
    z = torch.cat([_z, condition_vector], dim=1)
    
    # Decode the sampled vector
    x_reconstructed = decoder(z)
    
    # Recostruction loss
    # Use binary crossentropy loss
    reconstruction_loss = criterion(x_reconstructed, x)
    
    # KL divergence between prior p_theta(z) over z and posterior q_phi(z|x)
    # Appendix B: Kingma and Welling, Autoencoding Variational Bayes.
    # https://arxiv.org/abs/1312.6114
    kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    loss = (reconstruction_loss + kl_divergence) / batch_X.shape[0]
    
    loss.backward()
    
    optimizer_e.step()
    optimizer_d.step()
    
    return reconstruction_loss.item(), kl_divergence.item()

In [None]:
def train_epoch(X, y, criterion, optimizer_e, optimizer_d, batch_size=128):
    
    ptr = 0
    n_steps = X.shape[0] // batch_size + (X.shape[0]%batch_size != 0)
    
    reconstruction_loss = 0.
    kl_divergence = 0.
    for _iter in range(n_steps):
        _X, _y = X[ptr:ptr+batch_size, ...], y[ptr:ptr+batch_size]
        _reconstruction_loss, _kl_divergence = train_step(_X, _y, criterion, optimizer_e, optimizer_d)
        
        reconstruction_loss += _reconstruction_loss
        kl_divergence += _kl_divergence
        
    return reconstruction_loss / np.prod(X.shape), kl_divergence / (X.shape[0] * latent_size)

In [None]:
def val_step(batch_X, batch_y, criterion):
    
    with torch.no_grad():

        condition = np.zeros((batch_y.shape[0], 10))
        condition[np.arange(batch_y.shape[0]), batch_y] = 1.
        condition_vector = torch.from_numpy(condition).float().to(device)

        x = torch.FloatTensor(batch_X).to(device)

        # Predictive mean and log variance
        mu, log_var = encoder(x, condition_vector)

        # Sample i.e. the reparameterization trick
        # var = e^log_var
        # std = sqrt(var)
        # std = e^(log(var)/2)
        _z = mu + torch.randn_like(mu) * torch.exp(log_var / 2.)
        z = torch.cat([_z, condition_vector], dim=1)

        # Decode the sampled vector
        x_reconstructed = decoder(z)

        # Recostruction loss
        # Use binary crossentropy loss
        reconstruction_loss = criterion(x_reconstructed, x)

        # KL divergence between prior p_theta(z) over z and posterior q_phi(z|x)
        # Appendix B: Kingma and Welling, Autoencoding Variational Bayes.
        # https://arxiv.org/abs/1312.6114
        kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    return reconstruction_loss.item(), kl_divergence.item()

In [None]:
def val_epoch(X, y, criterion, batch_size=128):
    
    ptr = 0
    n_steps = X.shape[0] // batch_size + (X.shape[0]%batch_size != 0)
    
    reconstruction_loss = 0.
    kl_divergence = 0.
    for _iter in range(n_steps):
        _X, _y = X[ptr:ptr+batch_size, ...], y[ptr:ptr+batch_size]
        _reconstruction_loss, _kl_divergence = val_step(_X, _y, criterion)
        
        reconstruction_loss += _reconstruction_loss
        kl_divergence += _kl_divergence
        
    return reconstruction_loss / np.prod(X.shape), kl_divergence / (X.shape[0] * latent_size)

In [None]:
def visualize(noise):
    
    """noise: torch Tensor
    """
    
    n_images = noise.size(0)
    
    rows = 10
    cols = 8
    
    grid = np.zeros((rows*28, cols*28), dtype=np.uint8)
    
    with torch.no_grad():
        _x = decoder(noise)
        
    images = _x.cpu().numpy()
    
    ptr = 0
    for i in range(rows):
        _row = i * 28
        for j in range(cols):
            _col = j * 28
            img = images[ptr]
            if img.shape[0] == 1:
                img = np.squeeze(img, axis=0)
            else:
                img = np.transpose(img, axes=(1, 2, 0))
            img = (img * 255.).astype(np.uint8)
            grid[_row:_row+28, _col:_col+28] = img
            ptr += 1
            
    cmap = 'gray' if np.ndim(img) == 2 else None
    plt.imshow(grid, cmap=cmap)
    plt.show()

In [None]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [None]:
X_train = np.expand_dims(X_train, axis=1) / 255.
X_test = np.expand_dims(X_test, axis=1) / 255.

In [None]:
noise_vector = torch.FloatTensor(8, latent_size).normal_(0., 1.).repeat(10, 1)
condition = np.zeros((80, 10))
condition[np.arange(80), np.repeat(np.arange(10), 8)] = 1.
condition_vector = torch.from_numpy(condition).float()
fixed_noise_vector = torch.cat([noise_vector, condition_vector], dim=1)
if 'cuda' in device:
    fixed_noise_vector = fixed_noise_vector.to(device)

In [None]:
optimizer_e = optim.Adam(encoder.parameters(), lr=1e-4)
optimizer_d = optim.Adam(decoder.parameters(), lr=1e-4)

In [None]:
criterion = nn.BCELoss(reduction='sum')

In [None]:
for epoch in range(EPOCHS):

    train_indices = np.random.permutation(X_train.shape[0])
    _X_train = X_train[train_indices]
    _y_train = y_train[train_indices]

    test_indices = np.random.permutation(X_test.shape[0])
    _X_test = X_test[test_indices]
    _y_test = y_test[test_indices]
    
    train_reconstruction_loss, train_kl_divergence = train_epoch(_X_train, _y_train, criterion, optimizer_e, optimizer_d, batch_size=batch_size)
    print('Train: Epoch: {} | BCE: {:.5f} | KL Divergence: {:.5f}'.format(epoch, train_reconstruction_loss, train_kl_divergence))
    val_reconstruction_loss, val_kl_divergence = val_epoch(_X_test, _y_test, criterion, batch_size)
    print('Val: Epoch: {} | BCE: {:.5f} | KL Divergence: {:.5f}'.format(epoch, val_reconstruction_loss, val_kl_divergence))
    
    if epoch % 10 == 0:
        visualize(fixed_noise_vector)

In [None]:
dec = decoder.eval().to(device)
enc = encoder.eval().to(device)

# Jitter Experiments:
See what happens when we perturb one of the latent dimensions 

In [None]:
def perturb(X, y, steps=30, digit=0, dec=None, save=False):

    # Generate samples from the support of p(z)
    x = norm.ppf(np.linspace(0.05, 0.95, num=steps))
    y = norm.ppf(np.linspace(0.05, 0.95, num=steps))
    l_x = x.shape[0]
    l_y = y.shape[0]
    x_grid = np.repeat(x, l_y).reshape(-1, 1)
    y_grid = np.tile(y, l_x).reshape(-1, 1)
    _z = np.concatenate([x_grid, y_grid], axis=1)

    condition = np.zeros((_z.shape[0], 10))
    condition[np.arange(_z.shape[0]), digit] = 1.

    images_l = []
    for k in range(_z.shape[0]):
        _z_ = np.expand_dims(_z[k], axis=0)
        _condition = np.expand_dims(condition[k], axis=0)
        z = torch.cat([torch.from_numpy(_z_).float(), torch.from_numpy(_condition).float()], dim=1).to(device)
        images_l.append(dec(z).detach().cpu().numpy())

    
    grid = np.zeros((28*x.shape[0], 28*y.shape[0]), dtype=np.uint8)
    n_rows, n_cols = grid.shape
    
    ptr = 0
    for i, _x in enumerate(x): # -1, -1, ..0, 0, 
        for j, _y in enumerate(y): # -1, 0, 1,
        
            img = images_l[ptr][0, ...]
            _img = img * 255.
            _img = _img.astype(np.uint8)
            
            grid[n_rows-(j*28+28):n_rows-j*28, i*28:i*28+28] = _img

            ptr += 1
        
    f = plt.figure(figsize=(10, 10))
    plt.imshow(grid, cmap='gray')
    plt.axis('off')
    plt.show()

    if save:
        grid_img = Image.fromarray(grid)
        grid_img.save(base_path + 'conditional_vae_mnist_manifold_{}.png'.format(digit))
        grid_img.close()

In [None]:
for digit in range(10):
    perturb(X_test, y_test, steps=10, digit=digit, dec=decoder, save=True)

In [None]:
def plot_latent(X, y, enc, samples_per_class=10):

    fig = plt.figure(figsize=(10, 10))

    for digit in range(10):

        # Draw 10 random samples
        _X = X[y == digit]
        _y = y[y == digit]
        indices = np.random.permutation(np.arange(_X.shape[0]))
        _X = _X[indices[:samples_per_class]]
        _y = _y[indices[:samples_per_class]]

        X_sample = torch.from_numpy(_X).float().to(device)

        condition = np.zeros((samples_per_class, 10))
        condition[np.arange(samples_per_class), _y] = 1.
        condition_vector = torch.from_numpy(condition).float().to(device)

        mu, log_var = enc(X_sample, condition_vector)
        z = (mu + torch.randn_like(mu) * torch.exp(log_var / 2.)).detach().cpu().numpy()

        plt.scatter(z[:, 0], z[:, 1], c=None, alpha=0.7)
    plt.legend([_ for _ in range(10)])
    plt.xlabel('z1')
    plt.ylabel('z2')
    plt.savefig(base_path + '/conditional_vae_latent_plot.png', dpi=300)
    plt.show()


In [None]:
plot_latent(X_test, y_test, encoder, samples_per_class=300)