<p>Here, in this notebook, I have tried to implement the famous <b>pix2pix</b> research paper for the <b>Image colorization</b> task with <b>additional pretraining</b> to the Generator in terms of replacing it's backbone with a resnet pretrained on ImageNet followed by additional pretraining for colorization in a supervised manner using L1 loss. This pretraining has been done to deal with the fact that we will be using a much <b>smaller dataset</b> than the original paper.

pix2pix, proposed a general solution to many image-to-image tasks in deep learning which one of those was colorization. In this approach two losses are used: L1 loss, which makes it a regression task, and an adversarial (GAN) loss, which helps to solve the problem in an unsupervised manner.</p>

**Loss Function to be optimized**

$G^*=\arg \min _G \max _D \mathcal{L}_{c G A N}(G, D)+\lambda \mathcal{L}_{L 1}(G)$

**L1 loss** 

$\mathcal{L}_{L 1}(G)=\mathbb{E}_{x, y, z}\left[\|y-G(x, z)\|_1\right]$

**GAN Loss**

$\begin{aligned} \mathcal{L}_{c G A N}(G, D)=& \mathbb{E}_{x, y}[\log D(x, y)]+\\ & \mathbb{E}_{x, z}[\log (1-D(x, G(x, z))]\end{aligned}$

* x -> grayscale image (the condition introduced)
* y -> 2 channel output of generator
* z -> input noise of generator
* G -> Generator Model
* D -> Discriminator Model

---

<h1 align = "center"> 📚 Theory </h1>

## 💡 Basics of GANs

<p>2 types of models involved, <b>a generative model and a discriminative model</b></p>

* The generative model tries to find the joint probability P(X,Y) or P(X) when there are no labels where X is the set of data instances and Y is the set of labels. Then the joint probability can be used to find P(Y|X) or P(X|Y).
* The discriminative model tries to find the conditional probability P(Y|X) directly.

<p>Discriminative models try to draw boundaries in the data space, while generative models try to model how data is placed throughout the space.</p>


## 👨‍🏭 Basic Working of GANs

<p>Both the generator and the discriminator are neural networks. The generator output is connected directly to the discriminator input. Through backpropagation, the discriminator's classification provides a signal that the generator uses to update its weights.

A generative adversarial network (GAN) has two parts:</p>

* The generator learns to generate plausible data. The generated instances become negative training examples for the discriminator.
* The discriminator learns to distinguish the generator's fake data from real data. The discriminator penalizes the generator for producing implausible results.

## ✔ The Discriminator

<p>The discriminator in a GAN is simply a classifier. It tries to distinguish real data from the data created by the generator. It could use any network architecture appropriate to the type of data it's classifying.</p> 

![](https://developers.google.com/static/machine-learning/gan/images/gan_diagram_discriminator.svg)

### Training process for the Discriminator

* The Training data for the discriminator include Real data (positive examples) and Fake data (negative examples generated by the generator). 

* During discriminator training the generator does not train. Its weights remain constant while it produces examples for the discriminator to train on. The discriminator connects to two loss functions. During discriminator training, the discriminator ignores the generator loss and just uses the discriminator loss.

1. The discriminator classifies both real data and fake data from the generator.
2. The discriminator loss penalizes the discriminator for misclassifying a real instance as fake or a fake instance as real.
3. The discriminator updates its weights through backpropagation from the discriminator loss through the discriminator network.

## ⚗ The Generator

<p>The generator part of a GAN learns to create fake data by incorporating feedback from the discriminator. It learns to make the discriminator classify its output as real.</p>

![](https://developers.google.com/static/machine-learning/gan/images/gan_diagram_generator.svg)

The generator training includes:
* Random input noise to the generator.
* The generator network that generates the new data instances.
* The discriminator network the classifies the generator output.
* The generator loss that penalizes the generator for failing to fool the generator.


The Steps involved are:
1. Sample Random Noise 
2. Produce generaator output from the sampled random noise
3. Get the discriminator classification for the generator output
4. Calculate loss from discriminator output.
5. Backpropogate through both discriminator and generator and obtain the gradients, only updating the generator weights.

## GAN Training

<p>The GAN has to juggle the training of 2 networks and the convergence is hard to identify.

The training proceeds in the following manner:</p>
1. The discriminator trains for one or more epochs
2. The generator trains for one or more epochs
3. Repeat steps 1 and 2

## 📉 Loss Functions

<p>GANs try to replicate a probability distribution. They should therefore use loss functions that reflect the distance between the distribution of the data generated by the GAN and the distribution of the real data.</p>

### Minimax Loss

In the paper that introduced GANs, the generator tries to minimize the following function while the discriminator tries to maximize it:

$\begin{aligned} \mathcal{L}_{G A N}(G, D)=& \mathbb{E}_{x}[\log D(x)]+\\ & \mathbb{E}_{z}[\log (1-D(G(z))]\end{aligned}$

In this function:

* D(x) is the discriminator's estimate of the probability that real data instance x is real.
* Ex is the expected value over all real data instances.
* G(z) is the generator's output when given noise z.
* D(G(z)) is the discriminator's estimate of the probability that a fake instance is real.
* Ez is the expected value over all random inputs to the generator (in effect, the expected value over all generated fake instances G(z)).

The original GAN paper notes that the above minimax loss function can cause the GAN to get stuck in the early stages of GAN training when the discriminator's job is very easy. The paper therefore suggests modifying the generator loss so that the generator tries to maximize log D(G(z)). (**Modified Minimax Loss**)

There are other loss functions too like the **Wasserstein Loss**.

## 🥼 LAB color space

LAB is a color space just like RGB but here the 3 dimensions represent Lightness(L), Green-Redness (a) and Yellow-blueness (b) of each pixel

The main advantages of using this color space over RGB for the Image Colorization task are:
* Here, the L channel can straight away be used as the grayscale input of the image.
* The model only needs to output 2 channels.
* But if you use RGB, you have to first convert your image to grayscale, feed the grayscale image to the model and hope it will predict 3 numbers for you which is a way more difficult and unstable task due to the many more possible combinations of 3 numbers compared to two numbers.

---

In [1]:
# !pip install --upgrade torch torchvision

# 📃 Config Class

In [2]:
class Config:
    external_data_siae = 10000
    train_size = 8000
    image_size_1 = 256
    image_size_2 = 256
    batch_size = 32
    LeakyReLU_slope = 0.2
    dropout = 0.5
    kernel_size = 4
    stride = 2
    padding =1
    gen_lr = 2e-4
    disc_lr = 2e-4
    beta1 = 0.5
    beta2 = 0.999
    lambda_l1 = 100
    gan_mode = 'vanilla'
    layers_to_cut = -2
    epochs = 20
    pretrain_lr = 1e-4

<h1>📦 Importing Packages</h1>

In [3]:
import os
from pathlib import Path
import glob
import time
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm

import PIL
from PIL import Image
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.models.resnet import resnet18
from torchvision.models.vgg import vgg19
from torch.utils.data import DataLoader, Dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## FastAI imports + Loading in the Data

The data is a small sample of the COCO dataset for object detection and is loaded in through the external data functionality of fastai. The actual paper uses the entire ImageNet dataset while we only use 8000 training images.

In [4]:
from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet
from fastai.data.external import untar_data, URLs

path = untar_data(URLs.COCO_SAMPLE)
path

Path('/home/tim/.fastai/data/coco_sample')