In [None]:
!nvidia-smi

# **Intro to Generative Adversarial Networks (GANs)**

Generative adversarial networks (GANs) are algorithmic architectures that use two neural networks, compitting one against the other (thus the “adversarial”) in order to generate new, synthetic instances of data that can pass for real data. They are used widely in image generation, video generation and voice generation.

GANs were introduced in [a paper by Ian Goodfellow](https://arxiv.org/abs/1406.2661) and other researchers at the University of Montreal, including Yoshua Bengio, in 2014. Referring to GANs, Facebook’s AI research director Yann LeCun called adversarial training “the most interesting idea in the last 10 years in ML.”


## **Some cool demos**:
* Progress over the last several years, from [Ian Goodfellow tweet](https://twitter.com/goodfellow_ian/status/1084973596236144640)

<img src='http://drive.google.com/uc?export=view&id=1PSfze4ZHgAn4BAjLuZhqAZO_HJQ1NEHX' width=1000 height=350/>


A generative adversarial network (GAN) has two parts:

* The **generator** learns to generate plausible data. The generated instances become negative training examples for the discriminator.
* The **discriminator** learns to distinguish the generator's fake data from real data. The discriminator penalizes the generator for producing implausible results.


When training begins, the generator produces obviously fake data, and the discriminator quickly learns to tell that it's fake:

<img src='http://drive.google.com/uc?export=view&id=1Auxzsi3395vL0K80GfYlAEvWufTMTZ59' width=1000 height=350/>


# **<font color='Darkblue'>Import Required Libraries:</font>**

In [None]:
from __future__ import print_function
#%matplotlib inline
import random
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision import datasets # Training dataset
from torchvision.utils import make_grid
from torchvision import utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import numpy as np

import matplotlib.animation as animation
from IPython.display import HTML


# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("My device: => ", device)

# Set random seed for reproducibility
my_seed = 123
random.seed(my_seed)
torch.manual_seed(my_seed);

# **Fashion-MNIST Dataset:**

`Fashion-MNIST` is a dataset of [Zalando](https://jobs.zalando.com/en/tech/?gh_src=281f2ef41us)'s article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. We intend Fashion-MNIST to serve as a direct drop-in replacement for the original [MNIST dataset](http://yann.lecun.com/exdb/mnist/) for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits.

<img src='https://raw.githubusercontent.com/zalandoresearch/fashion-mnist/master/doc/img/fashion-mnist-sprite.png' width=1000 height=700/>



In [None]:
batch_size = 128

transform = transforms.Compose([transforms.ToTensor()])

data_train = datasets.FashionMNIST('./data', download=True, train=True, transform=transform)
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)


In [None]:
classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']

dataiter = iter(train_loader)
images, labels = dataiter.next()
images_arr = []
labels_arr = []

for i in range(0, 30):
  images_arr.append(images[i].unsqueeze(0))
  labels_arr.append(labels[i].item())

fig = plt.figure(figsize=(25, 10))
for i in range(30):
  ax = fig. add_subplot(3, 10, i+1, xticks=[], yticks=[])
  ax.imshow(images_arr[i].resize_(1, 28, 28).numpy().squeeze(), cmap='gray')
  ax.set_title("{}".format(classes[labels_arr[i]]), color=("blue"))


# **<font color='darkorange'>Generator Part:</font>**



The generator part of a GAN learns to create fake data by incorporating feedback from the discriminator. It learns to make the discriminator classify its output as real.

Generator training requires tighter integration between the generator and the discriminator than discriminator training requires. The portion of the GAN that trains the generator includes:



*   random input
*   generator network, which transforms the random input into a data instance
*   discriminator network, which classifies the generated data
*   discriminator output
*   generator loss, which penalizes the generator for failing to fool the discriminator






<img src='http://drive.google.com/uc?export=view&id=1dbk5FmAHE3LHwspYm8qxL-qHBLXBq29i' width=1000 height=350/>

### **Generator Block:**

In [None]:
def get_generator_block(input_dim, output_dim):
  seq = nn.Sequential(
      nn.Linear(input_dim, output_dim),
      nn.BatchNorm1d(output_dim),
      nn.LeakyReLU(negative_slope=0.2, inplace=False),
      nn.Dropout(0.3),
  )
  return seq

### **Generator Class:**

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim=10, img_dim=28*28, hidden_dim=128):
    super(Generator, self).__init__()

    self.gen = nn.Sequential(
        get_generator_block(z_dim, hidden_dim),
        get_generator_block(hidden_dim, hidden_dim * 2),
        get_generator_block(hidden_dim * 2, hidden_dim * 4),
        get_generator_block(hidden_dim * 4, hidden_dim * 8),
        nn.Linear(hidden_dim * 8, img_dim),
        nn.Sigmoid(),
    )

  def forward(self, noise):
    gen_output = self.gen(noise)
    return gen_output


      

In [None]:
# Generate Noise:

def get_generator_noise(n_sample, z_dim, device='cpu'):
  my_noise = torch.randn(n_sample, z_dim, device=device)
  return my_noise

# **<font color='darkorange'>Discriminator Part:</font>**



The discriminator in a GAN is simply a classifier. It tries to distinguish real data from the data created by the generator. It could use any network architecture appropriate to the type of data it's classifying.

The discriminator's training data comes from two sources:

* **Real data** instances, such as real pictures of people. The discriminator uses these instances as positive examples during training.
* **Fake data** instances created by the generator. The discriminator uses these instances as negative examples during training.

<img src='http://drive.google.com/uc?export=view&id=1A3_gYqcPORqXFio1wNHAsc8ZndY3zpIP' width=1000 height=350/>

In [None]:
# Discriminator Block

def get_discriminator_block(input_dim, output_dim):
  seq = nn.Sequential(
      nn.Linear(input_dim, output_dim),
      nn.LeakyReLU(negative_slope=0.2, inplace=False),
  )
  return seq

<img src='https://miro.medium.com/max/1400/1*siH_yCvYJ9rqWSUYeDBiRA.png' width=800 height=400/>

In [None]:
# Discriminator Class:

class Discriminator(nn.Module):
  def __init__(self, img_dim=28*28, hidden_dim=128):
    super(Discriminator, self).__init__()

    self.disc = nn.Sequential(
        get_discriminator_block(img_dim, hidden_dim * 4),
        get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
        get_discriminator_block(hidden_dim * 2, hidden_dim),
        nn.Linear(hidden_dim, 1),
    )

  def forward(self, image):
    state = self.disc(image)
    return state


# **<font color='deepskyblue'>Training Process:</font>**


Because a GAN contains two separately trained networks, its training algorithm must address two complications:

* GANs must juggle two different kinds of training (generator and discriminator).
* GAN convergence is hard to identify.

### **Set Hyperparameters:**

In [None]:
# Set your parameters

criterion = nn.BCEWithLogitsLoss()
num_epochs = 51
z_dim = 64
display_step = 100
lr = 0.0001

size = (1, 28, 28)

device = 'cuda'

In [None]:
# Generator:
generator = Generator(z_dim).to(device)
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)

# Discriminator:
discriminator = Discriminator().to(device)
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

In [None]:
# Discriminator Loss:

def get_discriminator_loss(gen, disc, criterion, real, num_images, z_dim, device):
  noise = get_generator_noise(num_images, z_dim, device=device)
  gen_output = gen(noise)
  disc_out_fake = disc(gen_output.detach())
  disc_loss_fake = criterion(disc_out_fake, torch.zeros_like(disc_out_fake))
  disc_out_real = disc(real)
  disc_loss_real = criterion(disc_out_real, torch.ones_like(disc_out_real))

  disc_loss = (disc_loss_fake + disc_loss_real) / 2
  return disc_loss

In [None]:
# Generator Loss:
def get_generator_loss(gen, disc, criterion, num_images, z_dim, device):
  noise = get_generator_noise(num_images, z_dim, device=device)
  gen_output = gen(noise)
  disc_preds = disc(gen_output) # gen_output.detach()
  gen_loss = criterion(disc_preds, torch.ones_like(disc_preds))

  return gen_loss

In [None]:
# Show Images Function:

def show_tensor_images(real, fake, num_images=25, size=(1, 28, 28)):
  plt.figure(figsize=(15,15))
  image_unflat_real = real.detach().cpu().view(-1, *size)
  image_grid_real = make_grid(image_unflat_real[:num_images], nrow=5, normalize=True, padding=2)
  plt.subplot(1,2,1)
  plt.axis("off")
  plt.title("Real Images")
  plt.imshow(image_grid_real.permute(1, 2, 0).squeeze())

  image_unflat_fake = fake.detach().cpu().view(-1, *size)
  image_grid_fake = make_grid(image_unflat_fake[:num_images], nrow=5, normalize=True, padding=2)
  plt.subplot(1,2,2)
  plt.axis("off")
  plt.title("Fake Images")
  plt.imshow(image_grid_fake.permute(1, 2, 0).squeeze())
  plt.show()

In [None]:
# Training Loop

img_list = []
G_losses = []
D_losses = []

iters = 0
cur_step = 0
img_show = 3

mean_generator_loss = 0
mean_discriminator_loss = 0

for epoch in range(num_epochs):

  for real, _ in tqdm(train_loader):
    cur_batch_size = len(real)

    real = real.view(cur_batch_size, -1).to(device)

    disc_optimizer.zero_grad()

    disc_loss = get_discriminator_loss(generator, discriminator, criterion, real, cur_batch_size, z_dim, device)

    disc_loss.backward()

    disc_optimizer.step()


    gen_optimizer.zero_grad()
    gen_loss = get_generator_loss(generator, discriminator, criterion, cur_batch_size, z_dim, device)
    gen_loss.backward()
    gen_optimizer.step()

    mean_discriminator_loss += disc_loss.item() / cur_batch_size
    mean_generator_loss += gen_loss.item() / cur_batch_size

    G_losses.append(mean_discriminator_loss)
    D_losses.append(mean_generator_loss)


    if cur_step % display_step == 0 and cur_step >= 0:
      print(f"[Epoch: {epoch}/{num_epochs}] | [Step: {cur_step}/{num_epochs*len(train_loader)}], Generator Loss: {mean_generator_loss}, Discriminator Loss: {mean_discriminator_loss}")
      fake_noise = get_generator_noise(cur_batch_size, z_dim, device=device)
      fake = generator(fake_noise)

      img_list.append(make_grid(fake.detach().cpu().view(-1, *size)[:36], nrow=6, normalize=True, padding=2))
      mean_discriminator_loss = 0
      mean_generator_loss = 0

    cur_step += 1

  if epoch % img_show == 0:
    fake_noise = get_generator_noise(cur_batch_size, z_dim, device=device)
    fake = generator(fake_noise)
    show_tensor_images(real, fake)

# **Visualization:**

In [None]:
plt.figure(figsize=(20, 7))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("steps")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
imgs = [[plt.imshow(np.transpose(img, (1,2,0)), animated=True)] for img in img_list]
anim = animation.ArtistAnimation(fig, imgs, interval=100, repeat_delay=1000, blit=True)

HTML(anim.to_jshtml())