We look to:

- Build a variational auto-encoder
- Use it to generate images

# Building a variational auto-encoder

In [4]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

In [5]:
class VAE(nn.Module):
    def __init__(self, in_channels, latent_dim, hidden_dims=None):
        super().__init__() # we need this so it inherits everything from the class nn
        self.latent_dim = latent_dim

        # Build encoder; it is standard for the encoder's number of channels to increase
        modules = []
        if hidden_dims is None: # define a certain number  configuration for the encoder
            hidden_dims = [32, 64, 128, 256, 512]
        
        for h_dim in hidden_dims: 
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim, kernel_size = 3, stride = 2, padding = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU()
                )
            )
            in_channels = h_dim 
        
        # the encoder consist of the list modules which is len(hidden_dims) times the sequence 
        # we defined above in nn.Sequential
        self.encoder = nn.Sequential(*modules)
        # now we will define the mu and sigma layers (remember: encoder maps our x into mu and sigma (=\psi)
        # which will be the input for our q_{\psi}(z|x) to sample z)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
        
        # Build decoder (takes as input z of dimensionality latent_dim)
        modules = []
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1]*4)
        hidden_dims.reverse() # same structure as hidden_dims but other way around
        
        for i in range(len(hidden_dims)-1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i], # this is a deconvolution layer
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )
        
        self.decoder = nn.Sequential(*modules)
        # final layer then computes tanh(something); this is the mu and we assume sigma=Identity
        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())
    
    def encode(self, input):
        """ 
        Encodes the input by passing it through the encoder network and returning z (the latent variable) in list form
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1) # only dimensions starting with start_dim are flattened
        # Split the result into mu and var; each of dimension latent_dim
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        return [mu, log_var]
        
    def decode(self, z): # this is in essense a single sample from p_{\theta}(x|z)=Gaussian(output,I)
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2) # we need this to change z to the right dimensions for conv layers
        result = self.decoder(result)
        result = self.final_layer(result)
        return result
    
    def reparametrize(self, mu, var):
        """
        does the reparametrizatrion trick
        """
        std = torch.exp(0.5 * var)
        eps = torch.randn_like(std)
        return eps * std + mu
    
    def forward(self, input):
        mu, log_var = self.encode(input)
        z = self.reparametrize(mu, log_var) # sample a random z
        return [self.decode(z), input, mu, log_var]
    
    def loss_function(self, *args):
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        recons_loss =F.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
        return recons_loss + 0.0001*kld_loss
        
    def sample(self, num_samples):
        """
        Samples from the latent space and returns the corresponding input (basically just passing through decoder)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        samples = self.decode(z)
        return samples


In [6]:
# Load data
import os
from urllib.request import urlretrieve
fullfilename = os.path.join('./data_faces/', 'celeba')
urlretrieve("https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip ", fullfilename)

('./data_faces/celeba', <http.client.HTTPMessage at 0x28163ba7880>)

In [7]:
import zipfile

with zipfile.ZipFile("./data_faces/celeba","r") as zip_ref:
    zip_ref.extractall("data_faces/")

root = 'data_faces/img_align_celeba'
img_list = os.listdir(root)
print(len(img_list))

202599


In [8]:
# Crop and scale the data
crop_size = 108
re_size = 64
offset_height = (218 - crop_size) // 2
offset_width = (178 - crop_size) // 2
crop = lambda x: x[:, offset_height:offset_height + crop_size, offset_width:offset_width + crop_size]

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Lambda(crop),
     transforms.ToPILImage(),
     transforms.Resize(size=(re_size, re_size)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)])

batch_size = 64
celeba_data = torchvision.datasets.ImageFolder('./data_faces/', transform=transform)
data_loader = torch.utils.data.DataLoader(celeba_data,batch_size=batch_size,shuffle=True)

dataiter = iter(data_loader)
img, labels = dataiter.next()
plt.imshow(np.transpose(img[0].numpy(), (1, 2, 0)))
plt.plot()

AttributeError: '_SingleProcessDataLoaderIter' object has no attribute 'next'

In [9]:
# Train the model 
epochs = 1
in_channels = 3
latent_dim = 100
vae = VAE(in_channels=in_channels, latent_dim=latent_dim)

optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)

step = 0
losses = []
for epoch in range(epochs):
    for idx, (img_batch, _) in enumerate(data_loader):
        step += 1
        mbatch_size = img_batch.size()[0]
        # forward pass 
        [a,b,c,d] = vae.forward(img_batch)
        loss = vae.loss_function(a,b,c,d)
        # Zero out the gradients
        optimizer.zero_grad()
        # Compute the gradients
        loss.backward()
        # Take the optimisation step
        optimizer.step()

        # print statistics
        losses += [loss.item()]
        if step % 10 == 0:    # print every 100 iterations
            print(step, loss.item())

print('Finished Training')

10 0.2190493941307068
20 0.1481684148311615
30 0.11184293776750565
40 0.11844523996114731
50 0.09888102114200592
60 0.08754809945821762
70 0.08022349327802658
80 0.08917193114757538
90 0.07941785454750061
100 0.08565007895231247
110 0.07814276963472366
120 0.0798088014125824
130 0.07963082194328308
140 0.07748065143823624
150 0.07393917441368103
160 0.06918950378894806
170 0.07207281142473221
180 0.067186638712883
190 0.06815145164728165
200 0.06659197062253952
210 0.061425477266311646
220 0.06252864003181458
230 0.061073508113622665
240 0.06800016015768051
250 0.06448479741811752
260 0.06115265190601349
270 0.06915315985679626
280 0.061398863792419434
290 0.05872933566570282
300 0.06793393939733505
310 0.062022771686315536
320 0.060304343700408936
330 0.05263516679406166
340 0.056093860417604446
350 0.05368071049451828
360 0.06266869604587555
370 0.060806095600128174
380 0.054120007902383804
390 0.05770733952522278
400 0.06139296293258667
410 0.05731608346104622
420 0.0573205165565013

In [None]:
# Sample 
num_samples = 10
outputs = vae.sample(num_samples)
numpy_outputs = outputs.detach().numpy()
print(numpy_outputs.shape)
for it in range(num_samples):
    plt.imshow(np.transpose(numpy_outputs[it], (1, 2, 0)))
    plt.show()