## PyTorch Conditional GAN

This kernel is a PyTorch implementation of [Conditional GAN](https://arxiv.org/abs/1411.1784), which is a GAN that allows you to choose the label of the generated image.

However, it is incomplete and we need you to complete the generator and discriminator models. Even the definition of training loop is incomplete and the `train_step` function needs to be completed in order to train the GAN.

Firstly read about Generative Adversarial Networks and try to understand the code written here. Then head over to complete the cells for [`generator` function](#1), [`discriminator` function](#2) and the [`train_step` function](#3)

In [None]:
%matplotlib inline
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch import autograd
from torch.autograd import Variable
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

 Let's start by defining a Dataset class:
* [Data Loading and Processing Tutorial on PyTorch's documentation](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)
* [torchvision](https://github.com/pytorch/vision) has a [built-in class for Fashion MNIST](https://pytorch.org/docs/stable/torchvision/datasets.html#fashion-mnist)

In [None]:
class FashionMNIST(Dataset):
    def __init__(self, transform=None):
        self.transform = transform
        fashion_df = pd.read_csv('../input/fashion-mnist_train.csv')
        self.labels = fashion_df.label.values
        self.images = fashion_df.iloc[:, 1:].values.astype('uint8').reshape(-1, 28, 28)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        label = self.labels[idx]
        img = Image.fromarray(self.images[idx])
        
        if self.transform:
            img = self.transform(img)

        return img, label

In [None]:
dataset = FashionMNIST()
dataset[0][0]

In [None]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
dataset = FashionMNIST(transform=transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

Now let's define the generator and the discriminator, which are simple MLPs. I'm going to use an embedding layer for the label:

### Discriminator
<a id='2'></a>

In [None]:
class Discriminator(nn.Module):
    #define the discriminator function here

### Generator 
<a id='1'></a>

In [2]:
class Generator(nn.Module):
    #define the generator function here

In [None]:
generator = Generator().cuda()
discriminator = Discriminator().cuda()

In [None]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)

### Training Loop 
<a id='3'></a>

In [None]:
def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion):
    #define the generator train_step fucntion here

In [None]:
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels):
    #define the discriminator train_step fucntion here

In [None]:
num_epochs = 30
n_critic = 5
display_step = 300
for epoch in range(num_epochs):
    print('Starting epoch {}...'.format(epoch))
    for i, (images, labels) in enumerate(data_loader):
        real_images = Variable(images).cuda()
        labels = Variable(labels).cuda()
        generator.train()
        batch_size = real_images.size(0)
        d_loss = discriminator_train_step(len(real_images), discriminator,
                                          generator, d_optimizer, criterion,
                                          real_images, labels)
        

        g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion)

    generator.eval()
    print('g_loss: {}, d_loss: {}'.format(g_loss, d_loss))
    z = Variable(torch.randn(9, 100)).cuda()
    labels = Variable(torch.LongTensor(np.arange(9))).cuda()
    sample_images = generator(z, labels).unsqueeze(1).data.cpu()
    grid = make_grid(sample_images, nrow=3, normalize=True).permute(1,2,0).numpy()
    plt.imshow(grid)
    plt.show()

## Results

In [None]:
z = Variable(torch.randn(100, 100)).cuda()
labels = Variable(torch.LongTensor([i for _ in range(10) for i in range(10)])).cuda()
sample_images = generator(z, labels).unsqueeze(1).data.cpu()
grid = make_grid(sample_images, nrow=10, normalize=True).permute(1,2,0).numpy()
fig, ax = plt.subplots(figsize=(15,15))
ax.imshow(grid)
_ = plt.yticks([])
_ = plt.xticks(np.arange(15, 300, 30), ['T-Shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], rotation=45, fontsize=20)
