In [28]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import torch.autograd as autograd

import cv2
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from tqdm.autonotebook import tqdm

from sklearn.model_selection import train_test_split

#!pip install -q torchsummary
#from torchsummary import summary

In [2]:
seed = 42
torch.random.manual_seed(seed)
np.random.seed(seed)

BATCH_SIZE = 8
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
epochs = 100
mse_loss_lambda = 100

PATH = 'model.pth'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currently using {device.upper()} device.')

In [3]:
path = r'../input/image-colorization-dataset/data/'

grays = glob(path + 'train_black' + '/*.jpg')
colored = glob(path + 'train_color' + '/*.jpg')

grays = sorted([str(x) for x in grays])
colored = sorted([str(x) for x in colored])

df_train = pd.DataFrame(data={'gray': grays, 'color': colored})

In [4]:
grays = glob(path + 'test_black' + '/*.jpg')
colored = glob(path + 'test_color' + '/*.jpg')

grays = sorted([str(x) for x in grays])
colored = sorted([str(x) for x in colored])

test = pd.DataFrame(data={'gray': grays, 'color': colored})

In [5]:
train, valid = train_test_split(df_train, test_size=0.25, shuffle=True, random_state=seed)
print(f'Train size: {len(train)}, valid size: {len(valid)}, test size: {len(test)}.')

In [6]:
train_transforms = T.Compose([
                              T.ToPILImage(),
                              T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)),
                              T.RandomHorizontalFlip(p=0.1),
                              T.ToTensor(),
                              T.Normalize(0.5, 0.5),
])
valid_transforms = T.Compose([
                              T.ToPILImage(),
                              T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)),
                              T.ToTensor(),
                              T.Normalize(0.5, 0.5),
])

def denormalize(image_tensor):
    return (image_tensor + 1) / 2.0

In [7]:
class ColorDataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, ix):
        row = self.df.iloc[ix].squeeze()
        gray_image = cv2.imread(row['gray'])
        gray_image = cv2.cvtColor(gray_image, cv2.COLOR_BGR2RGB)
        color_image = cv2.imread(row['color'])
        color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
        return gray_image, color_image
    
    def collate_fn(self, batch):
        grays, colored = list(zip(*batch))
        grays = [self.transforms(img)[None] for img in grays]
        colored = [self.transforms(img)[None] for img in colored]
        grays, colored = [torch.cat(i).to(device) for i in [grays, colored]]
        return grays, colored

In [8]:
train_dataset = ColorDataset(train, train_transforms)
valid_dataset = ColorDataset(valid, valid_transforms)
test_dataset = ColorDataset(test, valid_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=valid_dataset.collate_fn, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE//4, shuffle=False, collate_fn=test_dataset.collate_fn, drop_last=True)

In [9]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x

class DownConv(nn.Module):
    def __init__(self, ni, no, maxpool=True):
        super().__init__()
        self.model = nn.Sequential(
            nn.MaxPool2d(2) if maxpool else Identity(),
            nn.Conv2d(ni, no, 3, padding=1),
            nn.BatchNorm2d(no),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(no, no, 3, padding=1),
            nn.BatchNorm2d(no),
            nn.LeakyReLU(0.2, inplace=True),
        )
    def forward(self, x):
        return self.model(x)
    
class UpConv(nn.Module):
    def __init__(self, ni, no, maxpool=True):
        super().__init__()
        self.convtranspose = nn.ConvTranspose2d(ni, no, 2, stride=2)
        self.convlayers = nn.Sequential(
            nn.Conv2d(no+no, no, 3, padding=1),
            nn.BatchNorm2d(no),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(no, no, 3, padding=1),
            nn.BatchNorm2d(no),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
    def forward(self, x, y):
        x = self.convtranspose(x)
        x = torch.cat([x,y], axis=1)
        x = self.convlayers(x)
        return x
    
class UNet(nn.Module):
    def __init__(self, maxpool=False):
        super().__init__()
        self.d1 = DownConv( 3, 64, maxpool=maxpool)
        self.d2 = DownConv( 64, 128)
        self.d3 = DownConv( 128, 256)
        self.d4 = DownConv( 256, 512)
        self.d5 = DownConv( 512, 1024)
        self.u5 = UpConv (1024, 512)
        self.u4 = UpConv ( 512, 256)
        self.u3 = UpConv ( 256, 128)
        self.u2 = UpConv ( 128, 64)
        self.u1 = nn.Conv2d(64, 3, kernel_size=1, stride=1)

    def forward(self, x):
        x0 = self.d1(x)
        x1 = self.d2(x0)
        x2 = self.d3(x1)
        x3 = self.d4(x2)
        x4 = self.d5(x3)
        X4 = self.u5(x4, x3)
        X3 = self.u4(X4, x2)
        X2 = self.u3(X3, x1)
        X1 = self.u2(X2, x0)
        X0 = self.u1(X1)
        return X0

In [35]:
summary(UNet(maxpool=False), (3,224,224))

In [61]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.discriminator(input)

In [62]:
summary(Discriminator().to(device), (3,224,224))

In [63]:
criterion_content_loss = nn.MSELoss()

def gan_loss(model_type, **kwargs):
    if model_type == 'G':
        recon_discriminator_out = kwargs['recon_discriminator_out']
        return nn.functional.binary_cross_entropy(recon_discriminator_out, torch.ones_like(recon_discriminator_out))

    elif model_type == 'D':
        color_discriminator_out = kwargs['color_discriminator_out']
        recon_discriminator_out = kwargs['recon_discriminator_out']
        real_loss = nn.functional.binary_cross_entropy(color_discriminator_out, torch.ones_like(color_discriminator_out))
        fake_loss = nn.functional.binary_cross_entropy(recon_discriminator_out, torch.zeros_like(recon_discriminator_out))
        return (real_loss + fake_loss) / 2.0
    
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1, 0.02)
        nn.init.zeros_(m.bias)

In [64]:
generator = UNet().apply(init_weights).to(device)
discriminator = Discriminator().apply(init_weights).to(device)

optimizerG = torch.optim.AdamW(generator.parameters(), lr=1e-4, betas=(0.5, 0.999), amsgrad=True, weight_decay=1e-6)
optimizerD = torch.optim.AdamW(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999), amsgrad=True, weight_decay=1e-6)

lr_lambda = lambda epoch: (1 - (epoch - 20) / 100) if epoch > 20 else 1
schedulerG = torch.optim.lr_scheduler.LambdaLR(optimizerG, lr_lambda=lr_lambda)
schedulerD = torch.optim.lr_scheduler.LambdaLR(optimizerD, lr_lambda=lr_lambda)

In [17]:
def train_one_batch(generator, discriminator, data, criterion_content_loss, optimizerG, optimizerD):
    generator.train()
    discriminator.train()

    gray, colored = data

    recon = generator(gray)
    color_discriminator_out = discriminator(colored)
    recon_discriminator_out = discriminator(recon)

    optimizerD.zero_grad()
    kwargs = {
              'color_discriminator_out': color_discriminator_out, 
              'recon_discriminator_out': recon_discriminator_out,  
             }
    d_loss = gan_loss('D', **kwargs)
    
    d_loss.backward(retain_graph=True)
    optimizerD.step()

    optimizerG.zero_grad()
    
    recon = generator(gray)
    recon_discriminator_out = discriminator(recon)
    
    kwargs = {
              'recon_discriminator_out': recon_discriminator_out, 
              }    
    g_loss_gan = gan_loss('G', **kwargs) * mse_loss_lambda
    content_loss_g = criterion_content_loss(recon, colored) 
    g_loss = g_loss_gan + content_loss_g
    g_loss.backward()
    optimizerG.step()

    torch.nn.utils.clip_grad_norm_(generator.parameters(), 10.0)
    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 10.0)

    return d_loss.item(), g_loss.item()

@torch.no_grad()
def validate(generator, discriminator, data, criterion_content_loss):
    generator.eval()
    discriminator.eval()

    gray, colored = data
    recon = generator(gray)
    color_discriminator_out = discriminator(colored)
    recon_discriminator_out = discriminator(recon)
    kwargs = {
              'color_discriminator_out': color_discriminator_out, 
              'recon_discriminator_out': recon_discriminator_out,  
             }
    d_loss = gan_loss('D', **kwargs)
    
    kwargs = {
              'recon_discriminator_out': recon_discriminator_out, 
              }    
    g_loss_gan = gan_loss('G', **kwargs) * mse_loss_lambda
    content_loss_g = criterion_content_loss(recon, colored) 
    g_loss = g_loss_gan + content_loss_g

    return d_loss.item(), g_loss.item()

@torch.no_grad()
def visual_validate(data, model):
    img, tar = data
    model.eval()
    out = model(img)
    i = np.random.randint(0, BATCH_SIZE-1)
    out, img, tar = out[i], img[i], tar[i]
    out, img, tar = [denormalize(tensor) for tensor in [out, img, tar]]
    out, img, tar = [tensor.squeeze().cpu().detach().numpy().transpose(1,2,0) for tensor in [out, img, tar]]
    plt.figure(figsize=(12,14))
    plt.subplot(131)
    plt.title('Black')
    plt.imshow(img)
    plt.subplot(132)
    plt.title('Colored')
    plt.imshow(tar)
    plt.subplot(133)
    plt.title('Black Colored')
    plt.imshow(out)
    plt.show()
    plt.pause(0.001)

In [None]:
train_d_losses, train_g_losses, valid_d_losses, valid_g_losses = [], [], [], []

for epoch in range(epochs):
    print(f'Epoch {epoch + 1}/{epochs}')

    train_epoch_d_loss, train_epoch_g_loss = [],[]
    for _, data in enumerate(tqdm(train_dataloader, leave=False)): 
        d_loss, g_loss = train_one_batch(generator, discriminator, data, criterion_content_loss, optimizerG, optimizerD)
        train_epoch_d_loss.append(d_loss)
        train_epoch_g_loss.append(g_loss)
    epoch_d_loss = np.array(train_epoch_d_loss).mean()
    epoch_g_loss = np.array(train_epoch_g_loss).mean()
    train_d_losses.append(epoch_d_loss)
    train_g_losses.append(epoch_g_loss)
    print(f'Train D loss: {epoch_d_loss:.4f}, train G loss: {epoch_g_loss:.4f}')

    valid_epoch_g_loss, valid_epoch_d_loss = [],[]
    for _, data in enumerate(tqdm(valid_dataloader, leave=False)):
        d_loss, g_loss = validate(generator, discriminator, data, criterion_content_loss)
        valid_epoch_g_loss.append(g_loss)
        valid_epoch_d_loss.append(d_loss)
    epoch_g_loss = np.array(valid_epoch_g_loss).mean()
    epoch_d_loss = np.array(valid_epoch_d_loss).mean()
    valid_g_losses.append(epoch_g_loss)
    valid_d_losses.append(epoch_d_loss)
    print(f'Validation D loss: {epoch_d_loss:.4f}, validation G loss: {epoch_g_loss:.4f}')
    print('-'*50)    

    schedulerD.step()
    schedulerG.step()
    if (epoch + 1) % 2 == 0:
        checkpoint = {
                      'epoch': epoch,     
                      'G': generator,
                      'D': discriminator,
                      'optimizerG': optimizerG,
                      'optimizerD': optimizerD,
                      }
        torch.save(checkpoint, PATH)
        data = next(iter(valid_dataloader))
        visual_validate(data, generator)

In [25]:
data = next(iter(valid_dataloader))
visual_validate(data, generator)