In [1]:
""" (InfoGAN)

From the paper:
"In this paper, we present a simple modification to the generative adversarial 
network objective that encourages it to learn interpretable and meaningful 
representations. We do so by maximizing the mutual information between a fixed 
small subset of the GAN’s noise variables and the observations, which turns out 
to be relatively straightforward. Despite its simplicity, we found our method to be
surprisingly effective: it was able to discover highly semantic and meaningful 
hidden representations on a number of image datasets: digits (MNIST), faces (CelebA), 
and house numbers (SVHN). ""

The Generator input is split into two parts: a traditional "noise" vector (z)
and a latent "code” vector (c) that targets the salient structured semantic features of 
the data distribution. These vectors are made meaningful by maximizing the mutual 
information lower bound between c and the G(z, c).

https://arxiv.org/pdf/1606.03657.pdf

"""

import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import os
import matplotlib.pyplot as plt
import numpy as np

from itertools import product
from tqdm import tqdm_notebook
from load_data import get_data

def to_cuda(x):
    """ Cuda-erize a tensor """
    if torch.cuda.is_available():
        x = x.cuda()
    return x

# Enable Jupyter notebook plotting
%matplotlib inline

# Load in binarized MNIST data, separate into data loaders
train_iter, val_iter, test_iter = get_data()


class Generator(nn.Module):
    """ Generator. Input is noise, output is a generated image. 
    """
    def __init__(self, image_size, hidden_dim, z_dim):
        super().__init__()
        self.linear = nn.Linear(z_dim, hidden_dim)
        self.generate = nn.Linear(hidden_dim, image_size)
        
    def forward(self, x):
        activated = F.relu(self.linear(x))
        generation = torch.sigmoid(self.generate(activated))
        return generation


class Discriminator(nn.Module):
    """ Discriminator. Input is an image (real or generated), output is P(generated).
    """
    def __init__(self, image_size, hidden_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(image_size, hidden_dim)
        self.discriminate = nn.Linear(hidden_dim, output_dim)     
        
    def forward(self, x):
        activated = F.relu(self.linear(x))
        discrimination = torch.sigmoid(self.discriminate(activated))
        return discrimination


class InfoGAN(nn.Module):
    """ Super class to contain both Discriminator (D) and Generator (G) 
    """
    def __init__(self, image_size, hidden_dim, z_dim, output_dim=1):
        super().__init__()
        self.G = Generator(image_size, hidden_dim, z_dim)
        self.D = Discriminator(image_size, hidden_dim, output_dim)
        
        self.z_dim = z_dim

    
class InfoGANTrainer:
    """ Object to hold data iterators, train a GAN variant 
    """
    def __init__(self, model, train_iter, val_iter, test_iter, viz=False):
        self.model = to_cuda(model)
        self.name = model.__class__.__name__
        
        self.train_iter = train_iter
        self.val_iter = val_iter
        self.test_iter = test_iter
        
        self.Glosses = []
        self.Dlosses = []
        
        self.viz = viz
            
    def train(self, num_epochs, G_lr=2e-4, D_lr=2e-4, D_steps=1):
        """ Train a vanilla GAN using the non-saturating gradients loss for the generator. 
            Logs progress using G loss, D loss, G(x), D(G(x)), visualizations of Generator output.

        Inputs:
            num_epochs: int, number of epochs to train for
            G_lr: float, learning rate for generator's Adam optimizer (default 2e-4)
            D_lr: float, learning rate for discriminator's Adam optimizer (default 2e-4)
            D_steps: int, training step ratio for how often to train D compared to G (default 1)
        """
        # Initialize optimizers
        G_optimizer = torch.optim.Adam(params=[p for p in self.model.G.parameters() if p.requires_grad], lr=G_lr)
        D_optimizer = torch.optim.Adam(params=[p for p in self.model.D.parameters() if p.requires_grad], lr=D_lr)
    
        # Approximate steps/epoch given D_steps per epoch --> roughly train in the same way as if D_step (1) == G_step (1)
        epoch_steps = int(np.ceil(len(train_iter) / (D_steps))) 
        
        # Begin training
        for epoch in tqdm_notebook(range(1, num_epochs+1)):
            self.model.train()
            G_losses, D_losses = [], []
            
            for _ in range(epoch_steps):
                
                D_step_loss = []
                
                for _ in range(D_steps): 

                    # Reshape images
                    images = self.process_batch(self.train_iter)

                    # TRAINING D: Zero out gradients for D
                    D_optimizer.zero_grad()

                    # Train the discriminator to learn to discriminate between real and generated images
                    D_loss = self.train_D(images)
                    
                    # Update parameters
                    D_loss.backward()
                    D_optimizer.step()
                    
                    # Log results, backpropagate the discriminator network
                    D_step_loss.append(D_loss.item())
                                                        
                # We report D_loss in this way so that G_loss and D_loss have the same number of entries.
                D_losses.append(np.mean(D_step_loss))
                
                # TRAINING G: Zero out gradients for G
                G_optimizer.zero_grad()

                # Train the generator to generate images that fool the discriminator
                G_loss = self.train_G(images)
                
                # Log results, update parameters
                G_losses.append(G_loss.item())
                G_loss.backward()
                G_optimizer.step()
            
            # Save progress
            self.Glosses.extend(G_losses)
            self.Dlosses.extend(D_losses)
                            
            # Progress logging
            print ("Epoch[%d/%d], G Loss: %.4f, D Loss: %.4f"
                   %(epoch, num_epochs, np.mean(G_losses), np.mean(D_losses))) 
            self.num_epochs = epoch
            
            # Visualize generator progress
            self.generate_images(epoch)
            
            if self.viz:
                plt.show()
                
    def train_D(self, images):
        """ Run 1 step of training for discriminator

        Input:
            images: batch of images (reshaped to [batch_size, 784])
        Output:
            D_loss: non-saturing loss for discriminator, 
            -E[log(D(x))] - E[log(1 - D(G(z)))]
        """    
        # Generate labels (ones indicate real images, zeros indicate generated)
        X_labels = to_cuda(torch.ones(images.shape[0], 1)) 
        G_labels = to_cuda(torch.zeros(images.shape[0], 1)) 
        
        # Classify the real batch images, get the loss for these 
        DX_score = self.model.D(images)
        DX_loss = F.binary_cross_entropy(DX_score, X_labels)
        
        # Sample noise z, generate output G(z)
        noise = self.compute_noise(images.shape[0], model.z_dim)
        G_output = self.model.G(noise)
        
        # Classify the fake batch images, get the loss for these using sigmoid cross entropy
        DG_score = self.model.D(G_output)
        DG_loss = F.binary_cross_entropy(DG_score, G_labels)
        
        # Compute vanilla (original paper) D loss
        D_loss = DX_loss + DG_loss
        
        return D_loss
    
    def train_G(self, images):
        """ Run 1 step of training for generator
        
        Input:
            images: batch of images reshaped to [batch_size, -1]    
        Output:
            G_loss: non-saturating loss for how well G(z) fools D, 
            -E[log(D(G(z)))]
        """        
        # Generate labels for the generator batch images (all 0, since they are fake)
        G_labels = to_cuda(torch.ones(images.shape[0], 1)) 
        
        # Get noise (denoted z), classify it using G, then classify the output of G using D.
        noise = self.compute_noise(images.shape[0], self.model.z_dim) # z
        G_output = self.model.G(noise) # G(z)
        DG_score = self.model.D(G_output) # D(G(z))
        
        # Compute the non-saturating loss for how D did versus the generations of G using sigmoid cross entropy
        G_loss = F.binary_cross_entropy(DG_score, G_labels)
        
        return G_loss
    
    def compute_noise(self, batch_size, z_dim):
        """ Compute random noise for the generator to learn to make images from """
        return to_cuda(torch.randn(batch_size, z_dim))
    
    def discrete_code(self, batch_size, z_dim):
        """ Randomly distributed categorical latent variables """
        c = np.zeros((batch_size, z_dim))
        categorical = np.random.randint(0, z_dim, batch_size)
        c[range(batch_size), categorical] = 1
        return to_cuda(torch.Tensor(code))
    
    def continuous_code(self, batch_size, z_dim):
        return to_cuda(torch.Tensor(np.random.randn(batch_size, z_dim) * 0.5 + 0.0))
    
    def process_batch(self, iterator):
        """ Generate a process batch to be input into the discriminator D """
        images, _ = next(iter(iterator))
        images = to_cuda(images.view(images.shape[0], -1))
        return images
    
    def generate_images(self, epoch, num_outputs=36, save=True):
        """ Visualize progress of generator learning """
        # Turn off any regularization
        self.model.eval()
        
        # Sample noise vector
        noise = self.compute_noise(num_outputs, self.model.z_dim)
        
        # Transform noise to image
        images = self.model.G(noise)
        
        # Reshape to proper image size
        images = images.view(images.shape[0], 28, 28)
        
        # Plot
        plt.close()
        size_figure_grid = int(num_outputs**0.5)
        fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
        for i, j in product(range(size_figure_grid), range(size_figure_grid)):
            ax[i,j].get_xaxis().set_visible(False)
            ax[i,j].get_yaxis().set_visible(False)
            ax[i,j].cla()
            ax[i,j].imshow(images[i+j].data.numpy(), cmap='gray') 
        
        # Save images if desired
        if save:
            outname = '../viz/' + self.name + '/'
            if not os.path.exists(outname):
                os.makedirs(outname)
            torchvision.utils.save_image(images.unsqueeze(1).data, 
                                         outname + 'reconst_%d.png'
                                         %(epoch), nrow = 5)
    
    def viz_loss(self):
        """ Visualize loss for the generator, discriminator """
        # Set style, figure size
        plt.style.use('ggplot')
        plt.rcParams["figure.figsize"] = (8,6)

        # Plot Discriminator loss in red, Generator loss in green
        plt.plot(np.linspace(1, self.num_epochs, len(self.Dlosses)), self.Dlosses, 'r')
        plt.plot(np.linspace(1, self.num_epochs, len(self.Dlosses)), self.Glosses, 'g')
        
        # Add legend, title
        plt.legend(['Discriminator', 'Generator'])
        plt.title(self.name)
        plt.show()

    def save_model(self, savepath):
        """ Save model state dictionary """
        torch.save(self.model.state_dict(), savepath)
    
    def load_model(self, loadpath):
        """ Load state dictionary into model """
        state = torch.load(loadpath)
        self.model.load_state_dict(state)

# model = InfoGAN(image_size=784, 
#               hidden_dim=256, 
#               z_dim=128)

# trainer = InfoGANTrainer(model=model, 
#                        train_iter=train_iter, 
#                        val_iter=val_iter, 
#                        test_iter=test_iter,
#                        viz=True)

# trainer.train(num_epochs=25, 
#               G_lr=2e-4, 
#               D_lr=2e-4, 
#               D_steps=1)

In [3]:
torch.from_numpy(np.random.multinomial(1, 10 * [float(1.0 / 10)],
                                          size=[10])).type(torch.FloatTensor)

tensor([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])

In [None]:
real_x = torch.FloatTensor(self.batch_size, 1, 28, 28).cuda()
label = torch.FloatTensor(self.batch_size).cuda()
dis_c = torch.FloatTensor(self.batch_size, 10).cuda()
con_c = torch.FloatTensor(self.batch_size, 2).cuda()
noise = torch.FloatTensor(self.batch_size, 62).cuda()

In [8]:
torch.Tensor(np.random.randn(5, 10) * 0.5 + 0.0)

tensor([[ 0.8167,  0.3772,  0.1451,  0.5858,  0.0409,  0.2101,  0.4110, -0.1784,
          0.2621,  1.2358],
        [-0.0499, -0.3472, -1.1722,  0.2678, -0.2321, -0.3586, -0.0119, -0.2737,
         -0.1089,  0.4942],
        [-0.3976,  0.5503,  0.0553, -0.6882,  0.7961, -0.3377, -0.1622,  0.0986,
         -0.2087, -0.4003],
        [-0.4787,  0.5330,  0.1904, -1.0048, -0.1684,  0.5088, -0.0069, -0.3866,
         -0.1516,  0.7155],
        [ 0.0871, -1.1096, -0.1513, -0.1864,  0.4398, -0.2301,  0.4090, -0.4943,
         -0.2718, -0.1424]])

In [55]:
to_cuda(torch.Tensor(np.random.randn(100, 64) * 0.5 + 0.0))

tensor([[-0.3850,  0.3367, -0.5648,  ..., -0.5966, -0.1753,  0.6119],
        [-0.4176,  0.8222,  0.3787,  ..., -0.1547,  0.1506, -0.3886],
        [-0.0974,  0.3372, -0.0439,  ...,  0.3419,  0.5537,  0.3183],
        ...,
        [ 0.4631, -0.5071,  0.2694,  ..., -0.2048, -0.5828,  0.2354],
        [ 0.3200,  0.0720,  0.2960,  ...,  0.2946,  0.2431, -0.0116],
        [-0.2538, -0.3455,  0.0388,  ...,  0.0065,  0.0180, -1.3905]])

In [53]:
def compute_discrete_c(n_size, dim):
    code = np.zeros((n_size, dim))
    random_cate = np.random.randint(0, dim, n_size)
    code[range(n_size), random_cate] = 1
    return to_cuda(torch.Tensor(code))

In [50]:
np.random.randint(0, 3, 4)

array([2, 0, 0, 2])

In [52]:
torch.Tensor(np.random.randn(n_size, dim) * 0.5 + 0.0)

tensor([[-0.4414, -0.3434, -0.0854, -0.1526, -0.8968],
        [ 0.0283, -0.6033,  0.3400,  0.4386, -0.4224],
        [ 0.4602,  0.0730, -0.0973, -0.5957, -0.2101],
        [ 0.2695, -0.0637,  0.4636,  0.6081, -0.7417],
        [ 0.2429, -0.4586,  0.0943, -0.1694,  0.1278],
        [ 0.6369, -0.5725, -0.4205, -0.2046,  0.1934],
        [-0.5997,  1.3792,  0.0808, -0.6309,  0.0621],
        [-0.6252, -0.1120,  0.3336, -0.7573,  0.2322],
        [-0.0265, -0.3666, -0.8169,  0.2990,  1.0263],
        [ 0.5395, -0.5501, -0.6658, -0.2304, -1.1266]])

In [21]:
n_size, dim = 10, 5
code = np.zeros((n_size, dim))
random_cate = np.random.randint(0, dim, n_size)
code[range(n_size), random_cate] = 1