# Lab 2: Generative Adversarial Networks

## 2. InfoGAN

In this lab, we will look at a variant of GAN called InfoGAN.

Let us load a pretrained model of InfoGAN that has been trained on MNIST dataset. We are loading only the generator.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from InfoGAN import InfoGAN_Generator

model = torch.load('../data/lab2/InfoGAN_Gen.pt')
print(model)


Here, we define some functions to help us generate noise as input to the generator.

In [None]:
def gen_noise(n_instance, n_dim=2):
    """generate n-dim uniform random noise"""
    return torch.Tensor(np.random.uniform(low=-3.0, high=3.0,
                                          size=(n_instance, n_dim)))


def gen_conti_codes(n_instance, n_conti, mean=0, std=1):
    """generate gaussian continuous codes with specified mean and std"""
    codes = np.random.randn(n_instance, n_conti) * std + mean
    return torch.Tensor(codes)


def gen_discrete_code(n_instance, n_discrete, num_category=10):
    """generate discrete codes with n categories"""
    codes = []
    for i in range(n_discrete):
        code = np.zeros((n_instance, num_category))
        random_cate = np.random.randint(0, num_category, n_instance)
        code[range(n_instance), random_cate] = 1
        codes.append(code)

    codes = np.concatenate(codes, 1)
    return torch.Tensor(codes)

Let us generate some random noise vector and see the generated images. 

In [None]:
from torch.autograd import Variable
import numpy as np
batch_size = 64
noise_dim = 10 # size of the entangled noise vector
n_conti=2 # number of latent variables controlling the continuous property
n_discrete=1 # number of latent variables controlling the discrete property (class label of generated image)

num_category=10
noises = Variable(gen_noise(batch_size, n_dim=noise_dim))
conti_codes = Variable(gen_conti_codes(batch_size, n_conti,
                                                   0.0, 0.5))
discr_codes = Variable(gen_discrete_code(batch_size, n_discrete,
                                                     num_category))
print(noises.size())

We see that the final noise vector is a concatenation of the entangled noise vector, and the latent codes that we trained. 

Let us pass the noise vector through the generator to see the output.

In [None]:
gen_inputs = torch.cat((noises, conti_codes, discr_codes), 1)
fake_inputs = model(gen_inputs)

import torchvision
import matplotlib.pyplot as plt
%matplotlib inline
output = fake_inputs.data
output = torchvision.utils.make_grid(output)
output = output.permute(1,2,0)
plt.imshow(output.numpy())

For the next few experiments, we will use noise codes and not vary it from cell to cell. Run the following code to do that.

In [None]:
noises = noises[1,:]
noises = noises.repeat(64,1)


### Meaning of the discrete codes

Now, let us see what the discrete codes mean. We will initialize them in a systematic manner.

In [None]:
discr_codes = Variable(torch.eye(10))
discr_codes = discr_codes[0:8,:]
discr_codes = discr_codes.repeat(8,1)
print(discr_codes.size())

Now, run the previous step to see the images generated.

### Meaning of the continuous codes

Let us see what the continuous codes mean. In this model, two continuous latent variables are used. We will vary them both and see the result.

In [None]:
c = torch.linspace(-40,40,8)
c1 = np.repeat(c.numpy(), 8)
c2 = c.repeat(8)

c = np.transpose(np.stack((c1,c2.numpy())))
conti_codes = torch.from_numpy(c)
conti_codes = Variable(conti_codes)


The code block below will make all the discrete codes same, so that we can see the variations more clearly. 

In [None]:
discr_codes = Variable(torch.zeros(64,10))
discr_codes[:,7] = 1 

Now run the code to pass the latent variables through the generator and display the resulting images.

#### References 
<ol>
<li>Paper: Chen, Xi, et al. "Infogan: Interpretable representation learning by information maximizing generative adversarial nets." Advances in Neural Information Processing Systems. 2016.</li>
<li>Code used to train the model: https://github.com/AaronYALai/Generative_Adversarial_Networks_PyTorch/tree/master/InfoGAN </li>
</ol>