In [7]:
# pip install fastai==2.4

In [None]:
runtime = 'kaggle'  # 'local', 'colab', or 'kaggle
starting_epoch = 1  # Set to some other number if you want to continue training from previous checkpoint
upsample = False  # Set to 'True' for using upsampled generator
loss_type = 'l1'  # 'l1' or 'perceptual'

pretrain_generator = False  # True if you want to pre-trained generator
load_pretrained_generator = False  # True if you have pre-trained weights of generator
batch_size = 16  # Batch size for training (change depending on how much memory you have)
n_epochs = 50  # Number of epochs to train for
lmbda = 100.  # Set-up hyper-parameter lambda for L1/Perceptual Loss term

In [8]:
if runtime == 'colab':

    # Mount google-drive
    from google.colab import drive
    drive.mount('/content/gdrive')

    # Copy necessary files to the current environment
    !cp gdrive/MyDrive/Colorization/utils.py .
    !cp gdrive/MyDrive/Colorization/generator.py .
    !cp gdrive/MyDrive/Colorization/discriminator.py .
    !cp gdrive/MyDrive/Colorization/evaluation_metrics.py .
    !cp gdrive/MyDrive/Colorization/network_Lab_VGG.py .
    !cp gdrive/MyDrive/Colorization/perceptual_loss.py .
    !cp gdrive/MyDrive/Colorization/residual_upsampled_generator.py .
    !cp gdrive/MyDrive/Colorization/LabVGG16_BN_epoch120_batchsize32.pth .

    # Extract the dataset in the current environment
    !unzip gdrive/MyDrive/Colorization/data.zip
    
if runtime == 'kaggle':
    import sys
    sys.path.insert(1, '../input/colorization/Colorization')

In [9]:
import os
import time
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

from generator import UNet
from discriminator import PatchGAN
from utils import load_transformed_batch, load_rgb_batch, init_weights, lab_to_rgb
from evaluation_metrics import mean_absolute_error, epsilon_accuracy, peak_signal_to_noise_ratio

import torch
from torch import optim
from torchvision import transforms

from residual_upsampled_generator import UNetUpsampled
from perceptual_loss import PerceptualLoss

import warnings
warnings.filterwarnings("ignore")

In [10]:
# Find available device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
    
# Set the location of result directory
if runtime == 'colab':
    res_dir = os.path.join(os.getcwd(), 'gdrive', 'MyDrive', 'Colorization', 'results')
    model_dir = os.path.join(os.getcwd(), 'gdrive', 'MyDrive', 'Colorization', 'models')
else:
    res_dir = os.path.join(os.getcwd(), 'results')
    model_dir = os.path.join(os.getcwd(), 'models')

# Create model and result directory if they do not exist
if not os.path.exists(res_dir):
    os.makedirs(res_dir)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
    
# Create a log file
if starting_epoch == 1:
    header = 'epoch,generator-adversarial-loss,perceptual-or-l1-loss,generator-loss-total,discriminator-loss,mae,epsilon,psnr'
    with open(os.path.join(res_dir, 'logs.csv'), 'w') as f:
        np.savetxt(f, [], delimiter=',', header=header, comments='')

# Root directory for data
if runtime == 'kaggle':
    data_root = '../input/colorization/Colorization/data/data'
else:
    data_root = os.path.join(os.getcwd(), 'data')
train_dir = os.path.join(data_root, 'train')
test_dir = os.path.join(data_root, 'test')
vis_dir = os.path.join(data_root, 'visualize')
train_files = os.listdir(os.path.join(data_root, 'train'))
test_files = os.listdir(os.path.join(data_root, 'test'))
vis_files = os.listdir(os.path.join(data_root, 'visualize'))

In [11]:
# Display 16 randomly chosen sample images from train data
random_files = np.random.choice(train_files, size=16)
random_samples = [os.path.join(train_dir, x) for x in random_files]

_, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), random_samples):
    ax.imshow(Image.open(img_path))
    ax.axis("off")

In [14]:
# Transformations for the training data
train_transforms = transforms.Compose([transforms.Resize((256, 256), Image.BICUBIC),
                                       transforms.RandomHorizontalFlip()])  # for data augmentation

# Transformations for training upsampled architecture
upsample_transforms = transforms.Compose([transforms.Resize((512, 512), Image.BICUBIC),
                                          transforms.ToTensor()])

# Transformations for testing the model (resize input image to suitable size for model input)
val_transforms = transforms.Compose([transforms.Resize((256, 256), Image.BICUBIC)])

# Create generator object and initialize weights (normally)
if upsample:
    generator = UResNet(in_channels=1, out_channels=3, n_filters=64)
else:
    generator = UNet(in_channels=1, out_channels=2, n_filters=64)
generator = init_weights(generator)
generator.to(device)

# Create discriminator object and initialize weights (normally)
discriminator = PatchGAN(in_channels=3)
discriminator = init_weights(discriminator)
discriminator.to(device)

# Load pre-trained weights to continue training
if starting_epoch != 1:
    generator.load_state_dict(torch.load(os.path.join(model_dir, 'generator.pth'), map_location=device))
    discriminator.load_state_dict(torch.load(os.path.join(model_dir, 'discriminator.pth'), map_location=device))
        
# Set-up optimizer and scheduler
generator_optimizer = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Set-up loss functions
criterion = torch.nn.BCEWithLogitsLoss()
if loss_type == 'l1':
    additional_criterion = torch.nn.L1Loss()
else:
    additional_criterion = PerceptualLoss(color='rgb')

# Set-up labels for real and fake predictions
real_label = torch.tensor(1.0)
fake_label = torch.tensor(0.0)

# Calculate the number of batches
n_batches = int(len(train_files)/batch_size)

In [17]:
if load_pretrained_generator:
    
    from fastai.vision.learner import create_body
    from torchvision.models.resnet import resnet18
    from fastai.vision.models.unet import DynamicUnet
    
    body = create_body(resnet18, pretrained=True, n_in=1, cut=-2)
    generator = DynamicUnet(body, n_out=2, img_size=(256, 256)).to(device)
    generator.load_state_dict(torch.load(os.path.join('../input/colorization/generator_pretrained.pth'), map_location=device))

# Pre-training the generator for 20 epochs
if pretrain_generator:
    
    from fastai.vision.learner import create_body
    from torchvision.models.resnet import resnet18
    from fastai.vision.models.unet import DynamicUnet
    
    print('Pre-training the Generator')

    body = create_body(resnet18, pretrained=True, n_in=1, cut=-2)
    generator = DynamicUnet(body, n_out=2, img_size=(256, 256)).to(device)

    pretrain_generator_optimizer = optim.Adam(generator.parameters(), lr=1e-4)

    for epoch in range(1, 20+1):

        print('Epoch {}/{}'.format(epoch, 20))
        print('-' * 10)

        running_generator_loss_l1 = 0.0

        generator.train()

        # Iterate over all the batches
        for j in tqdm(range(n_batches), desc='Batch'):

            # Get the train data and labels for the current batch
            batch_files = train_files[j*batch_size:(j+1)*batch_size]
            L, ab = load_transformed_batch(train_dir, batch_files, train_transforms)
            
            # Put the data to the device
            L, ab = L.to(device), ab.to(device)

            with torch.set_grad_enabled(True):

                # Run the batch through the generator
                output = generator(L)

                # Calculate the loss
                generator_loss_l1 = additional_criterion(output, ab)

                # Make gradients zero
                pretrain_generator_optimizer.zero_grad()

                # backward + optimize
                generator_loss_l1.backward()
                pretrain_generator_optimizer.step()

            running_generator_loss_l1 += generator_loss_l1.item() * batch_size

        # Calculate average loss for current epoch
        epoch_generator_loss_l1 = running_generator_loss_l1 / (n_batches*batch_size)

        print('Generator Loss: {:.4f}'.format(epoch_generator_loss_l1))

        # Save the generator and discriminator model
        torch.save(generator.state_dict(), os.path.join(model_dir, 'generator_pretrained.pth'))

In [None]:
# Adversarial Training
for epoch in range(starting_epoch, n_epochs+1):
    
    # Variable to record time taken in each epoch
    since = time.time()
    
    print('Epoch {}/{}'.format(epoch, n_epochs))
    print('-' * 10)

    running_generator_loss_adversarial = 0.0
    running_generator_loss_additional = 0.0
    running_generator_loss_total = 0.0
    running_discriminator_loss_total = 0.0
    running_mae = 0.0
    running_epsilon = 0.0
    running_psnr = 0.0
    
    # Iterate over all the batches
    for j in tqdm(range(n_batches), desc='Batch'):
            
        # Get the train data and labels for the current batch
        batch_files = train_files[j*batch_size:(j+1)*batch_size]
        L, ab = load_transformed_batch(train_dir, batch_files, train_transforms)
        
        # Put the data to the device
        L, ab = L.to(device), ab.to(device)
        
        if upsample:
            rgb_images = load_rgb_batch(train_dir, batch_files, upsample_transforms)
            rgb_images = rgb_images.to(device)
        
        # Create a fake color image using the generator
        fake_color = generator(L)

        # Train the discriminator
        discriminator.train()

        # Enable grads
        for p in discriminator.parameters():
            p.requires_grad = True

        # Make gradients zero before forward pass
        discriminator_optimizer.zero_grad()

        # Run fake examples through the discriminator
        if upsample:
            fake_image = fake_color    
        else:
            fake_image = torch.cat([L, fake_color], dim=1)  # Make dim=0 when passing only one sample
        fake_preds = discriminator(fake_image.detach())
        discriminator_loss_fake = criterion(fake_preds, fake_label.expand_as(fake_preds).to(device))
        
        # Run real examples through the discriminator
        if upsample:
            real_image = rgb_images  
        else:
            real_image = torch.cat([L, ab], dim=1)  # Make dim=0 when passing only one sample
        real_preds = discriminator(real_image)
        discriminator_loss_real = criterion(real_preds, real_label.expand_as(real_preds).to(device))
        
        # Total loss is the sum of both the losses
        discriminator_loss_total = (discriminator_loss_fake + discriminator_loss_real) * 0.5
        
        # backward + optimize
        discriminator_loss_total.backward()
        discriminator_optimizer.step()
        
        # Train the generator while keeping the discriminator weights constant
        generator.train()
        
        # Enable grads
        for p in discriminator.parameters():
            p.requires_grad = False

        # Make gradients zero before forward pass
        generator_optimizer.zero_grad()

        # Calculate the prediction using discriminator
        fake_preds = discriminator(fake_image)
        
        # Calculate adversarial loss for the generator
        generator_loss_adversarial = criterion(fake_preds, real_label.expand_as(real_preds).to(device))
        
        # Calculate L1 loss for the generator (lambda * L1_loss)
        # Total loss is the sum of both the losses
        if loss_type == 'l1':
            generator_loss_additional = additional_criterion(fake_color, ab) * lmbda
            generator_loss_total = generator_loss_adversarial + generator_loss_additional
        else:
            generator_loss_additional = additional_criterion(fake_image, real_image)
            generator_loss_total = generator_loss_adversarial + generator_loss_additional
        
        # backward + optimize
        generator_loss_total.backward()
        generator_optimizer.step()
        
        # Add up the accuracy and losses for current batch
        running_generator_loss_adversarial += generator_loss_adversarial.item() * batch_size
        running_generator_loss_additional += generator_loss_additional.item() * batch_size
        running_generator_loss_total += generator_loss_total.item() * batch_size
        running_discriminator_loss_total += discriminator_loss_total.item() * batch_size
        
        if upsample:
            real_image_array = rgb_images
            fake_image_array = fake_image   
        else:
            real_image_array = torch.from_numpy(lab_to_rgb(L.detach(), ab.detach()))
            fake_image_array = torch.from_numpy(lab_to_rgb(L.detach(), fake_color.detach()))
        
        running_mae += mean_absolute_error(real_image_array, fake_image_array) * batch_size
        running_epsilon += epsilon_accuracy(real_image_array, fake_image_array, epsilon=0.05) * batch_size  # epsilon set at 5% of 255
        running_psnr += peak_signal_to_noise_ratio(real_image_array, fake_image_array) * batch_size

    # Calculate the average accuracy and average loss for current epoch
    epoch_generator_loss_adversarial = running_generator_loss_adversarial / (n_batches*batch_size)
    epoch_generator_loss_additional = running_generator_loss_additional / (n_batches*batch_size)
    epoch_generator_loss_total = running_generator_loss_total / (n_batches*batch_size)
    epoch_discriminator_loss_total = running_discriminator_loss_total / (n_batches*batch_size)
    epoch_mae = running_mae / (n_batches*batch_size)
    epoch_epsilon = running_epsilon / (n_batches*batch_size)
    epoch_psnr = running_psnr / (n_batches*batch_size)
    
    print('Generator Loss Adversarial: {:.4f}'.format(epoch_generator_loss_adversarial))
    print('Generator Loss L1/Perceptual: {:.4f}'.format(epoch_generator_loss_additional))
    print('Generator Loss Total: {:.4f}'.format(epoch_generator_loss_total))
    print('Discriminator Loss Total: {:.4f}'.format(epoch_discriminator_loss_total))
    print('Mean Absolute Error: {:.4f}'.format(epoch_mae))
    print('Epsilon Accuracy: {:.4f}'.format(epoch_epsilon))
    print('Peak SNR: {:.4f}'.format(epoch_psnr))
    
    time_elapsed = time.time() - since
    print('Epoch complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    
    # Save the generator and discriminator model
    torch.save(generator.state_dict(), os.path.join(model_dir, 'generator.pth'))
    torch.save(discriminator.state_dict(), os.path.join(model_dir, 'discriminator.pth'))
    
    # Save loss and accuracy to log file
    log = [[epoch,
            epoch_generator_loss_adversarial,
            epoch_generator_loss_additional,
            epoch_generator_loss_total,
            epoch_discriminator_loss_total,
            epoch_mae,
            epoch_epsilon,
            epoch_psnr]]
    with open(os.path.join(res_dir, 'logs.csv'), 'a') as f:
        np.savetxt(f, log, delimiter=',')
    
    # Test the model on 50 sample images to visualize the colorization
    generator.eval()

    with torch.no_grad():
    
        # Transform the images and get their L and ab channels
        L, ab = load_transformed_batch(vis_dir, vis_files, val_transforms)
        L, ab = L.to(device), ab.to(device)
        
        if upsample:
            # Run the L channel through the generator to get 'rgb' results
            res_images = generator(L).permute(0, 2, 3, 1).detach().numpy()
            # res_images = output_to_rgb(res_images)
        else:
            # Run the L channel through the generator to get 'ab' channels
            res_images = lab_to_rgb(L, generator(L))
            
        # Create directory for saving visualizations of images for the current epoch
        vis_result_dir = os.path.join(res_dir, 'epoch '+str(epoch))
        if not os.path.exists(vis_result_dir):
            os.makedirs(vis_result_dir)
        
        # Save output images for this epoch
        for i in range(len(res_images)):
            image = res_images[i] * 255
            Image.fromarray(image.astype(np.uint8)).save(os.path.join(vis_result_dir, vis_files[i]))
            
    generator.train()