**This notebook has been used for WGAN training**



# 1 Data collection

Two datasets are used: a small version of COCO dataset with 21,837 images and one with 17,178 images of animals (12 categories)

##1.1 Animals dataset

We download this dataset from kaggle (1.4 GB)

In [None]:
!pip install -q kaggle
from google.colab import files

You have to upload a file called kaggle.json. To obtain it you need to follow the first 2 steps described in https://www.kaggle.com/general/74235

In [None]:
files.upload();

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! kaggle datasets list

In [None]:
!kaggle datasets download -d piyushkumar18/animal-image-classification-dataset

The data have been downloaded. To unzip them

In [None]:
!mkdir /content/animal_data
!unzip -qq /content/animal-image-classification-dataset.zip -d /content/animal_data/

## 1.2 COCO dataset

To download it we use fastai

In [None]:
!pip install fastai==2.4;

In [None]:
from fastai.data.external import untar_data, URLs
import os
import glob
import numpy as np

In [None]:
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"

paths = glob.glob(coco_path+"/*.jpg")
paths =np.array(paths)
num_images_coco = len(paths)
print(f"# coco images: {num_images_coco}")

Uplaod either data_small_training.txt or data_big_training.txt (we have used only the small dataset)

In [None]:
files.upload();

In [None]:
filename = "data_small_training.txt"

def read_lines(path):

  lines = None

  with open(path) as file:
    lines = [line.rstrip() for line in file]

  return lines

In [None]:
training_paths = read_lines(filename)
print(f"{len(training_paths)} images for training")

9600 images for training


# 2 Datasets and Dataloaders

In [None]:
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

np.random.seed(123)

## 2.1 Training Dataset

In [None]:
SIZE = 256

train_transform = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  transforms.InterpolationMode.BILINEAR),
                transforms.RandomHorizontalFlip(),
            ])

In [None]:
class GrayToColorDataset(Dataset):

  def __init__(self, paths, transform = None):
    
    self.paths = paths
    self.transform = transform

  def __len__(self):

    return len(self.paths)

  def __getitem__(self, idx):

    img_rgb = Image.open(self.paths[idx]).convert("RGB")
    img_rgb = self.transform(img_rgb)
    img_rgb = np.array(img_rgb)

    #RGB -> Lab
    img_lab = rgb2lab(img_rgb).astype("float32")
    img_lab = transforms.ToTensor()(img_lab)

    #to have values in range [-1,1]
    L = img_lab[[0],:]/50. - 1.
    ab = img_lab[[1,2],:] / 110.

    return (L,ab)


In [None]:
train_dataset = GrayToColorDataset(training_paths, train_transform)

In [None]:

PIN_MEMORY = True
N_WORKERS = 2
BATCH_SIZE = 32


train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=N_WORKERS,
                            pin_memory=PIN_MEMORY, shuffle = True)

# 3 cGAN models

## 3.1 Generator: U-Net

In [None]:
class UNetDown(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size = 4, normalization_type = None, dropout = 0.0, activation = None):

    super(UNetDown, self).__init__()

    #if batchnorm/instancenorm used, bias not used

    use_bias = normalization_type == None
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size, 2, 1, bias = use_bias)]

    if not use_bias:
      if normalization_type == "instance":

        layers.append(nn.InstanceNorm2d(out_channels))

      else:

        layers.append( nn.BatchNorm2d(out_channels))
        
    if activation == None:
      layers.append(nn.LeakyReLU(negative_slope = 0.2))

    if activation == "ReLU":

      layers.append(nn.ReLU())

    if dropout:

      layers.append(nn.Dropout(p = dropout))

    self.model = nn.Sequential(*layers)


  def forward(self, x):

    return self.model(x)


In [None]:
class UNetUp(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size = 4,  normalization_type = None, dropout = 0.0):

    super(UNetUp, self).__init__()

    use_bias = normalization_type == None

    layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 2, 1, bias = use_bias)]

    if not use_bias:
      if normalization_type == "instance":

        layers.append(nn.InstanceNorm2d(out_channels))

      else:

        layers.append( nn.BatchNorm2d(out_channels))

    layers.append(nn.ReLU())

    if dropout:

      layers.append(nn.Dropout(p = dropout))

    self.model = nn.Sequential(*layers)


  def forward(self, x, skip = None):
      x = self.model(x)
      if skip is not None:

        x = torch.cat((skip, x), 1)

      return x

In [None]:
class GeneratorUNet(nn.Module):

  def __init__(self, in_channels = 1, out_channels = 2, num_down = 8, ngf = 64, normalization_type = None):

    super(GeneratorUNet, self).__init__()

    self.downs = nn.ModuleList()
    self.ups = nn.ModuleList()
    

    features =[ngf]

    for i in range(3):

      features.append(features[i]*2)

    features.append(features[-1])
    #64, 128, 256, 512, 512

    if num_down > 5:

      features += [ngf * 8 for i in range(num_down - 5)]
    #for num_down = 8: 64, 128, 256, 512, 512, 512, 512, 512 (->1x1 for input size 256x256)


    #ENCODER (CONTRACTING PATH)

    #outermost down block: no normalization and no dropout, only downconv
    self.downs.append(UNetDown(in_channels, ngf, 4))

    in_channels = ngf #new in_channels for the next down-block
    
    for i,n_features in enumerate(features[1:len(features)-1]):
      #no dropout
      self.downs.append(UNetDown(in_channels, n_features, 4, normalization_type, 0.0))
      in_channels = n_features

    
    #innermost down block: no normalization and no dropout, only downconv
    self.downs.append(UNetDown(in_channels, features[-1], 4, activation = "ReLU"))
    

    #DECODER (EXPANSIVE PATH)
    i_channels = in_channels
    for i, n_features in enumerate((features[-2::-1])):
      
      
      #if i == 0, innermost(bottleneck), namely a block such that after down we go up. no dropout
      i_channels = in_channels if i == 0  else i_channels * 2

      #no dropout for the first up and the last 4 ups 
      dropout = 0.0 if (i == 0 or i  > 3) else 0.5

      self.ups.append(UNetUp(i_channels, n_features, 4, normalization_type, dropout))
      i_channels = n_features
    
    
    self.final = nn.Sequential(
        nn.ConvTranspose2d(ngf*2,out_channels, kernel_size=4, stride=2, padding=1),
        nn.Tanh()
    )



  def forward(self, x):

    skip_connections = list()

    #encoder
    for down in self.downs:

      x = down(x)
      skip_connections.append(x)

    #decoder with skip connections
    for i, up in enumerate(self.ups):
      
      x = up(x, skip_connections[-i-2])

    return self.final(x)

## 3.2 Discrimintor: PatchGAN

The descriminator is a PatchGAN for $N \times N$ patches where $N=70$: given an input $256 \times 256$ the output is $30 \times 30$

In [None]:
class PatchDiscriminator(nn.Module):

  def __init__(self, in_channels = 3, ndf = 64, n_down = 5, normalization_type = "batchnorm"):

    super(PatchDiscriminator, self).__init__()

    features = [ndf * 2**i for i in range(n_down-1)]

    layers = []

    for i in range(len(features)):
      use_bias = True if i < 1  else False
      stride = 2 if i < (len(features)-1) else 1
      layers.append(nn.Conv2d(in_channels, features[i], 4, stride, 1, bias = use_bias))

      if not use_bias:
        if normalization_type == "batchnorm":
          layers.append(nn.BatchNorm2d(features[i]))
          
        if normalization_type == "instance":
          layers.append(nn.InstanceNorm2d(features[i]))


      layers.append(nn.LeakyReLU(0.2))

      in_channels = features[i]
    
    layers.append(nn.Conv2d(in_channels, 1, 4, 1, 1))

    self.model = nn.Sequential(*layers)

  def forward(self, x):

    return self.model(x)

# 4 Models initialization

Initialization generator and critic

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

G = GeneratorUNet(1,2,8,64, "batchnorm").to(device)
C = PatchDiscriminator(normalization_type = "instance").to(device)

def weights_init(m):

    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        #nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.xavier_uniform_(m.weight.data)

        if hasattr(m, 'bias') and m.bias is not None:
          nn.init.constant_(m.bias.data, 0.0)
    
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

G.apply(weights_init);
C.apply(weights_init);

# 5 Training setup

Here we define the L1 loss, the optimizers, the number of epochs and the hyperparameters

## 5.1 Losses

In [None]:
#GAN_loss = GANLoss(device) 
L1_loss = nn.L1Loss()

#L1 hyperparam

lamb = 100

## 5.2 Optimizers

In [None]:
#params for Adam
lr_G = 2e-4 
lr_C = 2e-4

betas_G = (0.5, 0.999)
betas_C = (0.0, 0.9)

G_opt = optim.Adam(G.parameters(), lr=lr_G, betas=betas_G)
C_opt = optim.Adam(C.parameters(), lr=lr_C, betas=betas_C)

# 6 Utility functions

In [None]:
def convert_lab_to_rgb(L, ab):

  """
  Provided a Lab image or a batch of Lab images, it returns it/them in RGB format 
  input:
    - L: torch.tensor
    - ab: torch.tensor
  
  output:
    - img: numpy.ndarray (the rgb images)
  """

  #check shape (one image or a batch)

  is_batch = len(ab.shape) > 3
  
  L = (L+1.)*50.
  ab = ab*110.

  if is_batch:
    # input tensors: N x 1 x 256 x 256, N x 2 x 256 x 256
    Lab_images = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().detach().numpy()
  else:
    # input tensors: 1 x 256 x 256, 2 x 256 x 256
    Lab_image = torch.cat([L, ab], dim=0).permute(1, 2, 0).cpu().detach().numpy()
    return lab2rgb(Lab_image)

  rgb_images = list()

  for image in Lab_images:

    img_rgb = lab2rgb(image)
    rgb_images.append(img_rgb)

  return np.stack(rgb_images, axis=0)

In [None]:
def show_results(Ls, real_abs, fake_abs):

  """
  provided a batch of real and fake images, visualize them (+ the gray images)
  input:
    - Ls: batch with L for each image, N x 1 x 256 x 256 tensor
    - real_abs: batch with ab for each real image, N x 2 x 256 x 256 tensor
    - fake_abs: batch with ab for each fake image, N x 2 x 256 x 256 tensor
  """

  n_cols = Ls.shape[0]

  real_images = convert_lab_to_rgb(Ls, real_abs)
  fake_images = convert_lab_to_rgb(Ls, fake_abs)

  fig = plt.figure(figsize=(15, 15))

  for i in range(n_cols):

    ax = plt.subplot(3, n_cols, i+1)
    ax.imshow(Ls[i][0].cpu(), cmap='gray')
    ax.axis("off")

    ax = plt.subplot(3, n_cols, i+1+n_cols)
    ax.imshow(real_images[i])
    ax.axis("off")

    ax = plt.subplot(3, n_cols, i+1+2*n_cols)
    ax.imshow(fake_images[i])
    ax.axis("off")

  plt.show()

The following class allows to track the losses over an epoch: it accumilates the losses over the dataset for each loss (WGAN_loss for the generator, for the discriminator, etc.). Then allows to compute the mean of each losses. In this way we can compute the mean loss at each epoch and also for intermediate steps.

In [None]:
class WLossTracker :
  def __init__(self) :

    self.G_count = 0 #n of images seen up to now
    self.D_count = 0 #n of images seen up to now
    
    self.avg = {} #dict with avg of losses

    self.G_dict_losses = {
        "G_loss" : 0,
        "G_WGAN_loss" : 0,
        "G_L1_loss" : 0,

    }

    self.D_dict_losses = {
        "D_loss" : 0,
        "D_loss_real" : 0,
        "D_loss_fake" : 0,
        "GP" : 0
    }

  def set_to_zero(self) :

    """

    """

    for key in self.G_dict_losses :
      self.G_dict_losses[key] = 0

    for key in self.D_dict_losses :
      self.D_dict_losses[key] = 0

    self.G_count = 0  
    self.D_count = 0 


  def update_G_losses(self, losses, batch_size) :
    """
    It updates the cumulative sum for the losses

    - losses: dict, the losses obtained in a single batch of dimension batch_size
    - batch_size: int
    """
    for key in losses :
      self.G_dict_losses[key] += losses[key] * batch_size

    self.G_count += batch_size

  def update_D_losses(self, losses, batch_size) :
    """
    It updates the cumulative sum for the losses

    - losses: dict, the losses obtained in a single batch of dimension batch_size
    - batch_size: int
    """
    for key in losses :
      self.D_dict_losses[key] += losses[key] * batch_size

    self.D_count += batch_size



  def avg_losses(self) : 

    """
     for each loss metric compute the mean w.r.t. the 
     losses accumulate up to now
    """

    for key in self.G_dict_losses :
      self.avg[key] = self.G_dict_losses[key] / self.G_count 

    for key in self.D_dict_losses :
      self.avg[key] = self.D_dict_losses[key] / self.D_count 

    return self.avg


  def print_losses(self) :

    """
    print of the mean loss for each loss metric
    """

    for key in self.avg :
      temp = self.avg[key]
      print(f"{key} : {temp}")

The following functions are used to save and load checkpoints during training

In [None]:
def wsave_checkpoint(epoch, model_G, model_D, opt_G, opt_D, dict_losses):
  """
  Provided the current epoch, G, D, optG, optD, and the losses, this function saves a checkpoint
  """

  dict_save = {
                'epoch': epoch,
                'G_state_dict': model_G.state_dict(),
                'G_opt_state_dict': opt_G.state_dict(),
                'D_state_dict' : model_D.state_dict(),
                'D_opt_state_dict' : opt_D.state_dict(),
                'G_WGAN_loss' : dict_losses['G_WGAN_loss'],
                'D_loss' : dict_losses['D_loss'],
                'G_loss' : dict_losses['G_loss']
              }

  #dict_save.update(dict_losses)
  #print(dict_save)

  torch.save(dict_save, "WGAN_training.pt")

In [None]:
def wload_checkpoint(G, D, opt_G, opt_D, path = "/content/WGAN_training.pt"):

  """
  - G: Generator model, GeneratorUNet
  - D: critic model, PatchDiscriminator 
  - opt_G: optimizer for G
  - opt_D: optimizer for D
  
  - path: path to the file from which load the checkpoint
  """

  checkpoint = torch.load(path)
  epoch = checkpoint['epoch']
  G.load_state_dict(checkpoint['G_state_dict'])
  D.load_state_dict(checkpoint['D_state_dict'])

  opt_D.load_state_dict(checkpoint['D_opt_state_dict'])
 
  opt_G.load_state_dict(checkpoint['G_opt_state_dict'])


  print(f"Checkpoint at epoch {epoch} loaded")

  return epoch, {'G_WGAN_loss' : checkpoint['G_WGAN_loss'], 'D_loss' : checkpoint['D_loss'], 'G_loss' : checkpoint['G_loss']}

Functions to save and load only the generator

In [None]:
def save_generator(G, path = "/content/WGAN-gen.pt"):

  torch.save(G.state_dict(), path)

def load_generator(G, path = "/content/WGAN-gen.pt"):

  G.load_state_dict(torch.load(path))



Create and update a .csv file with all the losses during the training

In [None]:
import pandas as pd
from csv import writer

wloss_names = ['G_loss', 'G_WGAN_loss', 'G_L1_loss', 'D_loss_real', 'D_loss_fake', 'D_loss', 'GP']

def create_csv(columns = wloss_names, path = "/content/wlosses.csv"):

  df = pd.DataFrame(columns=columns)
  df.to_csv(path)

def update_csv(epoch, losses, columns = wloss_names, path = "/content/wlosses.csv"):

  row_1 = [epoch]
  row_2 = [losses[column] for column in columns]

  row_to_add = row_1 + row_2

  with open(path, 'a') as f_object:

    writer_object = writer(f_object)
 
    writer_object.writerow(row_to_add)

    f_object.close()



# 7 Training


## 7.1 Functions for training

Function that computes the GP term

In [None]:
def GP_penalty_term(critic, real_imgs, fake_imgs, device) :

  alpha = torch.rand((real_imgs.shape[0], 1, 1, 1), device = device)
  
  interpolates = (alpha * real_imgs + ((1 - alpha) * fake_imgs)).requires_grad_(True)
  out_interpolates = critic(interpolates)

  grad_outputs = torch.ones(out_interpolates.size(), device = device, requires_grad = False)

  gradients = torch.autograd.grad(
      outputs = out_interpolates,
      inputs = interpolates,
      grad_outputs = grad_outputs,
      create_graph = True,
      retain_graph = True,
      only_inputs = True,
  )[0]

  gradients = gradients.view(gradients.size(0), -1)
  gradient_penalty = torch.mean((gradients.norm(2, dim = 1) -1) ** 2)

  return gradient_penalty

The following functions define how a single training step works: one function for the critic and two for the generator.

In [None]:
def w_gen_train_step(L, ab_real, ab_fake, fake_imgs, critic, g,  g_opt, device, C_L1 = 100) :

  g.train()

  g_opt.zero_grad()

  fake_output = critic(fake_imgs)

  G_WGAN_loss = -torch.mean(fake_output)
  G_L1_loss =  L1_loss(ab_fake, ab_real) * C_L1
  G_loss = G_WGAN_loss + G_L1_loss

  G_loss.backward()
  g_opt.step()

  return {"G_WGAN_loss" : G_WGAN_loss.item(), "G_L1_loss" : G_L1_loss.item(), "G_loss" : G_loss.item()}

Train step for the generator only with L1 norm

In [None]:
def w_gen_train_step_L1(L, ab_real, ab_fake, fake_imgs, critic, g,  g_opt, device, adversial = False, C_L1 = 100 ) :

  g.train()

  g_opt.zero_grad()


  G_loss = None

  if adversial:

    fake_output = critic(fake_imgs)
    G_WGAN_loss = -torch.mean(fake_output)
    G_L1_loss = 0
    G_loss = G_WGAN_loss

  else:
    G_L1_loss =  L1_loss(ab_fake, ab_real) * C_L1
    G_WGAN_loss = 0
    G_loss = G_L1_loss

  G_loss = G_WGAN_loss + G_L1_loss

  G_loss.backward()
  g_opt.step()


  if adversial:

    G_WGAN_loss = G_WGAN_loss.item()
  else:

    G_L1_loss = G_L1_loss.item()




  return {"G_WGAN_loss" : G_WGAN_loss, "G_L1_loss" : G_L1_loss, "G_loss" : G_loss.item()}

Critic train step

In [None]:
def w_critic_train_step(L, ab_real, fake_imgs, critic, critic_opt, device, C_GP = 10) :

  critic.train()

  #forward
  real_imgs = torch.cat([L, ab_real], dim = 1)

  critic_opt.zero_grad()

  real_output = critic(real_imgs)
  fake_output = critic(fake_imgs.detach())

  D_loss_real = torch.mean(real_output)
  D_loss_fake = torch.mean(fake_output)
  GP = GP_penalty_term(critic, real_imgs.data, fake_imgs.data, device) * C_GP

  D_loss = - D_loss_real + D_loss_fake + GP

  D_loss.backward()
  critic_opt.step()

  return {"D_loss_real" : -D_loss_real.item(), "D_loss_fake" : D_loss_fake.item(), "D_loss" : D_loss.item(), "GP" : GP.item()}

## 7.2 Train functions



### 7.2.1. L1 loss only

The following function allows to train the generator only (we pass the critic only for the saving)

In [None]:
def wgan_train_L1(dataloader, epochs, g, device, g_opt, print_every, critic, c_opt, C_L1 = 100, last_epoch_done = None) :


  #loss tracker da modificare

  losstracker = WLossTracker()
  

  for epoch in range(epochs) : 

    g.train()

    losstracker.set_to_zero()
    losstracker.update_D_losses({"D_loss_real" : 0, "D_loss_fake" : 0, "D_loss" : 0, "GP" : 0}, 32)
    progress_bar = tqdm(enumerate(dataloader), total = len(dataloader))

    for i, batch in progress_bar :

      L = batch[0].to(device)
      ab = batch[1].to(device)

      ab_fake = g(L)
      fake_imgs = torch.cat([L, ab_fake], dim = 1)
                                  
      G_losses = w_gen_train_step_L1(L, ab, ab_fake, fake_imgs, None,  g,  g_opt, device)
      losstracker.update_G_losses(G_losses, L.shape[0])

      if (i + 1) % print_every == 0 :
        #show_results(L[:3], ab[:3], ab_fake[:3])
        losstracker.avg_losses()
        losstracker.print_losses()
        print("\n\n")

    #SAVING MODELS + OPTs
    print("Saving model checkpoint")
    add = 0 if last_epoch_done == None else 1+last_epoch_done

    #compute avg losses over the epoch
    losses = losstracker.avg_losses()

    #save
    if epoch == (epochs -1):
      save_generator(g)
    else:
      wsave_checkpoint(epoch+add, g, critic, g_opt, c_opt, losses)


    #save losses in csv

    update_csv(epoch+add, losses, columns = wloss_names, path = "/content/wlosses.csv")

    print(f'Epoch {epoch+add} finished')

### 7.2.2 L1 loss + WGAN

In [None]:
def wgan_train(dataloader, epochs, g, critic, device, critic_opt, g_opt, print_every, Cs = {"L1" : 100, "GP" : 10},last_epoch_done = None, n_critic = 5) :

  C_L1 = Cs["L1"]
  C_GP = Cs["GP"]
  #loss tracker da modificare

  losstracker = WLossTracker()

  for epoch in range(epochs) : 

    losstracker.set_to_zero()
    progress_bar = tqdm(enumerate(dataloader), total = len(dataloader))

    for i, batch in progress_bar :

      L = batch[0].to(device)
      ab = batch[1].to(device)

      ab_fake = g(L)
      fake_imgs = torch.cat([L, ab_fake], dim = 1)
      
     
      D_losses = w_critic_train_step(L, ab, fake_imgs, critic, critic_opt, device, C_GP)

      losstracker.update_D_losses(D_losses, L.shape[0])


      if (i + 1) % n_critic == 0 :

        
        G_losses = w_gen_train_step(L, ab, ab_fake, fake_imgs, critic, g,  g_opt, device, C_L1 = C_L1 )
        losstracker.update_G_losses(G_losses, L.shape[0])

        if (i + 1) % print_every == 0 :
            show_results(L[:3], ab[:3], ab_fake[:3])
            losstracker.avg_losses()
            losstracker.print_losses()
            print("\n\n")

    #SAVING MODELS + OPTs
    print("Saving model checkpoint")
    add = 0 if last_epoch_done == None else 1+last_epoch_done

    #compute avg losses over the epoch
    losses = losstracker.avg_losses()

    #save
    if epoch == (epochs -1):
      save_generator(g)
    else:
      wsave_checkpoint(epoch+add, g, critic, g_opt, critic_opt, losses)


    #save losses in csv

    update_csv(epoch+add, losses, columns = wloss_names, path = "/content/wlosses.csv")

    print(f'Epoch {epoch+add} finished')

## 7.3 Start L1 training

In [None]:
#training phase

#create csv to save losses
create_csv()

EPOCHS = 60
print_every = 150
wgan_train_L1(train_dataloader, EPOCHS, G, device, G_opt, print_every, C, C_opt, C_L1 = 100, last_epoch_done = None)


## 7.4 Resume training L1

In the case in which the training (with only the L1 loss) is interrupted, you can resume the training with this cell: firstly upload the last .pt saved during training ('WGAN_training.pt') on Colab (if you have changed either the name of the file or the path, '/content/', you just specify to 'load_checkpoint' as extra argument the path. Have a look to the definition of the function)

In [None]:
EPOCHS = 60

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

G = GeneratorUNet(1,2,8,64, "batchnorm").to(device)
C = PatchDiscriminator(normalization_type = "instance").to(device)

lr_G = 2e-4 
lr_C = 2e-4

betas_G = (0.5, 0.999)
betas_C = (0.0, 0.9) # RMSProp

G_opt = optim.Adam(G.parameters(), lr=lr_G, betas=betas_G)   
C_opt = optim.Adam(C.parameters(), lr=lr_C, betas=betas_C)

epoch, check_point_losses = wload_checkpoint(G, C, G_opt, C_opt)

print(check_point_losses)


epochs_left = EPOCHS - epoch  - 1   
print_every = 150
wgan_train_L1( train_dataloader, EPOCHS, G, device, G_opt,print_every, C, C_opt, last_epoch_done = epoch)

## 7.5 Start training L1 + WGAN

In [None]:
EPOCHS = 80

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

G = GeneratorUNet(1,2,8,64, "batchnorm").to(device)
C = PatchDiscriminator(normalization_type = "instance").to(device)
     
lr_G = 2e-4 
lr_C = 2e-4 

betas_G = (0.5, 0.999)
betas_C = (0.0, 0.9) # RMSProp

G_opt = optim.Adam(G.parameters(), lr=lr_G, betas=betas_G)   
C_opt = optim.Adam(C.parameters(), lr=lr_C, betas=betas_C)

epoch, check_point_losses = wload_checkpoint(G, C, G_opt, C_opt)

G_opt = optim.Adam(G.parameters(), lr=lr_G, betas=betas_G)   
C_opt = optim.Adam(C.parameters(), lr=lr_C, betas=betas_C)

print(check_point_losses)


epochs_left = EPOCHS - epoch  - 1   
print_every = 150    


wgan_train( train_dataloader, epochs_left, G, C, device, C_opt, G_opt, print_every, Cs = {"L1": 1, "GP" : 1}, last_epoch_done = epoch)

## 7.6 Resume training L1 + WGAN

In [None]:
EPOCHS = 120
 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

G = GeneratorUNet(1,2,8,64, "batchnorm").to(device)
C = PatchDiscriminator(normalization_type = "instance").to(device)

lr_G = 2e-4 
lr_C = 2e-4

betas_G = (0.5, 0.999)
betas_C = (0.0, 0.9) # RMSProp

G_opt = optim.Adam(G.parameters(), lr=lr_G, betas=betas_G)   
C_opt = optim.Adam(C.parameters(), lr=lr_C, betas=betas_C)

epoch, check_point_losses = wload_checkpoint(G, C, G_opt, C_opt)


print(check_point_losses)


epochs_left = EPOCHS - epoch  - 1   
print_every = 150     




wgan_train( train_dataloader, epochs_left, G, C, device, C_opt, G_opt, print_every, Cs = {"L1": 1, "GP" : 1}, last_epoch_done = epoch)

## 8.4 Load and test Generator (to test properly the generator look at the other notebooks)

In [None]:
EPOCHS = 120
 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

G = GeneratorUNet(1,2,8,64, "batchnorm").to(device)
C = PatchDiscriminator(normalization_type = "instance").to(device)

lr_G = 2e-4 
lr_C = 2e-4

betas_G = (0.5, 0.999)
betas_C = (0.0, 0.9) # RMSProp

G_opt = optim.Adam(G.parameters(), lr=lr_G, betas=betas_G)   
C_opt = optim.Adam(C.parameters(), lr=lr_C, betas=betas_C)

epoch, check_point_losses = wload_checkpoint(G, C, G_opt, C_opt,"/content/WGAN_training_59.pt")


print(check_point_losses)

save_generator(G, "WGAN_9k_60.pt")

Checkpoint at epoch 59 loaded
{'G_WGAN_loss': 0.0, 'D_loss': 0.0, 'G_loss': 5.827366714477539}


The following cells load the generator and show the results with some new images. First of all load a file with the paths for images of the animals dataset (not used during training) 

In [None]:
files.upload(); #val_txt

In [None]:
filename = "val_animals.txt"

def read_lines(path):

  lines = None

  with open(path) as file:
    lines = [line.rstrip() for line in file]

  return lines

animals_paths = read_lines(filename)

print(f"# images of animals for testing: {len(animals_paths)}")

# images of animals for testing: 2400


Add the other images of COCO dataset

In [None]:
test_paths = animals_paths

for path in paths:
  
  if path not in training_paths:

    test_paths.append(path)

print(f"# images for testing: {len(test_paths)}")

# images for testing: 17637


In [None]:
np.random.seed)
import random

random.seed(1)

test_paths = np.array(test_paths)
np.random.shuffle(test_paths)

Create dataset and dataloader

In [None]:
SIZE = 256
batch_size = 3

test_transform = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  transforms.InterpolationMode.BILINEAR),
            ])

test_dataset = GrayToColorDataset(test_paths, test_transform)

PIN_MEMORY = True
N_WORKERS = 2
BATCH_SIZE = 16

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=N_WORKERS,
                            pin_memory=PIN_MEMORY, shuffle = False)

load generator

In [None]:
#G = GeneratorUNet(1,2,8,64, "batchnorm").to(device);

#load_generator(G);

G = GeneratorUNet(1,2,8,64, "batchnorm").to(device)
load_generator(G,"/content/WGAN_9k_60.pt")

#G.train(); #
G.eval();

Show results cicle

In [None]:
def test_generator(G, device, dataloader, n_batcher_to_show = 4):

  for i,batch in enumerate(dataloader):
        print(f"Results for batch {i+1}")
        L = batch[0].to(device)
        ab = batch[1].to(device)

        ab_fake = G(L)

        show_results(L, ab, ab_fake)

        print("\n\n")

        if i == n_batcher_to_show -1:
          break

test_generator(G, device, test_dataloader, 10)