<a href="https://colab.research.google.com/github/yandexdataschool/MLatImperial2022/blob/master/Seminars/lab_07_02_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generative adversarial networks

Let's do our usual imports + use a tool to download the dataset we'll use today:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


import tensorflow_datasets as tfds
from tqdm import tqdm
from PIL import Image

The code below will download and preprocess the dataset we'll work with today:

In [None]:
# Load Labeled Faces in the Wild dataset

lfw = tfds.image_classification.LFW()
lfw.download_and_prepare()
ds = lfw.as_dataset()

In [None]:
def get_img(x):
    return x['image'][80:-80,80:-80]

data = np.array([
  np.array(Image.fromarray(img.numpy()).convert('L').resize((36, 36)))
  for img in tqdm(ds['train'].map(get_img))
])

Let's have a look at the shape of the dataset:

In [None]:
print("shape:", data.shape)
print("min, max:", data.min(), data.max())

So far our data has the following shape: (n_images, height, width, n_channels). PyTorch convolutional layers want the channels dimension to be the second one (axis=1), so let's transpose (and normalize) the data:

In [None]:
data = data.astype(np.float32)[:, None, :, :] / 255.

And here's a function to plot a (optionally random) subset of images:

In [None]:
def plot_mn(images, m=5, n=5, shuffle=True):
    if shuffle:
        images = images[np.random.permutation(len(images))[:m * n]]
    h, w = images.shape[2:]
    images = images[:m*n].reshape(m, n, *images.shape[1:]).transpose(0, 1, 3, 4, 2) # plotting requires channels last
    images = images.transpose(0, 2, 1, 3, 4).reshape(m * h, n * w)
    plt.imshow(images, cmap='gray')

plt.figure(figsize=(8, 8))
plot_mn(data)
plt.axis('off');

Finally, let's import torch and define the Reshape layer (same as in the convolutions notebook):

In [None]:
import torch
import torch.nn as nn

class Reshape(torch.nn.Module):
    def __init__(self, *shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(x.shape[0], *self.shape)

Now, we'll take off from a very simple generator and discriminator:

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    print("WARNING: gpu not found, the code will run on cpu")
    device = torch.device('cpu')

print(f'Device is: "{device}".')

In [None]:
latent_dims = 128
batch_size = 64

class Generator(torch.nn.Module):
    def __init__(self, latent_dims):
        super().__init__()
        self.latent_dims = latent_dims
        
        self.generator = torch.nn.Sequential(
            torch.nn.Linear(latent_dims, 64),
            torch.nn.ELU(),
            torch.nn.Linear(64, 1 * 36 * 36),
            Reshape(1, 36, 36),
            torch.nn.Sigmoid()
        )
    
    def forward(self, X):
        return self.generator(X)


class Discriminator(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.discriminator = torch.nn.Sequential(
            Reshape(1 * 36 * 36),
            torch.nn.Linear(1 * 36 * 36, 128),
            torch.nn.ELU(),
            torch.nn.Linear(128, 1),
        )
        
    def forward(self, X):
        return self.discriminator(X)

class Sampler():
    def __init__(self, data, batch_size, device):
        self.data = torch.from_numpy(data).to(device)
        self.batch_size = batch_size
        self.device = device
        
    def sample_true(self):
        ids = np.random.choice(len(self.data), size=self.batch_size)
        return self.data[ids]
    
    def sample_fake(self, G):
        noise = torch.randn(self.batch_size, G.latent_dims).to(device)
        return G(noise)
        
def get_n_params(model):
    return sum(p.reshape(-1).shape[0] for p in model.parameters())

G = Generator(latent_dims).to(device)
D = Discriminator().to(device)
sampler = Sampler(data, batch_size, device)

print('generator params:', get_n_params(G))
print('discriminator params:', get_n_params(D))

Let's have a look what we can generate before any training:

In [None]:
G.eval()

imgs = sampler.sample_fake(G).cpu().detach().numpy()
print(imgs.shape)
plt.figure(figsize=(8, 8))
plot_mn(imgs.clip(0, 1))
plt.axis('off')
plt.show();

Ok, now that we have our model defined, we need our loss functions. Historically the first loss used in GANs is the cross-entropy that we already used so many times:
$$\mathscr{L}^{\text{discr}} =
-\text{E}_{x_{real} \sim p(x)}\left[logD_{\phi}(x_{real})\right]
-\text{E}_{z \sim q(z)}\left[log(1 - D_{\phi}(G_{\theta}(z)))\right] \rightarrow \min_{\phi}
$$

And hence for the generator the loss is:

$$
\mathscr{L}^{\text{gen}} =
-\text{E}_{z \sim q(z)}\left[logD_{\phi}(G_{\theta}(z)))\right] \rightarrow \min_{\theta}
$$

Note that here $D(x)$ is the probability the discriminator assigns to $x$ to be from the real dataset, so it's $\sigma($ `discriminator` $(x))$.

Try implementing these loss functions below. Note that $1-\sigma(x)=\sigma(-x)$. You should use the `logsigmoid` as a stable realization of $log\cdot\sigma(x)$.

In [None]:
logsigmoid = torch.nn.functional.logsigmoid

def generator_loss(fake):
    return <YOUR_CODE>
  
def discriminator_loss(real, fake):
    return <YOUR_CODE>

Let's do some more set-up and run the learning process:

In [None]:
generator_losses = []
discriminator_losses = []

generator_grad_norms = []
discriminator_grad_norms = []

optimizer_G = torch.optim.RMSprop(G.parameters(), lr=0.001)
optimizer_D = torch.optim.RMSprop(D.parameters(), lr=0.001)

scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.98)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=10, gamma=0.98)

In [None]:
def calc_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    return np.sqrt(total_norm)

In [None]:
from IPython.display import clear_output

for i in range(1200):
    # Since our models are updated in turns,
    # we first set the discriminator to train,
    # while the generator is in the eval mode
    G.eval()
    D.train()
  
    # Several discriminator updates per step:
    avg_D_grad_norm = 0.0
    for j in range(5):
        # Sampling reals and fakes
        real = sampler.sample_true()
        with torch.no_grad():
            fake = sampler.sample_fake(G)

        # Calculating the loss
        loss = discriminator_loss(real, fake)

        # Doing our regular optimization step for the discriminator
        D.zero_grad()
        loss.backward()
        avg_D_grad_norm += calc_grad_norm(D)
        optimizer_D.step()
        
    
    # Remember the value of discriminator loss for plotting
    avg_D_grad_norm /= 5
    discriminator_losses.append(loss.item())
    discriminator_grad_norms.append(avg_D_grad_norm)

    # Now it's generator's time to learn:
    G.train()
    D.eval()

    fake = sampler.sample_fake(G)
    loss = generator_loss(fake)
    
    G.zero_grad()
    loss.backward()
    generator_grad_norms.append(calc_grad_norm(G))
    optimizer_G.step()
    
    generator_losses.append(loss.item())

    scheduler_D.step()
    scheduler_G.step()

    if i % 20 == 0:
        G.eval()
        imgs = sampler.sample_fake(G).cpu().detach().numpy()

        clear_output(wait=True)

        plt.figure(figsize=(16, 8))

        plt.subplot(1, 2, 1)
        plt.plot(generator_losses    , label='generator loss')
        plt.plot(discriminator_losses, label='discriminator loss')
        plt.grid()
        plt.legend()

        plt.subplot(1, 2, 2)
        plot_mn(imgs.clip(0, 1))
        plt.axis('off')
        
        plt.show();

In [None]:
plt.figure(figsize=(10, 10))
plt.plot(generator_grad_norms, label='generator grad')
plt.plot(discriminator_grad_norms, label='discriminator grad')
plt.grid()
plt.legend()
plt.show()

In [None]:
G.eval()
num_images = 10
Z = torch.linspace(-0.1, 0.2, num_images).unsqueeze(1).expand(num_images, latent_dims).cuda() # random (not random) noises

plt.figure(figsize=(25, 10))
for i in range(1, num_images + 1):
    plt.subplot(1, num_images, i)
    res = G(Z[i-1].unsqueeze(0)).squeeze().cpu().detach().numpy()
    plt.imshow(res.clip(0, 1), cmap='gray')
    plt.axis('off')

plt.show();

## Modifications

1. Convolutional Neural Network (care about zero gradients)
- Use ELU or LeakyReLU instead of ReLU
- Use LayerNorm instead of BatchNorm2d
- Decrease image size by convolutions or MeanPool2d instead of MaxPool2d
- Use Transposed Convolutions instead of Upsampling2d (Difficult + Heavier)

2. Discriminator gradient L2 penalty for real images only (see this paper: https://arxiv.org/pdf/1801.04406.pdf):

```
def discriminator_penalty(real):
    real.requires_grad = True
    scores = D(real)
    grad = torch.autograd.grad(outputs=scores.mean(), inputs=real,
        create_graph=True)[0]
    penalty = (grad**2).sum()
    return penalty
```


3. Feature Matching Loss

```
def feature_matching_loss(real, fake):
    return ((D(real)[1] - D(fake)[1])**2).mean()
```

### Bonus: Try replacing Upsample in Generator with nn.ConvTranspose2d()