**This notebook have been used for the cGAN training**


In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 1 Data collection

Two datasets are used: a small version of COCO dataset with 21,837 images and one with 17,178 images of animals (12 categories)

##1.1 Animals dataset

We download this dataset from kaggle (1.4 GB)

In [1]:
!pip install -q kaggle
from google.colab import files

You have to upload a file called kaggle.json. To obtain it you need to follow the first 2 steps described in https://www.kaggle.com/general/74235

In [None]:
files.upload()

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! kaggle datasets list

In [None]:
!kaggle datasets download -d piyushkumar18/animal-image-classification-dataset

The data have been downloaded. To unzip them

In [5]:
!mkdir /content/animal_data
!unzip -qq /content/animal-image-classification-dataset.zip -d /content/animal_data/

## 1.2 COCO dataset

To download it we use fastai

In [None]:
!pip install fastai==2.4;

In [7]:
from fastai.data.external import untar_data, URLs
import os
import glob
import numpy as np

In [None]:
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"

paths = glob.glob(coco_path+"/*.jpg")
paths =np.array(paths)
num_images_coco = len(paths)
print(f"# coco images: {num_images_coco}")

Uplaod either data_small_training.txt or data_big_training.txt

In [None]:
files.upload();

In [10]:
filename = "data_small_training.txt"

def read_lines(path):

  lines = None

  with open(path) as file:
    lines = [line.rstrip() for line in file]

  return lines

In [11]:
training_paths = read_lines(filename)
print(f"{len(training_paths)} images for training")

9600 images for training


# 2 Datasets and Dataloaders

In [12]:
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

np.random.seed(123)

## 2.1 Training Dataset

In [13]:
SIZE = 256

train_transform = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  transforms.InterpolationMode.BILINEAR),
                transforms.RandomHorizontalFlip(),
            ])

In [14]:
class GrayToColorDataset(Dataset):

  def __init__(self, paths, transform = None):
    
    self.paths = paths
    self.transform = transform

  def __len__(self):

    return len(self.paths)

  def __getitem__(self, idx):

    img_rgb = Image.open(self.paths[idx]).convert("RGB")
    img_rgb = self.transform(img_rgb)
    img_rgb = np.array(img_rgb)

    #RGB -> Lab
    img_lab = rgb2lab(img_rgb).astype("float32")
    img_lab = transforms.ToTensor()(img_lab)

    #to have values in range [-1,1]
    L = img_lab[[0],:]/50. - 1.
    ab = img_lab[[1,2],:] / 110.

    return (L,ab)


In [15]:
train_dataset = GrayToColorDataset(training_paths, train_transform)

In [16]:

PIN_MEMORY = True
N_WORKERS = 2
BATCH_SIZE = 8

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=N_WORKERS,
                            pin_memory=PIN_MEMORY, shuffle = True)

# 3 cGAN models

## 3.1 Generator: U-Net

In [17]:
class UNetDown(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size = 4, normalization_type = None, dropout = 0.0, activation = None):

    super(UNetDown, self).__init__()

    #if batchnorm/instancenorm used, bias not used

    use_bias = normalization_type == None
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size, 2, 1, bias = use_bias)]

    if not use_bias:
      
      if normalization_type == "instance":

        layers.append(nn.InstanceNorm2d(out_channels))

      else:

        layers.append( nn.BatchNorm2d(out_channels))
        
    if activation == None:
      layers.append(nn.LeakyReLU(negative_slope = 0.2))

    if activation == "ReLU":

      layers.append(nn.ReLU())

    if dropout:

      layers.append(nn.Dropout(p = dropout))

    self.model = nn.Sequential(*layers)


  def forward(self, x):

    return self.model(x)


In [18]:
class UNetUp(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size = 4,  normalization_type = None, dropout = 0.0):

    super(UNetUp, self).__init__()

    use_bias = normalization_type == None

    layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 2, 1, bias = use_bias)]

    if not use_bias:
      if normalization_type == "instance":

        layers.append(nn.InstanceNorm2d(out_channels))

      else:

        layers.append( nn.BatchNorm2d(out_channels))

    layers.append(nn.ReLU())

    if dropout:

      layers.append(nn.Dropout(p = dropout))

    self.model = nn.Sequential(*layers)


  def forward(self, x, skip = None):
      x = self.model(x)
      if skip is not None:

        x = torch.cat((skip, x), 1)

      return x

In [19]:
class GeneratorUNet(nn.Module):

  def __init__(self, in_channels = 1, out_channels = 2, num_down = 8, ngf = 64, normalization_type = None):

    super(GeneratorUNet, self).__init__()

    self.downs = nn.ModuleList()
    self.ups = nn.ModuleList()
    

    features =[ngf]

    for i in range(3):

      features.append(features[i]*2)

    features.append(features[-1])
    #64, 128, 256, 512, 512

    if num_down > 5:

      features += [ngf * 8 for i in range(num_down - 5)]
    #for num_down = 8: 64, 128, 256, 512, 512, 512, 512, 512 (->1x1 for input size 256x256)


    #ENCODER (CONTRACTING PATH)

    #outermost down block: no normalization and no dropout, only downconv
    self.downs.append(UNetDown(in_channels, ngf, 4))

    in_channels = ngf #new in_channels for the next down-block
    
    for i,n_features in enumerate(features[1:len(features)-1]):
      #no dropout
      self.downs.append(UNetDown(in_channels, n_features, 4, normalization_type, 0.0))
      in_channels = n_features

    
    #innermost down block: no normalization and no dropout, only downconv
    self.downs.append(UNetDown(in_channels, features[-1], 4, activation = "ReLU"))
    

    #DECODER (EXPANSIVE PATH)
    i_channels = in_channels
    for i, n_features in enumerate((features[-2::-1])):
      
      
      #if i == 0, innermost(bottleneck), namely a block such that after down we go up. no dropout
      i_channels = in_channels if i == 0  else i_channels * 2

      #no dropout for the first up and the last 4 ups 
      dropout = 0.0 if (i == 0 or i  > 3) else 0.5

      self.ups.append(UNetUp(i_channels, n_features, 4, normalization_type, dropout))
      i_channels = n_features
    
    
    self.final = nn.Sequential(
        nn.ConvTranspose2d(ngf*2,out_channels, kernel_size=4, stride=2, padding=1),
        nn.Tanh()
    )



  def forward(self, x):

    skip_connections = list()

    #encoder
    for down in self.downs:

      x = down(x)
      skip_connections.append(x)

    #decoder with skip connections
    for i, up in enumerate(self.ups):
      
      x = up(x, skip_connections[-i-2])

    return self.final(x)

## 3.2 Discrimintor: PatchGAN

The descriminator is a PatchGAN for $N \times N$ patches where $N=70$: given an input $256 \times 256$ the output is $30 \times 30$

In [20]:
class PatchDiscriminator(nn.Module):

  def __init__(self, in_channels = 3, ndf = 64, n_down = 5, normalization_type = "batchnorm"):

    super(PatchDiscriminator, self).__init__()

    features = [ndf * 2**i for i in range(n_down-1)]

    layers = []

    for i in range(len(features)):
      use_bias = True if i < 1  else False
      stride = 2 if i < (len(features)-1) else 1
      layers.append(nn.Conv2d(in_channels, features[i], 4, stride, 1, bias = use_bias))

      if not use_bias:
        if normalization_type == "batchnorm":
          layers.append(nn.BatchNorm2d(features[i]))
          
        if normalization_type == "instance":
          layers.append(nn.InstanceNorm2d(features[i]))


      layers.append(nn.LeakyReLU(0.2))

      in_channels = features[i]
    
    layers.append(nn.Conv2d(in_channels, 1, 4, 1, 1))

    self.model = nn.Sequential(*layers)

  def forward(self, x):

    return self.model(x)

# 4 GAN LOSS

The following class allows to implement the GAN loss: for the discriminator 
\begin{equation}
\mathbb{E}_{x,y}[\log D(x,y)]+\mathbb{E}_{x,z}[\log(1-D(x,G(z, x)))]
\end{equation}

For the generator instead

\begin{equation}
\mathbb{E}_{x,z}[\log D(x,G(z,x))]
\end{equation}

In [21]:
class GANLoss():

  def __init__(self, device):

    self.criteria = nn.BCEWithLogitsLoss()
    self.real = 1.
    self.fake = 0.
    self.device = device

  def __call__(self, input, label_type):
    
    label = torch.tensor(self.real if label_type else self.fake)
    
    labels = label.expand_as(input).to(self.device)
    
    return self.criteria(input, labels)

# 5 Models initialization

Initialization generator and discriminator

In [None]:
#G initialized Generator
#D initialized Discriminator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = GeneratorUNet(1,2,8,64, "batchnorm").to(device)
#G = Unet(input_c=1, output_c=2, n_down=8, num_filters=64).to(device)
D = PatchDiscriminator().to(device)

def weights_init(m):

    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        #nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.xavier_uniform_(m.weight.data)

        if hasattr(m, 'bias') and m.bias is not None:
          nn.init.constant_(m.bias.data, 0.0)
    
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

G.apply(weights_init);
D.apply(weights_init);

# 6 Training setup

Here we define the losses (GAN loss and L1), the optimizer, the number of epochs and the hyperparameters

## 6.1 Losses

In [23]:
GAN_loss = GANLoss(device) 
L1_loss = nn.L1Loss()

#L1 hyperparam

lamb = 100

## 6.2 Optimizers

In [24]:
#params for Adam
lr_G = 2e-4 
lr_D = 2e-4

betas = (0.5, 0.999)

G_opt = optim.Adam(G.parameters(), lr=lr_G, betas=betas)
D_opt = optim.Adam(D.parameters(), lr=lr_D, betas=betas)

# 7 Utility functions

In [25]:
def convert_lab_to_rgb(L, ab):

  """
  Provided a Lab image or a batch of Lab images, it returns it/them in RGB format 
  input:
    - L: torch.tensor
    - ab: torch.tensor
  
  output:
    - img: numpy.ndarray (the rgb images)
  """

  #check shape (one image or a batch)

  is_batch = len(ab.shape) > 3
  
  L = (L+1.)*50.
  ab = ab*110.

  if is_batch:
    # input tensors: N x 1 x 256 x 256, N x 2 x 256 x 256
    Lab_images = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().detach().numpy()
  else:
    # input tensors: 1 x 256 x 256, 2 x 256 x 256
    Lab_image = torch.cat([L, ab], dim=0).permute(1, 2, 0).cpu().detach().numpy()
    return lab2rgb(Lab_image)

  rgb_images = list()

  for image in Lab_images:

    img_rgb = lab2rgb(image)
    rgb_images.append(img_rgb)

  return np.stack(rgb_images, axis=0)

In [26]:
def show_results(Ls, real_abs, fake_abs):

  """
  provided a batch of real and fake images, visualize them (+ the gray images)
  input:
    - Ls: batch with L for each image, N x 1 x 256 x 256 tensor
    - real_abs: batch with ab for each real image, N x 2 x 256 x 256 tensor
    - fake_abs: batch with ab for each fake image, N x 2 x 256 x 256 tensor
  """

  n_cols = Ls.shape[0]

  real_images = convert_lab_to_rgb(Ls, real_abs)
  fake_images = convert_lab_to_rgb(Ls, fake_abs)

  fig = plt.figure(figsize=(15, 15))

  for i in range(n_cols):

    ax = plt.subplot(3, n_cols, i+1)
    ax.imshow(Ls[i][0].cpu(), cmap='gray')
    ax.axis("off")

    ax = plt.subplot(3, n_cols, i+1+n_cols)
    ax.imshow(real_images[i])
    ax.axis("off")

    ax = plt.subplot(3, n_cols, i+1+2*n_cols)
    ax.imshow(fake_images[i])
    ax.axis("off")

  plt.show()

The following class allows to track the losses over an epoch: it accumilates the losses over the dataset for each loss (GAN_loss for the generator, for the discriminator, etc.). Then allows to compute the mean of each losses. In this way we can compute the mean loss at each epoch and also in intermediate steps.

In [27]:
class LossTracker :
  def __init__(self) :

    self.count = 0 #n of images seen up to now
    self.avg = {} #dict with avg of losses

    self.dict_losses = {
        "G_loss" : 0,
        "G_GAN_loss" : 0,
        "G_L1_loss" : 0,
        "D_loss" : 0,
        "D_loss_real" : 0,
        "D_loss_fake" : 0
    }

  def set_to_zero(self) :

    """

    """

    for key in self.dict_losses :
      self.dict_losses[key] = 0
    self.count = 0  


  def update_losses(self, losses, batch_size) :
    """
    It updates the cumulative sum for the losses

    - losses: dict, the losses obtained in a single batch of dimension batch_size
    - batch_size: int
    """
    for key in losses :
      self.dict_losses[key] += losses[key] * batch_size

    self.count += batch_size


  def avg_losses(self) : 

    """
     for each loss metric compute the mean w.r.t. the 
     losses accumulate up to now
    """

    for key in self.dict_losses :
      self.avg[key] = self.dict_losses[key] / self.count 

    return self.avg


  def print_losses(self) :

    """
    print of the mean loss for each loss metric
    """

    for key in self.avg :
      temp = self.avg[key]
      print(f"{key} : {temp}")

The following functions are used to save and load checkpoints during training

In [28]:
def save_checkpoint(epoch, model_G, model_D, opt_G, opt_D, dict_losses):
  """
  Provided the current epoch, G, D, optG, optD, and the losses, this function saves a checkpoint
  """

  dict_save = {
                'epoch': epoch,
                'G_state_dict': model_G.state_dict(),
                'G_opt_state_dict': opt_G.state_dict(),
                'D_state_dict' : model_D.state_dict(),
                'D_opt_state_dict' : opt_D.state_dict(),
                'G_GAN_loss' : dict_losses['G_GAN_loss'],
                'D_loss' : dict_losses['D_loss'],
                'G_loss' : dict_losses['G_loss']
              }

  #dict_save.update(dict_losses)
  #print(dict_save)

  torch.save(dict_save, "cGAN_training.pt")

In [29]:
def load_checkpoint(G, D, opt_G, opt_D, path = "/content/cGAN_training.pt"):

  """
  - G: Generator model, GeneratorUNet
  - D: Discriminator model, PatchDiscriminator 
  - opt_G: optimizer for G
  - opt_D: optimizer for D
  
  - path: path to the file from which load the checkpoint
  """

  checkpoint = torch.load(path)
  epoch = checkpoint['epoch']
  G.load_state_dict(checkpoint['G_state_dict'])
  D.load_state_dict(checkpoint['D_state_dict'])

  opt_D.load_state_dict(checkpoint['D_opt_state_dict'])
 
  opt_G.load_state_dict(checkpoint['G_opt_state_dict'])


  print(f"Checkpoint at epoch {epoch} loaded")

  return epoch, {'G_GAN_loss' : checkpoint['G_GAN_loss'], 'D_loss' : checkpoint['D_loss'], 'G_loss' : checkpoint['G_loss']}

Functions to save and load only the generator

In [30]:
def save_generator(G, path = "/content/cGAN-gen.pt"):

  torch.save(G.state_dict(), path)

def load_generator(G, path = "/content/cGAN-gen.pt"):

  G.load_state_dict(torch.load(path))



Create and update a .csv file with all the losses during the training

In [31]:
import pandas as pd
from csv import writer

loss_names = ['G_loss', 'G_GAN_loss', 'G_L1_loss', 'D_loss_real', 'D_loss_fake', 'D_loss']

def create_csv(columns = loss_names, path = "/content/losses.csv"):

  df = pd.DataFrame(columns=columns)
  df.to_csv(path)

def update_csv(epoch, losses, columns = loss_names, path = "/content/losses.csv"):

  row_1 = [epoch]
  row_2 = [losses[column] for column in columns]

  row_to_add = row_1 + row_2

  with open(path, 'a') as f_object:

    writer_object = writer(f_object)
 
    writer_object.writerow(row_to_add)

    f_object.close()



# 8 Training

Run the cell in 8.3 only if the training has been interrupted and you want to complete the 100 epochs (more details later)

## 8.1 Functions for training

Define a training step. Pass the b&w image (L), the ab of the true image  and the others parameters needed. This function update both the parameters of the generator $G$ and the parameters of the discriminator $D$

In [32]:
def train_step(L, ab_real, g, d, device, dis_opt, gen_opt):

    g.train()
    d.train()

    #forward
    ab_fake = g(L)

    real_imgs = torch.cat([L, ab_real], dim = 1)
    fake_imgs = torch.cat([L, ab_fake], dim = 1)

    #UPDATE DISCRIMINATOR
    dis_opt.zero_grad()

    real_output = d(real_imgs)
    fake_output = d(fake_imgs.detach())

    D_loss_real = GAN_loss(real_output, True)
    D_loss_fake = GAN_loss(fake_output, False)

    D_loss = (D_loss_real + D_loss_fake) * 0.5 #???

    D_loss.backward()
    dis_opt.step()

    #UPDATE GENERATOR
    gen_opt.zero_grad()

    fake_output = d(fake_imgs)
    G_GAN_loss = GAN_loss(fake_output, True)
    G_L1_loss = L1_loss(ab_fake, ab_real) * lamb

    G_loss = G_L1_loss + G_GAN_loss

    
    G_loss.backward()
    gen_opt.step()
    
    return {"G_loss" : G_loss.item(), "G_GAN_loss" : G_GAN_loss.item(),"G_L1_loss" : G_L1_loss.item(), "D_loss" : D_loss.item(), "D_loss_real" : D_loss_real.item(), "D_loss_fake" : D_loss_fake.item()}, ab_fake

The following funtion implements the training over a number of epochs specified

In [33]:
import time as time

def train(dataloader, epochs, g, d, device, dis_opt, gen_opt, print_every, last_epoch_done = None):

  losstracker = LossTracker()

  for epoch in range(epochs):

    losstracker.set_to_zero()
    
    count = 0

    for batch in tqdm(dataloader):

        L = batch[0].to(device)
        ab = batch[1].to(device)

        #G_loss, D_loss, G_GAN_loss, G_L1_loss, D_loss_real, D_loss_fake, ab_fake = train_step(L, ab, g, d, device, dis_opt, gen_opt)
        dict_losses, ab_fake = train_step(L, ab, g, d, device, dis_opt, gen_opt)
        
        #track losses at each epoch
        losstracker.update_losses(dict_losses, L.shape[0])

        count += 1

        if (count % print_every == 0):
            #PRINT avg losses
            #show_results(L[:3], ab[:3], ab_fake[:3])
            losstracker.avg_losses()
            losstracker.print_losses()
            print("\n")
                       

    #SAVING MODELS + OPTs
    print("Saving model checkpoint")
    add = 0 if last_epoch_done == None else 1+last_epoch_done

    #compute avg losses over the epoch
    losses = losstracker.avg_losses()

    #save
    if epoch == (epochs -1):
      save_generator(g)
    else:
      save_checkpoint(epoch+add, g, d, gen_opt, dis_opt, losses)


    #save losses in csv

    update_csv(epoch+add, losses)

    print(f'Epoch {epoch+add} finished')

## 8.2 Start training

In [None]:
#training phase

#create csv to save losses
create_csv()

EPOCHS = 100    
print_every = 600
train( train_dataloader, EPOCHS, G, D, device, D_opt, G_opt, print_every)

## 8.3 Resume training

In the case in which the training is interrupted, you can resume the training with this cell: firstly upload the last .pt saved during training ('cGAN_training.pt') on Colab (if you have changed either the name of the file or the path, '/content/', you just specify to 'load_checkpoint' as extra argument the path. Have a look to the definition of the function)

In [None]:
#define the model from which to start
EPOCHS = 100
G = GeneratorUNet(1,2,8,64, "batchnorm").to(device)
D = PatchDiscriminator().to(device) 
 
#G.apply(weights_init)
#D.apply(weights_init)    
                                
G_opt = optim.Adam(G.parameters(), lr=lr_G, betas=betas)   
D_opt = optim.Adam(D.parameters(), lr=lr_D, betas=betas)

#if the training is stopped load the model and restart from the last point until #EPOCHS are reached
epoch, check_point_losses = load_checkpoint(G, D, G_opt, D_opt)       

print(check_point_losses)              

epochs_left = EPOCHS - epoch  - 1   
print_every = 600    
  
train( train_dataloader, epochs_left, G, D, device, D_opt, G_opt, print_every, epoch)

## 8.4 Load and test Generator (to test properly the generator look at the other notebooks)

The following cells load the generator and show the results with some new images. First of all load a file with the paths for images of the animals dataset (not used during training). **Upload the "test_animals.txt"**

In [None]:
files.upload();  

In [None]:
filename = "val_animals.txt"

def read_lines(path):

  lines = None

  with open(path) as file:
    lines = [line.rstrip() for line in file]

  return lines

animals_paths = read_lines(filename)

print(f"# images of animals for testing: {len(animals_paths)}")

# images of animals for testing: 2400


Add the other images of COCO dataset

In [None]:
test_paths = animals_paths

for path in paths:
  
  if path not in training_paths:

    test_paths.append(path)

print(f"# images for testing: {len(test_paths)}")

# images for testing: 9437


In [None]:
np.random.seed(33)
import random

random.seed(33)

test_paths = np.array(test_paths)
np.random.shuffle(test_paths)

Create dataset and dataloader

In [None]:
SIZE = 256
batch_size = 3

test_transform = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  transforms.InterpolationMode.BILINEAR),
            ])

test_dataset = GrayToColorDataset(test_paths, test_transform)

PIN_MEMORY = True
N_WORKERS = 2
BATCH_SIZE = 16

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=N_WORKERS,
                            pin_memory=PIN_MEMORY, shuffle = False)

load generator

In [None]:
G = GeneratorUNet(1,2,8,64, "batchnorm").to(device);

load_generator(G);

#G.train(); #
G.eval();

Show results cicle

In [None]:
def test_generator(G, device, dataloader, n_batcher_to_show = 4):

  for i,batch in enumerate(dataloader):
        print(f"Results for batch {i+1}")
        L = batch[0].to(device)
        ab = batch[1].to(device)

        ab_fake = G(L)

        show_results(L, ab, ab_fake)

        print("\n\n")

        if i == n_batcher_to_show -1:
          break

test_generator(G, device, test_dataloader, 2)