# 1 Data collection

Two datasets are used: a small version of COCO dataset with 21,837 images and one with 17,178 images of animals (12 categories)

##1.1 Animals dataset

We download this dataset from kaggle (1.4 GB)

In [None]:
!pip install -q kaggle
from google.colab import files

You have to upload a file called kaggle.json. To obtain it you need to follow the first 2 steps described in https://www.kaggle.com/general/74235

In [None]:
files.upload()

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! kaggle datasets list

In [None]:
!kaggle datasets download -d piyushkumar18/animal-image-classification-dataset

The data have been downloaded. To unzip them

In [None]:
!mkdir /content/animal_data
!unzip -qq /content/animal-image-classification-dataset.zip -d /content/animal_data/

## 1.2 COCO dataset

To download it we use fastai

In [None]:
!pip install fastai==2.4

In [None]:
from fastai.data.external import untar_data, URLs
import os
import glob
import numpy as np

In [None]:
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"

paths = glob.glob(coco_path+"/*.jpg")
paths =np.array(paths)
num_images_coco = len(paths)
print(f"# coco images: {num_images_coco}")

Choose one of the dataset among the two text files: "data_big_training.txt" (16k images of which 4.2k animal images) or "data_small_training.txt" (9.6k images of which 3k animal images).
Since we want to test the generators, choose the big one.

In [None]:
files.upload();

In [None]:
filename = "data_big_training.txt" #choose the proper file name

def read_lines(path):

  lines = None

  with open(path) as file:
    lines = [line.rstrip() for line in file]

  return lines

In [None]:
training_paths = read_lines(filename)
print(f"{len(training_paths)} images for training")

Upload the test dataset "test_animals.txt"


In [None]:
files.upload();

#2 Loading of all the generators

In [None]:
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
class UNetDown(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size = 4, normalization_type = None, dropout = 0.0, activation = None):

    super(UNetDown, self).__init__()

    #if batchnorm/instancenorm used, bias not used

    use_bias = normalization_type == None
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size, 2, 1, bias = use_bias)]

    if not use_bias:
      if normalization_type == "instance":

        layers.append(nn.InstanceNorm2d(out_channels))

      else:

        layers.append( nn.BatchNorm2d(out_channels))
        
    if activation == None:
      layers.append(nn.LeakyReLU(negative_slope = 0.2))

    if activation == "ReLU":

      layers.append(nn.ReLU())

    if dropout:

      layers.append(nn.Dropout(p = dropout))

    self.model = nn.Sequential(*layers)


  def forward(self, x):

    return self.model(x)

In [None]:
class UNetUp(nn.Module):

  def __init__(self, in_channels, out_channels, kernel_size = 4,  normalization_type = None, dropout = 0.0):

    super(UNetUp, self).__init__()

    use_bias = normalization_type == None

    layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 2, 1, bias = use_bias)]

    if not use_bias:
      if normalization_type == "instance":

        layers.append(nn.InstanceNorm2d(out_channels))

      else:

        layers.append( nn.BatchNorm2d(out_channels))

    layers.append(nn.ReLU())

    if dropout:

      layers.append(nn.Dropout(p = dropout))

    self.model = nn.Sequential(*layers)


  def forward(self, x, skip = None):
      x = self.model(x)
      if skip is not None:

        x = torch.cat((skip, x), 1)

      return x

In [None]:
class GeneratorUNet(nn.Module):

  def __init__(self, in_channels = 1, out_channels = 2, num_down = 8, ngf = 64, normalization_type = None):

    super(GeneratorUNet, self).__init__()

    self.downs = nn.ModuleList()
    self.ups = nn.ModuleList()
    

    features =[ngf]

    for i in range(3):

      features.append(features[i]*2)

    features.append(features[-1])
    #64, 128, 256, 512, 512

    if num_down > 5:

      features += [ngf * 8 for i in range(num_down - 5)]
    #for num_down = 8: 64, 128, 256, 512, 512, 512, 512, 512 (->1x1 for input size 256x256)


    #ENCODER (CONTRACTING PATH)

    #outermost down block: no normalization and no dropout, only downconv
    self.downs.append(UNetDown(in_channels, ngf, 4))

    in_channels = ngf #new in_channels for the next down-block
    
    for i,n_features in enumerate(features[1:len(features)-1]):
      #no dropout
      self.downs.append(UNetDown(in_channels, n_features, 4, normalization_type, 0.0))
      in_channels = n_features

    
    #innermost down block: no normalization and no dropout, only downconv
    self.downs.append(UNetDown(in_channels, features[-1], 4, activation = "ReLU"))
    

    #DECODER (EXPANSIVE PATH)
    i_channels = in_channels
    for i, n_features in enumerate((features[-2::-1])):
      
      
      #if i == 0, innermost(bottleneck), namely a block such that after down we go up. no dropout
      i_channels = in_channels if i == 0  else i_channels * 2

      #no dropout for the first up and the last 4 ups 
      dropout = 0.0 if (i == 0 or i  > 3) else 0.5

      self.ups.append(UNetUp(i_channels, n_features, 4, normalization_type, dropout))
      i_channels = n_features
    
    
    self.final = nn.Sequential(
        nn.ConvTranspose2d(ngf*2,out_channels, kernel_size=4, stride=2, padding=1),
        nn.Tanh()
    )



  def forward(self, x):

    skip_connections = list()

    #encoder
    for down in self.downs:

      x = down(x)
      skip_connections.append(x)

    #decoder with skip connections
    for i, up in enumerate(self.ups):
      
      x = up(x, skip_connections[-i-2])

    return self.final(x)

## 2.1 Connect to drive and load the .pt for the generators

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp -r /content/drive/MyDrive/TrainedNets/WGAN_9k_120.pt /content/ #WGAN generator

In [None]:
!cp -r /content/drive/MyDrive/TrainedNets/cGAN_big.pt /content/ # cGAN generator with 16k dataset

In [None]:
!cp -r /content/drive/MyDrive/TrainedNets/cGAN_small_16_100.pt /content/ # cGAN generator with 9.6k dataset

In [None]:
!cp -r /content/drive/MyDrive/TrainedNets/cGAN_small_32_100.pt /content/

In [None]:
!cp -r /content/drive/MyDrive/TrainedNets/cGAN_small_8.pt /content/

In [None]:
# Add the .pt you want to test
G_paths = ["WGAN_9k_120.pt", "cGAN_big.pt", "cGAN_small_16_100.pt", "cGAN_small_32_100.pt", "cGAN_small_8.pt" ]

In [None]:
def load_generator(G, path = "/content/cGAN-gen.pt"):
  G.load_state_dict(torch.load(path))

In [None]:
Gs = [] # Array in which we store all the generators

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for i, G_path in enumerate(G_paths):

  Gs.append( GeneratorUNet(1,2,8,64, "batchnorm").to(device) )
  Gs[i].eval()

  load_generator(Gs[i], "/content/" + G_path)

# 3 Visualize Results: show results with some test images

With the following cells you can plot the results with some test images and save them in a jpeg image.

In [None]:
test_animals_paths = read_lines("test_animals.txt")
print(f"{len(test_animals_paths)} animal images for testing")

In [None]:
test_coco_paths = []

for path in paths:
  
  if path not in training_paths:
    test_coco_paths.append(path)

print(f"{len(test_coco_paths)} coco images for testing")


## 3.1 Dataset and Dataloader

In [None]:
SIZE = 256

test_transform = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  transforms.InterpolationMode.BILINEAR),
                #transforms.RandomHorizontalFlip(),
            ])

In [None]:
class GrayToColorDataset(Dataset):

  def __init__(self, paths, transform = None):
    
    self.paths = paths
    self.transform = transform

  def __len__(self):

    return len(self.paths)

  def __getitem__(self, idx):

    img_rgb = Image.open(self.paths[idx]).convert("RGB")
    img_rgb = self.transform(img_rgb)
    img_rgb = np.array(img_rgb)

    #RGB -> Lab
    img_lab = rgb2lab(img_rgb).astype("float32")
    img_lab = transforms.ToTensor()(img_lab)

    #to have values in range [-1,1]
    L = img_lab[[0],:]/50. - 1.
    ab = img_lab[[1,2],:] / 110.

    return (L,ab)


In [None]:
test_coco_dataset = GrayToColorDataset(test_coco_paths, test_transform)
test_animals_dataset = GrayToColorDataset(test_animals_paths, test_transform)

In [None]:

PIN_MEMORY = True
N_WORKERS = 2
BATCH_SIZE = 9

test_coco_dataloader = DataLoader(test_coco_dataset, batch_size=BATCH_SIZE, num_workers=N_WORKERS,
                            pin_memory=PIN_MEMORY, shuffle = False)

test_animals_dataloader = DataLoader(test_animals_dataset, batch_size=BATCH_SIZE, num_workers=N_WORKERS,
                            pin_memory=PIN_MEMORY, shuffle = True)

## 3.2 Plot and save results

In [None]:
def convert_lab_to_rgb(L, ab):

  """
  Provided a Lab image or a batch of Lab images, it returns it/them in RGB format 
  input:
    - L: torch.tensor
    - ab: torch.tensor
  
  output:
    - img: numpy.ndarray (the rgb images)
  """

  #check shape (one image or a batch)

  is_batch = len(ab.shape) > 3
  
  L = (L+1.)*50.
  ab = ab*110.

  if is_batch:
    # input tensors: N x 1 x 256 x 256, N x 2 x 256 x 256
    Lab_images = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().detach().numpy()
  else:
    # input tensors: 1 x 256 x 256, 2 x 256 x 256
    Lab_image = torch.cat([L, ab], dim=0).permute(1, 2, 0).cpu().detach().numpy()
    return lab2rgb(Lab_image)

  rgb_images = list()

  for image in Lab_images:

    img_rgb = lab2rgb(image)
    rgb_images.append(img_rgb)

  return np.stack(rgb_images, axis=0)

In [None]:
def show_results(Ls, real_abs, fake_abs, path):

  """
  provided a batch of real and fake images, visualize them (+ the gray images)
  input:
    - Ls: batch with L for each image, N x 1 x 256 x 256 tensor
    - real_abs: batch with ab for each real image, N x 2 x 256 x 256 tensor
    - fake_abs: batch with ab for each fake image, N x 2 x 256 x 256 tensor
  """

  n_cols = Ls.shape[0]

  real_images = convert_lab_to_rgb(Ls, real_abs)
  fake_images = convert_lab_to_rgb(Ls, fake_abs)

  fig, axes = plt.subplots(3, 3, figsize=(20, 20))

  for idx in range(3):

    axes[0,idx].axis("off")
    axes[0,idx].imshow(real_images[ 3*idx], aspect = "auto")

    axes[1,idx].axis("off")
    axes[1,idx].imshow(real_images[ 3*idx + 1], aspect = "auto")

    axes[2,idx].axis("off")
    axes[2,idx].imshow(real_images[ 3*idx + 2], aspect = "auto")
  plt.subplots_adjust(wspace=0.05, hspace = .05)
  plt.savefig(path + "_real.jpg")
  plt.show()
  

  fig_fake, axes_fake = plt.subplots(3, 3, figsize=(20, 20))

  for idx in range(3):

    axes_fake[0,idx].axis("off")
    axes_fake[0,idx].imshow(fake_images[ 3*idx ], aspect = "auto")

    axes_fake[1,idx].axis("off")
    axes_fake[1,idx].imshow(fake_images[ 3*idx + 1], aspect = "auto")

    axes_fake[2,idx].axis("off")
    axes_fake[2,idx].imshow(fake_images[ 3*idx + 2], aspect = "auto")
  plt.subplots_adjust(wspace=0.05, hspace = .05)
  plt.savefig(path + "_fake.jpg")
  plt.show()
  

In [None]:
def show_images_for_model(G, dataloader, path):

  Ls, abs = next(iter(dataloader))
  Ls = Ls.to(device)
  abs = abs.to(device)
  abs_fake = G(Ls)
  show_results(Ls, abs, abs_fake, path)
  



Choose the model to test: change idx to choose a different model. Have a look at Gs list to select the proper index.

In [None]:
import random
np.random.seed(123)
random.seed(10)

show_images_for_model(Gs[5], test_animals_dataloader, "cGAN_small_mix")

Now you can download the .jpg images

# 5 Evaluate Generator - first metric

## 5.1 Load classifiers

Load on colab the two .pt files for the two classifiers (one for colored images and the other for gray images)

In [None]:
!cp -r /content/drive/MyDrive/TrainedNets/vgg16-color.pt /content/

In [None]:
from torchvision import models, transforms

C_c = models.vgg16(pretrained=True); #Classifier for color images

C_c.classifier[6] = nn.Linear(in_features=4096, out_features=12)


C_c.load_state_dict(torch.load("/content/vgg16-color.pt"))

C_c = C_c.to(device);


C_c.eval();

## 5.2 Dataset and Dataloader

In [None]:
import os
from sklearn.preprocessing import LabelEncoder
import pandas as pd

In [None]:
def build_test_dataset(test_path = "test_animals.txt"):

  #starting path for the kaggle dataset
  start_path = '/content/animal_data/Animal Image Dataset/'

  images = []
  labels = []

  with open(test_path) as file:
    val_paths = [line.rstrip() for line in file]

    for path in val_paths:
      label = path.split('/')[4]
      images.append(path)
      labels.append(label)
      
  data = {'Images':images, 'Labels':labels} 
  data = pd.DataFrame(data) 

  lb = LabelEncoder()
  data['encoded_labels'] = lb.fit_transform(data['Labels'])

  return data



In [None]:
data = build_test_dataset(test_path = "test_animals.txt")

In [None]:
# Transform used by the classifiers
trans_classifier = transforms.Compose([
                  transforms.Resize((224,224)),
                  transforms.ToTensor(),
                  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                  ])

# Transform used to convert the result from the generator to the input format for the classifier
trans_gan_to_classifier = transforms.Compose([
                  transforms.Resize((224,224)),
                  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                  ])

# Transform for the GAN
trans_gan = transforms.Compose([
                transforms.Resize((256, 256),  transforms.InterpolationMode.BILINEAR),
            ])

The following dataset will allow us to get all the input for the classifiers and for the generator (in particular the color image, the grayscale image and the label)

In [None]:
class Animals_Dataset(Dataset):
    def __init__(self, img_data, tr1, tr2):
        self.tr1 = tr1
        self.tr2 = tr2
        self.img_data = img_data
        
    def __len__(self):
        return len(self.img_data)
    
    def __getitem__(self, index):
        img_name = self.img_data.loc[index, 'Images']

        #format for classifiers
        image = Image.open(img_name)
        image = image.convert('RGB')
        gray = image.convert('L')
        gray_image = gray.convert('RGB')

        #format for cGAN, WGAN
        img_rgb = self.tr2(image)
        img_np = np.array(img_rgb)
        img_lab = rgb2lab(img_np).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)

        L = img_lab[[0],:] /50.-1.
        
        
        label = torch.tensor(self.img_data.loc[index, 'encoded_labels'])
        
        if self.tr1 is not None:
            image = self.tr1(image)
            gray_image = self.tr1(gray_image)
        
        return image, gray_image, L, label

In [None]:
test_dataset = Animals_Dataset(data, trans_classifier, trans_gan)

In [None]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16,
                                                shuffle = False)

## 5.3 Function for the evaluation

In [None]:
def evaluate_generator(G):

  softmax = nn.Softmax(dim = 1)
  G.eval()

  weighted_sum_fake = 0.0
  to_normalize = 0.0

  weighted_sum_real = 0.0
  weighted_sum_gray = 0.0

  sum_fake = 0.0
  sum_real = 0.0
  sum_gray = 0.0

  accuracy_fake = 0.0
  accuracy_color = 0.0
  accuracy_gray = 0.0

  for batch in tqdm(test_loader):

    with torch.no_grad():
      image = batch[0].to(device)
      gray = batch[1].to(device)
      L = batch[2].to(device)
      labels = batch[3].to(device)

      #FORWARD

      prob_color_s = softmax(C_c(image))
      prob_gray_c = softmax(C_c(gray))

      ab_fake = G(L)

      #from GAN output to RGB (input classifier)
      rgb_fake = torch.from_numpy(convert_lab_to_rgb(L, ab_fake)).permute(0,3,1,2)
      rgb_fake = trans_gan_to_classifier(rgb_fake).to(device)

      #output_fake: output C_c with fake images 
      output_fake = C_c(rgb_fake) # logits
      
      #labels predicted by C_c with fake images
      pred_labels_fake = torch.argmax(output_fake, dim = 1)


      #select probs for the correct classes
      prob_color = prob_color_s[np.arange(len(prob_color_s)),labels]
      prob_fake = softmax(output_fake)[np.arange(len(prob_gray_c)),labels]

      prob_gray = prob_gray_c[np.arange(len(prob_gray_c)),labels]

      #weights computation
      weights = torch.abs(prob_color - prob_gray)

      #update unnormalized sum and sum of weights
      weighted_sum_fake += torch.sum(weights*prob_fake)
      weighted_sum_real += torch.sum(weights*prob_color)
      weighted_sum_gray += torch.sum(weights*prob_gray)


      to_normalize += torch.sum(weights)

      sum_fake += torch.sum(prob_fake)
      sum_gray += torch.sum(prob_gray)
      sum_real += torch.sum(prob_color)

      #ACCURACY
      accuracy_fake += torch.sum(pred_labels_fake == labels)
      accuracy_color += torch.sum(torch.argmax(prob_color_s, dim= 1) == labels)
      accuracy_gray += torch.sum(torch.argmax(prob_gray_c, dim = 1) == labels)


  weighted_multinoulli_pred_real = (weighted_sum_real / to_normalize).item()
  weighted_multinoulli_pred_gray = (weighted_sum_gray / to_normalize).item()
  weighted_multinoulli_pred_fake = (weighted_sum_fake/ to_normalize).item()

  multinoulli_pred_fake = (sum_fake / len(test_dataset) ).item()
  multinoulli_pred_real = (sum_real / len(test_dataset) ).item()
  multinoulli_pred_gray = (sum_gray / len(test_dataset) ).item()

  accuracy_fake = ( accuracy_fake / len(test_dataset) ).item()
  accuracy_real = ( accuracy_color / len(test_dataset) ).item()
  accuracy_gray = ( accuracy_gray / len(test_dataset) ).item()

  print(f"Weighting metric, fake: {weighted_multinoulli_pred_fake}")
  print(f"Weighting metric, real: {weighted_multinoulli_pred_real}") 
  print(f"Weighting metric, gray: {weighted_multinoulli_pred_gray}")
  
  print()

  print(f"Accuracy metric, fake: {accuracy_fake}")
  print(f"Accuracy metric, real: {accuracy_real}") 
  print(f"Accuracy metric, gray: {accuracy_gray}") 

In [None]:
import warnings
#["WGAN_9k_120.pt", "cGAN_big.pt", "cGAN_small_16_100.pt", "cGAN_small_32_100.pt", "cGAN_small_8.pt" ]

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    evaluate_generator(Gs[0])

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    evaluate_generator(Gs[1])

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    evaluate_generator(Gs[2])

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    evaluate_generator(Gs[3])

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    evaluate_generator(Gs[4])

# Weight comparison - second metric

## 6.1 Dataset and dataloader

In [None]:
def build_dataset(path_noadd = "data_big_training.txt"):

  #starting path for the kaggle dataset
  start_path = '/content/animal_data/Animal Image Dataset/'

  with open(path_noadd) as file:
    val_paths = [line.rstrip() for line in file]

  images = []
  labels = []

  for folders, subfolders, files in os.walk(start_path,topdown=True):
    label = folders.split('/')[4]
    for file in files:

      path_file = start_path + label + '/' + file

      if path_file not in val_paths:  
        images.append(path_file)
        labels.append(label)
      
  data = {'Images':images, 'Labels':labels} 
  data = pd.DataFrame(data) 

  lb = LabelEncoder()
  data['encoded_labels'] = lb.fit_transform(data['Labels'])

  return data

In [None]:
new_data = build_dataset(path_noadd = "data_big_training.txt")
print(len(new_data))

12978


In [None]:
classifier_dataset = Animals_Dataset(new_data,trans_classifier, trans_gan)

In [None]:
classifier_loader = torch.utils.data.DataLoader(classifier_dataset, batch_size=128,
                                                shuffle = True)

## 6.2 Classifier training

In this second metric we train two different networks. One with the original colored images and one with the images colored by the GAN. If the colorization is good then the weights of the two trained models shouldn't be very different.

## 6.2.1 Train first model

Train this model with the original colored images

In [None]:
from torchvision import models, transforms

classifier_model = models.vgg16(pretrained=True)
classifier_model.classifier[6] = nn.Linear(in_features=4096, out_features=12)

classifier_model = classifier_model.to(device)

In [None]:
learning_rate = 0.005

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(classifier_model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)

# Train the model
total_step = len(classifier_loader)

In [None]:
n_epochs = 2
print_every = 25
valid_loss_min = np.Inf
val_loss = []
val_acc = []
train_loss = []
train_acc = []
total_step = len(classifier_loader)

for epoch in range(1, n_epochs+1):
    running_loss = 0.0
    # scheduler.step(epoch)
    correct = 0
    total=0
    print(f'Epoch {epoch}\n')

    for i, (images, _, _, labels) in tqdm(enumerate(classifier_loader), total = len(classifier_loader)):

        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = classifier_model(images)
        #print(outputs)
        #print(labels)
        loss = criterion(outputs, labels)
        #print(loss)

        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        _,pred = torch.max(outputs, dim=1)
        correct += torch.sum(pred==labels).item()
        total += labels.size(0)

        if (i) % print_every == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch, n_epochs, i, total_step, loss.item()))
            
    train_acc.append(100 * correct / total)
    train_loss.append(running_loss/total_step)
    print(f'\ntrain loss: {np.mean(train_loss):.4f}, train acc: {(100 * correct / total):.4f}')

    classifier_model.train()

Save the trained network

In [None]:
torch.save(classifier_model.state_dict(), 'vgg16_real_images.pt')

## 6.2.2 Train second model

Train this model with the images colored by the GAN

In [None]:
from torchvision import models, transforms

classifier_fake_images = models.vgg16(pretrained=True)
classifier_fake_images.classifier[6] = nn.Linear(in_features=4096, out_features=12)

classifier_fake_images = classifier_fake_images.to(device)

In [None]:
learning_rate = 0.005

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_fake = torch.optim.SGD(classifier_fake_images.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)

# Train the model
total_step = len(classifier_loader)

In [None]:
def train_classifier_generator(G, epochs = 2, print_every = 25):

  n_epochs = 2
  print_every = 25
  valid_loss_min = np.Inf
  val_loss = []
  val_acc = []
  train_loss = []
  train_acc = []
  total_step = len(classifier_loader)

  G.train()
  classifier_fake_images.train()

  for epoch in range(1, n_epochs+1):
      running_loss = 0.0
      # scheduler.step(epoch)
      correct = 0
      total=0
      print(f'Epoch {epoch}\n')

      for i, (images, _, L, labels) in tqdm(enumerate(classifier_loader), total = len(classifier_loader)):

          # Move tensors to the configured device
          images = images.to(device)
          labels = labels.to(device)
          L = L.to(device)
          ab = G(L).detach()

          rgb_fake = torch.from_numpy(convert_lab_to_rgb(L, ab)).permute(0,3,1,2)
          rgb_fake = trans_gan_to_classifier(rgb_fake).to(device).detach()

          optimizer_fake.zero_grad()

          outputs = classifier_fake_images(rgb_fake)
          loss = criterion(outputs, labels)


          loss.backward()
          optimizer_fake.step()
          
          # print statistics
          running_loss += loss.item()
          _,pred = torch.max(outputs, dim=1)
          correct += torch.sum(pred==labels).item()
          total += labels.size(0)

          if (i) % print_every == 0:
              print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                    .format(epoch, n_epochs, i, total_step, loss.item()))
              
      train_acc.append(100 * correct / total)
      train_loss.append(running_loss/total_step)
      print(f'\ntrain loss: {np.mean(train_loss):.4f}, train acc: {(100 * correct / total):.4f}')

      classifier_fake_images.train()


In [None]:
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    train_classifier_generator(Gs[4])

In [None]:
torch.save(classifier_fake_images.state_dict(), 'classifier_cGAN_small_8.pt')

To save the classifier

In [None]:
!cp /content/classifier_cGAN_small_8.pt /content/drive/MyDrive/TrainedNets/Weight_Comparison

## 6.3 Weights comparison

Apply the L2-norm to the weights of the two models. If this value is small then the recolorization process is good.

In [None]:
def weights_distance(m_1, m_2) :

  m_1_list = list()
  for name, p in m_1.named_parameters():
    m_1_list.append(p)

  m_2_list = list()
  for name, p in m_2.named_parameters():
    m_2_list.append(p)


  diff = list()

  for i in range(len(m_1_list)):
    diff.append(torch.sum(torch.sub(m_1_list[i], m_2_list[i])**2))


  distance = 0.0

  for i in range(len(diff)) :
    distance += diff[i]

  return torch.sqrt(distance).item()

The smaller the distance, the better it is.

In [None]:
distance = weights_distance(classifier_model, classifier_fake_images)

print(distance) #l2-norm between the weights of the two models

## 6.4 Importing saved model to evaluate weight distance

In [None]:
!cp /content/drive/MyDrive/TrainedNets/Weight_Comparison/classifier_cGAN_big.pt /content/

In [None]:
m_2 = models.vgg16(pretrained=True)
m_2.classifier[6] = nn.Linear(in_features=4096, out_features=12)

m_2.load_state_dict(torch.load("/content/vgg16_real_images.pt"))

m_2 = m_2.to(device)

In [None]:
m_1 = models.vgg16(pretrained=True)
m_1.classifier[6] = nn.Linear(in_features=4096, out_features=12)

m_1.load_state_dict(torch.load("/content/classifier_cGAN_small_32.pt"))

m_1 = m_1.to(device)

In [None]:
distance = weights_distance(m_1, m_2)

print(distance) #l2-norm between the weights of the two models