# IMAGE COLORIZATION USING cGAN

**Contents**

1. [Import Packages](#packages)
2. [Utils](#utils)
3. [Data preparation](#dataset)
4. [Generator architecture](#generator)
5. [Discriminator architecture](#discriminator)
6. [Trainer](#training)
7. [Validation](#validation)

## Import Packages <a class="anchor" id="packages"></a>

In [None]:
import os
import numpy as np
import seaborn as sns
import random
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from torchsummary import summary

from PIL import Image

from typing import List

from tqdm.notebook import tqdm


from colour import sRGB_to_XYZ, XYZ_to_Lab, Lab_to_XYZ, XYZ_to_sRGB

import gc

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

## Utils <a class="anchor" id="utils"></a>

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
print(device)

In [None]:
def lab_to_rgb(L, ab, device):
    """
    Takes a batch of images
    """
    L = 100 * L
    ab = (ab - 0.5) * 256
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).detach().cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img = Lab_to_XYZ(img)
        img = XYZ_to_sRGB(img)
        rgb_imgs.append(img)
    return torch.tensor(np.stack(rgb_imgs, axis=0)).permute(0, 3, 1, 2).to(device)

In [None]:
# image_size = (384, 384)
image_size = (256, 128)

t = transforms.Compose([
    transforms.Resize(image_size, antialias=True),
    transforms.Normalize(mean=0, std=0.5)
])

In [None]:
class ImageColorizeDataset(Dataset):
    def __init__(self, path: str, device='cpu', train: bool = False, transforms = None ):
        _mode = 'train' if train else 'test'
        
        self.device = device
        self._input_path = os.path.join(path, f'{_mode}_color')
        
        self.data = os.listdir(self._input_path)
        
        self.transforms = transforms
    
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, idx: int):
        to_tensor = transforms.ToTensor()
        
        item = self.data[idx]
        
        input_ = Image.open(os.path.join(self._input_path, item))
        w, h = input_.size
        left = 0
        top = 0
        right = w
        bottom = h / 2
        input_ = input_.crop((left, top, right, bottom))
        
        input_ = to_tensor(input_)
        
        seed = np.random.randint(2147483647)
        random.seed(seed)
        torch.manual_seed(seed)
        if self.transforms is not None:
            input_ = self.transforms(input_)
        
        img = input_.permute(1, 2, 0).numpy()
        img = sRGB_to_XYZ(img)
        img = XYZ_to_Lab(img).transpose(2, 0, 1).astype("float32")
        
        L = torch.tensor(img[[0], ...] / 100) 
        ab = torch.tensor(img[[1, 2], ...] / 256 + 0.5)
        
        return L.to(device), ab.to(device)

In [None]:
test_data = ImageColorizeDataset("data", transforms=t, train=False)
test_dl = DataLoader(test_data, batch_size=8, shuffle=True)

In [None]:
train_data = ImageColorizeDataset("data",
                                  transforms=t, train=True)
train_dl = DataLoader(train_data, batch_size=8, shuffle=True)

In [None]:
images, targets = next(iter(train_dl))

to_show = lab_to_rgb(images, targets, device=device)

grid = torchvision.utils.make_grid(torch.cat([images.expand(to_show.shape), to_show], dim=0),
                                   nrow=8, padding=0, scale_each=True)

fig = plt.figure(figsize=(16,8))
plt.imshow(grid.cpu().permute(1,2,0))
plt.axis('off')
plt.show()

## Generator <a class="anchor" id="generator"></a>

UNet model was used as a generator. It takes image as an input. To randomize output of generator, dropout layers applied both at training and evalutaion time as a noise.

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, norm_layer = nn.BatchNorm2d):
        super().__init__()
        
        
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        self.identity = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
    
        x_ = x.detach().clone()
        
        x_ = self.block(x_)
        
        residual = self.identity(x)
        
        out = x_ + residual
        
        return self.relu(out)

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_chans, out_chans, sampling_factor=2):
        super().__init__()
        self.block = nn.Sequential(
            nn.MaxPool2d(sampling_factor),
            ConvBlock(in_chans, out_chans)
        )
        
    def forward(self, x):
        return self.block(x)
    
class DecoderBlock(nn.Module):
    def __init__(self, in_chans, out_chans, sampling_factor=2):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        
        self.block = ConvBlock(in_chans + out_chans, out_chans)
        
    def forward(self, x, skip):
        x  = self.upsample(x)
        x = torch.cat([x, skip], dim=1)
        x = self.block(x)
        return x

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=2, dropout_rate=0.2):
        super().__init__()
                
        self.encoder = nn.ModuleList([
            ConvBlock(in_channels, 64),
            EncoderBlock(64, 128),
            EncoderBlock(128, 256),
            EncoderBlock(256, 512),

        ])
        self.decoder = nn.ModuleList([
            DecoderBlock(512, 256),
            DecoderBlock(256, 128),
            DecoderBlock(128, 64)
        ])
        
        self.dropout = nn.Dropout2d(dropout_rate)
        
        self.logits = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=1)
    
    def forward(self, x):
        
        
        encoded = []
        for enc in self.encoder:
            x = enc(x)
            x = self.dropout(x)
            encoded.append(x)
    
        enc_out = encoded.pop()

        for dec in self.decoder:
            enc_out = encoded.pop()
            x = dec(x, enc_out)
        return F.sigmoid(self.logits(x))

In [None]:
gc.collect()

## Discriminator <a class="anchor" id="discriminator"></a>

Due to our input image shape `batch_size x 3 x 384 x 384`, in out PatchGAN discriminator we have 3 sequential `3 x 3` conv

In [None]:
class PatchGAN(nn.Module):
    def __init__(self, in_channels, n_features=64, n_layers=3):
        super().__init__()
        
        
        k_size = 4
        p_size = 2
        
        seq = [nn.Conv2d(in_channels=in_channels, out_channels=n_features, kernel_size=k_size, padding=p_size, stride=2), nn.LeakyReLU(0.2, True)]
        
        f_mult = 1
        f_mult_prev = 1
        
        for i in range(1, n_layers):
            
            f_mult_prev = f_mult
            f_mult = min(2 ** i, 8)
            
            seq.append(nn.Conv2d(in_channels=f_mult_prev * n_features, out_channels=f_mult * n_features, kernel_size=k_size, padding=p_size, stride=2))
            seq.append(nn.BatchNorm2d(f_mult * n_features))
            seq.append(nn.LeakyReLU(0.2, True))

        f_mult_prev = f_mult
        f_mult = min(2 ** n_layers, 8)
        
        seq += [
            nn.Conv2d(n_features * f_mult_prev, n_features * f_mult, kernel_size=k_size, stride=1, padding=p_size),
            nn.BatchNorm2d(n_features * f_mult),
            nn.LeakyReLU(0.2, True)
        ]

        seq += [nn.Conv2d(n_features * f_mult, 1, kernel_size=k_size, stride=1, padding=p_size)]  # output 1 channel prediction map
        self.model = nn.Sequential(*seq)

    
    def forward(self, x, label):
        
        x = torch.cat((x, label), dim=1)
        x = self.model(x)
        return x

In [None]:
D = PatchGAN(in_channels=3).to(device)
#D2 = PatchGAN(in_channels=3).to(device)

# summary(D, [(1, 384, 384), (2, 384, 384)])
summary(D, [(1, image_size[0], image_size[1]), (2, image_size[0], image_size[1])])

In [None]:
G = UNet(in_channels=1, out_channels=2).to(device)
#G2 = UNet(in_channels=2, out_channels=1).to(device)

# summary(G, (1, 384, 384))
summary(G, (1, image_size[0], image_size[1]))

## Trainer <a class="anchor" id="training"></a>

In [None]:
class Trainer: 
    def __init__(self, G, D, device, batch_size = 64, lr=3e-4, discriminator_to_generator_training_rate = 2, plot_rate=1):
        
        self.G = G
        self.D = D
        
        self.L1_G_loss = nn.L1Loss()
        self.G_loss = nn.BCEWithLogitsLoss()
        self.D_loss = nn.BCEWithLogitsLoss()
        
        self.G_optim = torch.optim.Adam(params=G.parameters(), lr=lr)
        self.D_optim = torch.optim.Adam(params=D.parameters(), lr=lr)

        
        self.batch_size = batch_size
        
        self.loss_G_per_epoch = []
        self.loss_D_per_epoch = []
        self.loss_D_real_per_epoch = []
        self.loss_D_fake_per_epoch = []
        
        self.loss_G_history = []
        self.loss_D_history = []
        self.loss_D_real_history = []
        self.loss_D_fake_history = []
        
        self.k = discriminator_to_generator_training_rate
        self.device = device
        
        self.plot_rate = plot_rate
    
    def train(self, dataloader, epoch):
        for epoch in range(epoch):
            self.loss_G_history.append([])
            self.loss_D_history.append([])
            self.loss_D_real_history.append([])
            self.loss_D_fake_history.append([])
            
            print(f'EPOCH: {epoch + 1}')
            for i, (images, targets) in enumerate(tqdm(dataloader)):
                self._train_discriminator(images, targets)
                
                if (i + 1) % self.k == 0:
                    self._train_generator(images, targets)
    
            self.loss_G_per_epoch.append(np.mean(self.loss_G_history[-1]))
            self.loss_D_per_epoch.append(np.mean(self.loss_D_history[-1]))
            self.loss_D_real_per_epoch.append(np.mean(self.loss_D_real_history[-1]))
            self.loss_D_fake_per_epoch.append(np.mean(self.loss_D_fake_history[-1]))
            
            if (epoch + 1) % self.plot_rate == 0:
                self._plot_epoch_stats(epoch)
                self._plot_fake_images(images, targets)
            self._plot_stats

    
    def _train_generator(self, inputs, targets):
        
        self.G_optim.zero_grad()
        
        self.G.train()
        self.D.eval()
        
        l = 1
        
        fake_targets = self.G(inputs)
        
        predictions = self.D(inputs, fake_targets)
        fake_labels = torch.zeros(*predictions.shape,  device=self.device)
        
        L1_loss = self.L1_G_loss(fake_targets, targets)
        BCE_loss = self.G_loss(predictions, fake_labels)
        loss_g = BCE_loss + l * L1_loss
        self.loss_G_history[-1].append(loss_g.item())
        loss_g.backward()
        self.G_optim.step()
                
    
    def _train_discriminator(self, inputs, real_targets):
        self.D_optim.zero_grad()
        
        self.G.eval()
        self.D.train()
        
        # train on real images
        
        real_predictions = self.D(inputs, real_targets)
        real_label = torch.ones(*real_predictions.shape,  device=self.device)

        real = self.D_loss(real_predictions, real_label)
        
        
        # train on fake images
                
        fake_targets = self.G(inputs)
        fake_predictions = self.D(inputs, fake_targets.detach())
        fake_label = torch.zeros(*fake_predictions.shape, device=self.device)

        fake = self.D_loss(fake_predictions, fake_label)
        
        loss_d = real + fake
        
        self.loss_D_history[-1].append(loss_d.item())
        self.loss_D_real_history[-1].append(real.item())
        self.loss_D_fake_history[-1].append(fake.item())

        loss_d.backward()
        self.D_optim.step()
        
    def _plot_fake_images(self, images, targets, nrow=8):
        """
        Showing the generator's results
        """
        
        self.G.eval()
                
        fake = self.G(images)
        
        to_show = lab_to_rgb(images, fake, device=self.device)
        to_show_real = lab_to_rgb(images, targets, device=self.device)
        a = fake[:, 0, :, :].unsqueeze(1).expand(to_show.shape)
        b = fake[:, 1, :, :].unsqueeze(1).expand(to_show.shape)
        grid = torchvision.utils.make_grid(torch.cat([images.expand(to_show.shape), a, b, to_show, to_show_real], dim=0), nrow=8, padding=1, scale_each=True)

        fig = plt.figure(figsize=(16,8))
        plt.imshow(grid.cpu().permute(1,2,0))
        plt.axis('off')
        plt.show()
        
    def _plot_stats(self):
        """
        Plotting stats of history training
        """
        fig, axes = plt.subplots(2, 2, figsize=(10, 4))
        sns.lineplot(self.loss_D_per_epoch, label="discriminator", ax=axes[0][0])
        sns.lineplot(self.loss_G_per_epoch, label="generator", ax=axes[0][1])
        
        sns.lineplot(self.loss_D_real_per_epoch, label="real", ax=axes[1][0])
        sns.lineplot(self.loss_D_fake_per_epoch, label="fake", ax=axes[1][1])
        
        plt.tight_layout()
        plt.show()

        
    def _plot_epoch_stats(self, epoch):
        """
        Plotting stats of history training
        """
        fig, axes = plt.subplots(2, 2, figsize=(10, 4))
        sns.lineplot(self.loss_D_history[epoch], label="discriminator", ax=axes[0][0])
        sns.lineplot(self.loss_G_history[epoch], label="generator", ax=axes[0][1])
        
        sns.lineplot(self.loss_D_real_history[epoch], label="real", ax=axes[1][0])
        sns.lineplot(self.loss_D_fake_history[epoch], label="fake", ax=axes[1][1])
        
        plt.tight_layout()
        plt.show()
        
    def _plot_losses(self):
        """
        Plot the generator and discriminator losses per epoch.
        """
        epochs = range(1, len(self.loss_G_per_epoch) + 1)

        plt.figure(figsize=(8, 5))
        plt.plot(epochs, self.loss_G_per_epoch, label='Generator loss', color='blue')
        plt.plot(epochs, self.loss_D_per_epoch, label='Discriminator loss', color='red')
        plt.title('Loss vs. Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

In [None]:
trainer = Trainer(G, D, device)

In [None]:
trainer.train(train_dl, epoch=10)

In [None]:
trainer._plot_losses()

In [None]:
trainer.G