In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T

import cv2
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from tqdm.autonotebook import tqdm
from skimage.color import rgb2lab, lab2rgb

from sklearn.model_selection import train_test_split

!pip install -q torchsummary
from torchsummary import summary

Try:

* batch_size = 8
* lr_G = 1e-4
* lr_D = 1e-4
* Dropout_D?

In [2]:
seed = 42
torch.random.manual_seed(seed)
np.random.seed(seed)

BATCH_SIZE = 8  # initially 4
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
epochs = 100
loss_lambda = 100

PATH = 'model.pth'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currently using {device.upper()} device.')

In [3]:
path = r'../input/image-colorization-dataset/data/'

grays = glob(path + 'train_black' + '/*.jpg')
colored = glob(path + 'train_color' + '/*.jpg')

grays = sorted([str(x) for x in grays])
colored = sorted([str(x) for x in colored])

df_train = pd.DataFrame(data={'gray': grays, 'color': colored})

In [4]:
grays = glob(path + 'test_black' + '/*.jpg')
colored = glob(path + 'test_color' + '/*.jpg')

grays = sorted([str(x) for x in grays])
colored = sorted([str(x) for x in colored])

test = pd.DataFrame(data={'gray': grays, 'color': colored})

In [5]:
train, valid = train_test_split(df_train, test_size=0.25, shuffle=True, random_state=seed)
print(f'Train size: {len(train)}, valid size: {len(valid)}, test size: {len(test)}.')

In [6]:
train_transforms = T.Compose([
                              T.ToPILImage(),
                              T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)),
                              T.RandomHorizontalFlip(p=0.1),
])
final_transforms = T.Compose([
                        T.ToTensor(),
])
valid_transforms = T.Compose([
                              T.ToPILImage(),
                              T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)),
])

In [7]:
class ColorDataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, ix):
        row = self.df.iloc[ix].squeeze()
        color_image = cv2.imread(row['color'])
        color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
        color_image = self.transforms(color_image)
        color_image = np.array(color_image)
        img_lab = rgb2lab(color_image).astype("float32")
        img_lab = final_transforms(img_lab)
        L = img_lab[[0], ...] / 50. - 1.
        ab = img_lab[[1, 2], ...] / 110.
        return L, ab
    
    def collate_fn(self, batch):
        L, ab = list(zip(*batch))
        L = [l[None].to(device) for l in L]
        ab = [a_b[None].to(device) for a_b in ab]
        L, ab = [torch.cat(i) for i in [L, ab]]
        return L, ab

In [32]:
train_dataset = ColorDataset(train, train_transforms)
valid_dataset = ColorDataset(valid, valid_transforms)
test_dataset = ColorDataset(test, valid_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=valid_dataset.collate_fn, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=test_dataset.collate_fn, drop_last=True)

In [13]:
# initially batchnorm2d
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False, instance=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni) if not instance else nn.InstanceNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf) if not instance else nn.InstanceNorm2d(nf)
        
        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64, instance=False):
        super().__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True, instance=instance)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True, instance=instance)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block, instance=instance)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True, instance=instance)
    
    def forward(self, x):
        return self.model(x)

In [38]:
# initially no dropout and batchnorm2d
# instance norm computes learnable parameters for every instance in batch, it is better when instances have diff. distributions
# Conv > Normalization > Activation > Dropout > Pooling
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # spectral or instance norms? Dropout?
        self.discriminator = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128), # nn.InstanceNorm2d(128), 
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256), # nn.InstanceNorm2d(256), 
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512), # nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.discriminator(input)

In [37]:
criterion_content_loss = nn.L1Loss()  # nn.MSELoss?

def gan_loss(model_type, **kwargs):
    if model_type == 'G':
        recon_discriminator_out = kwargs['recon_discriminator_out']
        return nn.functional.binary_cross_entropy(recon_discriminator_out, torch.ones_like(recon_discriminator_out))

    elif model_type == 'D':
        color_discriminator_out = kwargs['color_discriminator_out']
        recon_discriminator_out = kwargs['recon_discriminator_out']
        real_loss = nn.functional.binary_cross_entropy(color_discriminator_out, torch.ones_like(color_discriminator_out))
        fake_loss = nn.functional.binary_cross_entropy(recon_discriminator_out, torch.zeros_like(recon_discriminator_out))
        return (real_loss + fake_loss) / 2.0
    
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1, 0.02)
        nn.init.zeros_(m.bias)

In [39]:
generator = Unet(input_c=1, output_c=2, n_down=8, num_filters=64, instance=False).apply(init_weights).to(device)
discriminator = Discriminator().apply(init_weights).to(device)

# initially lr_G = 1e-4, lr_D = 2e-4, lr_decay_G = lr_decay_D = 1e-6
# the bigger batch_size the bigger learning rate
optimizerG = torch.optim.AdamW(generator.parameters(), lr=5e-5, betas=(0.5, 0.999), amsgrad=True, weight_decay=1.5e-6)
optimizerD = torch.optim.AdamW(discriminator.parameters(), lr=5e-5, betas=(0.5, 0.999), amsgrad=True, weight_decay=1.5e-6)

schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=20, gamma=0.5)
schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=20, gamma=0.5)

In [40]:
summary(generator, (1,256,256))

In [41]:
summary(discriminator, (3,256,256))

In [17]:
def grad_req(model, is_required=True):
    for param in model.parameters():
        param.requires_grad = True

In [18]:
def train_one_batch(generator, discriminator, data, criterion_content_loss, optimizerG, optimizerD):
    generator.train()
    discriminator.train()

    L, ab = data

    fake_color = generator(L)
    grad_req(discriminator, True)
    
    fake_image = torch.cat([L, fake_color], dim=1)
    fake_preds = discriminator(fake_image.detach())
    
    optimizerD.zero_grad()

    real_image = torch.cat([L, ab], dim=1)
    real_preds = discriminator(real_image)
    kwargs = {
              'color_discriminator_out': real_preds, 
              'recon_discriminator_out': fake_preds,  
             }
    d_loss = gan_loss('D', **kwargs)
    d_loss.backward()
    optimizerD.step()

    grad_req(discriminator, False)
    optimizerG.zero_grad()
    
    fake_image = torch.cat([L, fake_color], dim=1)
    fake_preds = discriminator(fake_image)
    
    kwargs = {
              'recon_discriminator_out': fake_preds, 
              }    
    g_loss_gan = gan_loss('G', **kwargs)
    loss_G_L1 = criterion_content_loss(fake_color, ab) * loss_lambda
    g_loss = g_loss_gan + loss_G_L1
    g_loss.backward()
    optimizerG.step()

    torch.nn.utils.clip_grad_norm_(generator.parameters(), 10.0)
    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 10.0)

    return d_loss.item(), g_loss.item()

@torch.no_grad()
def validate(generator, discriminator, data, criterion_content_loss):
    generator.eval()
    discriminator.eval()

    L, ab = data
    
    fake_color = generator(L)
    fake_image = torch.cat([L, fake_color], dim=1)
    real_image = torch.cat([L, ab], dim=1)
    
    color_discriminator_out = discriminator(real_image)
    recon_discriminator_out = discriminator(fake_image)
    kwargs = {
              'color_discriminator_out': color_discriminator_out, 
              'recon_discriminator_out': recon_discriminator_out,  
             }
    d_loss = gan_loss('D', **kwargs)
    
    kwargs = {
              'recon_discriminator_out': recon_discriminator_out, 
              }    
    g_loss_gan = gan_loss('G', **kwargs) 
    loss_G_L1 = criterion_content_loss(fake_color, ab) * loss_lambda 
    g_loss = g_loss_gan + loss_G_L1

    return d_loss.item(), g_loss.item()

@torch.no_grad()
def visual_validate(data, generator):
    L, ab = data
    generator.eval()
    out = generator(L)
    i = np.random.randint(0, BATCH_SIZE)
    out, L, ab = out[i], L[i], ab[i]
    L = (L + 1) * 50
    ab = ab * 110
    out = out * 110
    true_image = lab2rgb(torch.cat([L, ab]).detach().cpu().numpy(), channel_axis=0).transpose(1,2,0)
    black_image = L.detach().cpu().numpy().squeeze()
    out_image = lab2rgb(torch.cat([L, out]).detach().cpu().numpy(), channel_axis=0).transpose(1,2,0)
    plt.figure(figsize=(12,14))
    plt.subplot(131)
    plt.title('Black')
    plt.imshow(black_image, cmap='gray')
    plt.subplot(132)
    plt.title('Colored')
    plt.imshow(true_image)
    plt.subplot(133)
    plt.title('Black Colored')
    plt.imshow(out_image)
    plt.show()
    plt.pause(0.001)

In [None]:
train_d_losses, train_g_losses, valid_d_losses, valid_g_losses = [], [], [], []

for epoch in range(epochs):
    print(f'Epoch {epoch + 1}/{epochs}')

    train_epoch_d_loss, train_epoch_g_loss = [],[]
    for _, data in enumerate(tqdm(train_dataloader, leave=False)): 
        d_loss, g_loss = train_one_batch(generator, discriminator, data, criterion_content_loss, optimizerG, optimizerD)
        train_epoch_d_loss.append(d_loss)
        train_epoch_g_loss.append(g_loss)
    epoch_d_loss = np.array(train_epoch_d_loss).mean()
    epoch_g_loss = np.array(train_epoch_g_loss).mean()
    train_d_losses.append(epoch_d_loss)
    train_g_losses.append(epoch_g_loss)
    print(f'Train D loss: {epoch_d_loss:.4f}, train G loss: {epoch_g_loss:.4f}')

    valid_epoch_g_loss, valid_epoch_d_loss = [],[]
    for _, data in enumerate(tqdm(valid_dataloader, leave=False)):
        d_loss, g_loss = validate(generator, discriminator, data, criterion_content_loss)
        valid_epoch_g_loss.append(g_loss)
        valid_epoch_d_loss.append(d_loss)
    epoch_g_loss = np.array(valid_epoch_g_loss).mean()
    epoch_d_loss = np.array(valid_epoch_d_loss).mean()
    valid_g_losses.append(epoch_g_loss)
    valid_d_losses.append(epoch_d_loss)
    print(f'Validation D loss: {epoch_d_loss:.4f}, validation G loss: {epoch_g_loss:.4f}')
    print('-'*50)    

    schedulerD.step()
    schedulerG.step()
    if (epoch + 1) % 2 == 0:
        checkpoint = {
                      'epoch': epoch,     
                      'G': generator,
                      'D': discriminator,
                      'optimizerG': optimizerG,
                      'optimizerD': optimizerD,
                      }
        torch.save(checkpoint, PATH)
        data = next(iter(valid_dataloader))
        visual_validate(data, generator)

In [None]:
# test_dataloader

#### As the author suggests we may use pretrained Unet model to retrain it on colorization task and only then use it in our above Unet-GAN implementation (just pass it to the train_one_batch, validate functions), which significantly improves the quality of colorization results
```
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet


def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    body = create_body(resnet18, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G
    
def pretrain_generator(net_G, train_dl, opt, criterion, epochs):
    for e in range(epochs):
        for data in train_dalaloader:
            L, ab = data
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()
            
net_G = build_res_unet(n_input=1, n_output=2, size=256)
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()        
pretrain_generator(net_G, train_dl, opt, criterion, 20)
torch.save(net_G.state_dict(), "res18-unet.pth")
```
[https://towardsdatascience.com/colorizing-black-white-images-with-u-net-and-conditional-gan-a-tutorial-81b2df111cd8](http://)

##### Tests

In [None]:
data = next(iter(test_dataloader))
visual_validate(data, generator)