<a href="https://colab.research.google.com/github/smkj33/greyscale-image-colourisation/blob/master/Image_Colorisation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports
**Try** to keep Most of the imports in this cell

In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np

from skimage import color
from skimage.transform import resize
import scipy.ndimage.interpolation as sni
from skimage import color, io

# import cv2

from PIL import Image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Pytorch device: " + str(device))

# RUNTIME ENVIRONMENT
Loaded all codes into google colab - drive used: sid.mkjee.h5@gmail.com;   

In [0]:
# this saves the model to your own Google Drive
# once you run, you will have to follow a generated link and get the authorization code
from google.colab import drive
drive.mount('/content/gdrive')

# IMPORTING DATASET (optional)


In [0]:
"""
Using the following commands, python connects to kaggle and
the data set is downloaded in the colab environment.
This will get deleted as soon as the instance is disconnected.
To avoid this the data is permanently uploaded in my google drive
"""

# Using my Kaggle Token to import dataset.  

# !mkdir ~/.kaggle
# !touch ~/.kaggle/kaggle.json
# api_token = {"username":"smukherj","key":"<---add latest key here--->"}

# import json
# with open('/root/.kaggle/kaggle.json', 'w') as file:
#     json.dump(api_token, file)

# !chmod 600 ~/.kaggle/kaggle.json

In [0]:
"""
The following dataset is used:  
https://www.kaggle.com/shravankumar9892/image-colorization#gray_scale.npy   
"""

# !kaggle datasets download -d shravankumar9892/image-colorization

# DATA PREPROCESSING 



In [0]:
"""
File is obtained in zip format and uploaded to drive.  
Python is used for unzipping.

DO NOT RERUN, FILE ALREADY EXTRACTED!!
"""


# import zipfile
# with zipfile.ZipFile('image-colorization.zip', 'r') as zip_ref:
#     zip_ref.extractall(path = 'gdrive/My Drive/Datasets')

In [0]:
# Reference: https://stackoverflow.com/questions/57989716/loading-npy-files-as-dataset-for-pytorch
# https://discuss.pytorch.org/t/loading-npy-files-using-torchvision/28481/2
# Imports npy files into tensor data structure
def npy_loader(path, file_name):
    return torch.from_numpy(np.load(path + file_name))

def reconstruct_predicted(image_gray, image_ab):
    '''Reconstruct image from l+ab, convert to rgb uint8'''
    image_gray = image_gray.squeeze(0).squeeze(0)
    image_ab = image_ab.squeeze(0).permute((1,2,0))
    img = np.zeros((256, 256, 3))
    img[:, :, 0] = image_gray
    img[:, :, 1:] = image_ab
    img = color.lab2rgb(img)  # (256, 256, 3), float(0., 1.)
    img = img * 255
    img = img.astype('uint8')  # (256, 256, 3), int(0, 255)
    return img

# Hyperparameters

In [0]:
minibatch_size = 1
n_epochs = 100
# TODO: add other hyperparams

# Dataset

In [0]:
# Import Dataset
path = '/content/gdrive/My Drive/Datasets/Archive/ab/ab/'
abval1 = npy_loader(path, 'ab1.npy')
abval2 = npy_loader(path, 'ab2.npy')
abval3 = npy_loader(path, 'ab3.npy')

path =  '/content/gdrive/My Drive/Datasets/Archive/l/'
lval = npy_loader(path, 'gray_scale.npy')

ab_temp = torch.cat((abval1, abval2), 0)
abval = torch.cat((ab_temp , abval3), 0)

# split into training and testing
lval_training, lval_testing = lval.split((lval.shape[0] - 5000, 5000))
abval_training, abval_testing = abval.split((lval.shape[0] - 5000, 5000))

# split to mini_batch sized chunks
lval_training = torch.split(lval_training, minibatch_size, 0)
lval_testing = torch.split(lval_testing, minibatch_size, 0)
abval_training = torch.split(abval_training, minibatch_size, 0)
abval_testing = torch.split(abval_testing, minibatch_size, 0)
n_training = len(lval_training)
n_testing = len(lval_testing)
print("Number training batches: {}\nNumber testing batches: {}".format(n_training, n_testing))
print("Shape of a batch for l: {}".format(lval_training[0].shape))
print("Shape of a batch for ab: {}".format(abval_training[0].shape))

In [0]:
# Bins for quantized ab color space
points = torch.tensor([[-90, 50], [-90, 60], [-90, 70], [-90, 80], [-90, 90], [-80, 20], [-80, 30], [-80, 40], [-80, 50], [-80, 60], [-80, 70], [-80, 80], [-80, 90], [-70, 0], [-70, 10], [-70, 20], [-70, 30], [-70, 40], [-70, 50], [-70, 60], [-70, 70], [-70, 80], [-70, 90], [-60, -20], [-60, -10], [-60, 0], [-60, 10], [-60, 20], [-60, 30], [-60, 40], [-60, 50], [-60, 60], [-60, 70], [-60, 80], [-60, 90], [-50, -30], [-50, -20], [-50, -10], [-50, 0], [-50, 10], [-50, 20], [-50, 30], [-50, 40], [-50, 50], [-50, 60], [-50, 70], [-50, 80], [-50, 90], [-50, 100], [-40, -40], [-40, -30], [-40, -20], [-40, -10], [-40, 0], [-40, 10], [-40, 20], [-40, 30], [-40, 40], [-40, 50], [-40, 60], [-40, 70], [-40, 80], [-40, 90], [-40, 100], [-30, -50], [-30, -40], [-30, -30], [-30, -20], [-30, -10], [-30, 0], [-30, 10], [-30, 20], [-30, 30], [-30, 40], [-30, 50], [-30, 60], [-30, 70], [-30, 80], [-30, 90], [-30, 100], [-20, -50], [-20, -40], [-20, -30], [-20, -20], [-20, -10], [-20, 0], [-20, 10], [-20, 20], [-20, 30], [-20, 40], [-20, 50], [-20, 60], [-20, 70], [-20, 80], [-20, 90], [-20, 100], [-10, -60], [-10, -50], [-10, -40], [-10, -30], [-10, -20], [-10, -10], [-10, 0], [-10, 10], [-10, 20], [-10, 30], [-10, 40], [-10, 50], [-10, 60], [-10, 70], [-10, 80], [-10, 90], [-10, 100], [0, -70], [0, -60], [0, -50], [0, -40], [0, -30], [0, -20], [0, -10], [0, 0], [0, 10], [0, 20], [0, 30], [0, 40], [0, 50], [0, 60], [0, 70], [0, 80], [0, 90], [0, 100], [10, -80], [10, -70], [10, -60], [10, -50], [10, -40], [10, -30], [10, -20], [10, -10], [10, 0], [10, 10], [10, 20], [10, 30], [10, 40], [10, 50], [10, 60], [10, 70], [10, 80], [10, 90], [20, -80], [20, -70], [20, -60], [20, -50], [20, -40], [20, -30], [20, -20], [20, -10], [20, 0], [20, 10], [20, 20], [20, 30], [20, 40], [20, 50], [20, 60], [20, 70], [20, 80], [20, 90], [30, -90], [30, -80], [30, -70], [30, -60], [30, -50], [30, -40], [30, -30], [30, -20], [30, -10], [30, 0], [30, 10], [30, 20], [30, 30], [30, 40], [30, 50], [30, 60], [30, 70], [30, 80], [30, 90], [40, -100], [40, -90], [40, -80], [40, -70], [40, -60], [40, -50], [40, -40], [40, -30], [40, -20], [40, -10], [40, 0], [40, 10], [40, 20], [40, 30], [40, 40], [40, 50], [40, 60], [40, 70], [40, 80], [40, 90], [50, -100], [50, -90], [50, -80], [50, -70], [50, -60], [50, -50], [50, -40], [50, -30], [50, -20], [50, -10], [50, 0], [50, 10], [50, 20], [50, 30], [50, 40], [50, 50], [50, 60], [50, 70], [50, 80], [60, -110], [60, -100], [60, -90], [60, -80], [60, -70], [60, -60], [60, -50], [60, -40], [60, -30], [60, -20], [60, -10], [60, 0], [60, 10], [60, 20], [60, 30], [60, 40], [60, 50], [60, 60], [60, 70], [60, 80], [70, -110], [70, -100], [70, -90], [70, -80], [70, -70], [70, -60], [70, -50], [70, -40], [70, -30], [70, -20], [70, -10], [70, 0], [70, 10], [70, 20], [70, 30], [70, 40], [70, 50], [70, 60], [70, 70], [70, 80], [80, -110], [80, -100], [80, -90], [80, -80], [80, -70], [80, -60], [80, -50], [80, -40], [80, -30], [80, -20], [80, -10], [80, 0], [80, 10], [80, 20], [80, 30], [80, 40], [80, 50], [80, 60], [80, 70], [90, -110], [90, -100], [90, -90], [90, -80], [90, -70], [90, -60], [90, -50], [90, -40], [90, -30], [90, -20], [90, -10], [90, 0], [90, 10], [90, 20], [90, 30], [90, 40], [90, 50], [90, 60], [90, 70], [100, -90], [100, -80], [100, -70], [100, -60], [100, -50], [100, -40], [100, -30], [100, -20], [100, -10], [100, 0]]).to(device)

In [0]:
class CustomDatasetLAB():
    def __init__(self, lval, abval):
        # array of luminances and chrominances
        self.l = lval
        self.ab = abval

    def size(self):
        return len(lval)

    def __getitem__(self, idx):
        img_l = self.l[idx]
        img_ab = self.ab[idx]
        img_l = resize(img_l, (minibatch_size, 256, 256), preserve_range=True) #TODO: rewrite with F.interpolate
        img_ab = resize(img_ab, (minibatch_size, 256, 256), preserve_range=True)  #TODO: rewrite with F.interpolate

        img_l = torch.from_numpy(img_l).unsqueeze(1).float().to(device)
        # print(img_ab.shape) 

        img_ab = torch.from_numpy(img_ab).permute(0,3,1,2).float().to(device)
        # print(img_l.shape)  # (32, 1, 256, 256)
        # print(img_ab.shape)  # (32, 2, 256, 256)

        img_l = img_l / 256 * 110
        img_ab = img_ab - 128
        # print(img_ab.shape)
        return (img_l, img_ab)

training_dataset = CustomDatasetLAB(lval_training, abval_training)
testing_dataset = CustomDatasetLAB(lval_testing, abval_testing)

# Preview dataset

In [0]:
l, ab = testing_dataset[1]  # batch
n = 0 # image in batch
l = l[n:n+1]
ab = ab[n:n+1]
img = reconstruct_predicted(l.to('cpu'), ab.to('cpu'))
plt.imshow(img)

# Colorizer NN

In [0]:
class Net(nn.Module):
    ''' Net parameters as in: https://github.com/richzhang/colorization/blob/master/colorization/models/colorization_deploy_v2.prototxt '''
    def __init__(self):
        super(Net, self).__init__()
        self.conv1_1 = nn.Conv2d(1, 64, 3, padding=1).to(device)
        self.conv1_2 = nn.Conv2d(64, 64, 3, stride=2, padding=1).to(device)
        self.bn1_2 = nn.BatchNorm2d(64)  # TODO: what does batchnorm do?

        self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1).to(device)
        self.conv2_2 = nn.Conv2d(128, 128, 3, stride=2, padding=1).to(device)
        self.bn2_2 = nn.BatchNorm2d(128)

        self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1).to(device)
        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1).to(device)
        self.conv3_3 = nn.Conv2d(256, 256, 3, 2, padding=1).to(device)
        self.bn3_3 = nn.BatchNorm2d(256)

        self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1).to(device)
        self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1).to(device)
        self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1).to(device)
        self.bn4_3 = nn.BatchNorm2d(512)

        self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2).to(device)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2).to(device)
        self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2).to(device)
        self.bn5_3 = nn.BatchNorm2d(512)

        self.conv6_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2).to(device)
        self.conv6_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2).to(device)
        self.conv6_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2).to(device)
        self.bn6_3 = nn.BatchNorm2d(512)
       
        self.conv7_1 = nn.Conv2d(512, 512, 3, padding=1, dilation=1).to(device)
        self.conv7_2 = nn.Conv2d(512, 512, 3, padding=1, dilation=1).to(device)
        self.conv7_3 = nn.Conv2d(512, 512, 3, padding=1, dilation=1).to(device)
        self.bn7_3 = nn.BatchNorm2d(512)

        # TODO: what is deconvolution?
        self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, dilation=1).to(device)
        self.conv8_2 = nn.Conv2d(256, 256, 3, padding=1, dilation=1).to(device)
        self.conv8_3 = nn.Conv2d(256, 256, 3, padding=1, dilation=1).to(device)
        
        self.conv8_313 = nn.Conv2d(256, 313, kernel_size=1).to(device)
        self.softmax = nn.Softmax(dim=1).to(device)

    def forward(self, x):
        ''' expects l tensor[1,1,256,256] '''
        x = self.bn1_2(F.relu(self.conv1_2(F.relu(self.conv1_1(x)))))
        x = self.bn2_2(F.relu(self.conv2_2(F.relu(self.conv2_1(x)))))
        x = self.bn3_3(F.relu(self.conv3_3(F.relu(self.conv3_2(F.relu(self.conv3_1(x)))))))
        x = self.bn4_3(F.relu(self.conv4_3(F.relu(self.conv4_2(F.relu(self.conv4_1(x)))))))
        x = self.bn5_3(F.relu(self.conv5_3(F.relu(self.conv5_2(F.relu(self.conv5_1(x)))))))
        x = self.bn6_3(F.relu(self.conv6_3(F.relu(self.conv6_2(F.relu(self.conv6_1(x)))))))
        x = self.bn7_3(F.relu(self.conv7_3(F.relu(self.conv7_2(F.relu(self.conv7_1(x)))))))
        x = F.relu(self.conv8_3(F.relu(self.conv8_2(F.relu(self.conv8_1(x))))))
        z_hat = self.softmax(self.conv8_313(x))
        return z_hat



# Training

In [0]:
def get_index(ab):
    ''' Quantize image and return bin-number for whole h*w
    @param: in_data - np.array in shape [1,2,h,w]
    @return: idx_bin - np.array in shape [1,h,w]
     '''
    ab = ab.unsqueeze(1) # make dimension for z broadcasting (1,1,2,64,64)
    points_reshaped = points.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)  # (1, 313, 2, 1, 1)
    distance = (ab - points_reshaped)**2 # (1, 313, 2, 64, 64)
    avg_distance = torch.mean(distance, dim = 2) # (1, 313, 64, 64) n.b. closest bin is (a_dist + b_dist)/2
    bin_idx = avg_distance.argmin(dim = 1)  # (1, 64, 64)
    # print(bin_idx.shape)
    return bin_idx

# NOTE: This is the one-hot encoded version
def h_inverse(ab):
    ''' construct z tensor given ab ground truth. Hard encoded.
    @param ab, tensor([224,224,2])
    @return z, tensor[1,64,64,313]
    '''
    ab_idx = get_index(ab)
    z = F.one_hot(ab_idx, 313).float()
    return z

# Init Net, Loss, Optimizer



In [0]:
net = Net()
net.to(device)

class OutCrossEntropyLoss(torch.nn.Module):
    
    def __init__(self):
        super(OutCrossEntropyLoss,self).__init__()
        
    def forward(self,z,z_hat):
        mul = z.mul(z_hat.log2())
        # print(mul)
        # print(mul.sum(1))
        MSE = mul.sum()
        return -MSE

# criterion = nn.MSELoss()  # Simple L2 squared loss
criterion = OutCrossEntropyLoss()  # Cross Entropy Loss
criterion.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.99), weight_decay=0.004)

In [0]:
torch.set_printoptions(edgeitems=10)  # widen print output
losses = []

class JobDone( Exception ):
    pass
try:
    for epoch in range(n_epochs):  # loop over the dataset multiple times
        # dummy_image = training_dataset[0]
        for i, data in enumerate(training_dataset, 0):
        # for i in range(100000):
            # l, ab = dummy_image        
            l, ab = data       
            # print(np.unique(get_index(ab_64).flatten()).tolist()) # DEBUG: all unique colors in the image
            ab_64 = F.interpolate(ab, (64,64))
            # given ground truth ab, convert it to probabilities. Paper page 5, footnote
            z = h_inverse(ab_64)  # (minibatch_size, 64, 64, 313)
            z = z.permute(0,3,1,2)  # (minibatch_size, 313, 64, 64)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            z_hat = net(l)

            loss = criterion(z, z_hat)  # cross entropy loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)  # TODO: useful?
            optimizer.step()

            # print statistics
            losses.append(loss.item())
            if i % 1000 == 0:
                print('Epoch {}/{}, minibatch {}/{}(size = {}): loss {}'.format(epoch, n_epochs, i, n_training, minibatch_size, loss.item()))
                # running_loss = 0.0

                # Save colorization preview for first image of testing_dataset
                l, ab = testing_dataset[12] # minibatch 0
                l = l[0:1] # image 0
                ab = ab[0:1] # image 0
                z_hat = net(l)
                ab_hat = H(z_hat)
                ab_hat = ab_hat.permute(2,0,1).unsqueeze(0)
                ab_hat_256 = F.interpolate(ab_hat, (256, 256))
                img_rgb = reconstruct_predicted(l.cpu(), ab_hat_256.detach().cpu())
                PATH = '/content/gdrive/My Drive/SavedModels/Results/colorized{}_{}.png'.format(epoch, i)
                img = Image.fromarray(img_rgb)
                img.save(PATH)
                print("       Colorization preview saved")

            if i % 20000 == 0:
                # Save model checkpoints
                PATH = '/content/gdrive/My Drive/SavedModels/Checkpoints/dummy_image_{}_{}.pth'.format(epoch,i)
                torch.save(net.state_dict(), PATH)
                print("       Checkpoint saved.")
            if (torch.isnan(z_hat[0,0,0,0])):
                raise ValueError("ERROR: One of the predicted values is NaN, aborting.")
            if loss.item() < 500:
                raise JobDone
except JobDone:   
    print('Finished Training')
    PATH = '/content/gdrive/My Drive/SavedModels/Checkpoints/job_done.pth'.format(epoch,i)
    torch.save(net.state_dict(), PATH)
    print("       Job saved.")
else:
    print("Job failed due to NaN error")

In [0]:
# Display statistics of loss
# print(running_losses)
fig = plt.figure()
running_loss = 0
running_losses = []
n = 16  # running mean window
for i, loss in enumerate(losses):
    running_loss += loss
    if i % n == 0 and i != 0:
        running_losses.append(running_loss / n)
        running_loss = 0
plt.plot(running_losses)
plt.show()

# Saving a trained model (optional)

In [0]:
PATH = '/content/gdrive/My Drive/SavedModels/colorization_net.pth'
torch.save(net.state_dict(), PATH)
print("Saved.")

# Loading a trained model (optional)

In [0]:
net = Net()
PATH = '/content/gdrive/My Drive/SavedModels/Checkpoints/dummy_image_8_0.pth'
net.load_state_dict(torch.load(PATH))
net.to(device)
print("Loaded.")

#Displaying trained Colorizer results

In [0]:
def H(z_hat):
    """
    Input: Z_hat (1,313,64,64)
     - 313 probablities for each of the H * W pixels.
      This is for 313 quantised a-b bins

    Output: Y_hat (64, 64, 2)
    - ab value matrix
    """
    z_hat_softmax = f_t(z_hat)
    z_hat_softmax = z_hat_softmax.squeeze(0).permute(1,2,0) # remove singleton dimension 0
    # print(z_hat_softmax.shape) # [64,64,313]
    a_bins = points[:,0].unsqueeze(0).unsqueeze(0)
    b_bins = points[:,1].unsqueeze(0).unsqueeze(0)
    # print(b_bins.shape)  # [1,1,313]
    a_hat = (z_hat_softmax * a_bins).sum(2)
    b_hat = (z_hat_softmax * b_bins).sum(2)
    a_hat = a_hat.unsqueeze(0)
    b_hat = b_hat.unsqueeze(0)
    # print(a_hat.shape)  # [1, 64, 64]
    ab_hat = torch.cat((a_hat, b_hat), 0)
    # print(ab_hat.shape) # [2, 64, 64]
    return ab_hat.permute(1,2,0)

def f_t(z_hat):
    """
    f_t function in equation 5 
    input Z_hat has dimensions (1, 313, 64, 64)
    ??? output f_t_Z dimension H * W * Q

    - Each of the 313 Q probability values of each H * W pixels is converted to 
      softmax probabilites.

    - The final value is the expected value of all these 313 softmax probablities across the color bins
      this is calculated in the H function
    """
    T = 0.38 # From page 6 of paper
    z_hat_exp = torch.exp(torch.log(z_hat)/T) # dimension: H * W * Q

    deno = torch.sum(z_hat_exp, dim = 1)
    z_hat_softmax = z_hat_exp / deno 
    return z_hat_softmax

# Test with random image from end of dataset
from random import seed
from random import randint

# l, ab = dataset[randint(0, 100) * -1]
# l, ab = testing_dataset[11] # minibatch 0
l, ab = training_dataset[0] # minibatch 0
l = l[0:1] # image 0
ab = ab[0:1] # image 0
z_hat = net(l)
ab_hat = H(z_hat)
ab_hat = ab_hat.permute(2,0,1).unsqueeze(0)
ab_hat_256 = F.interpolate(ab_hat, (256, 256))
img_rgb = reconstruct_predicted(l.cpu(), ab_hat_256.detach().cpu())
plt.imshow(img_rgb)

In [0]:
# TODO: add v(.) coefficient to the loss function (color rarity rebalancing term)
# TODO: optimal minibatch size? Passing minibatch of 32 images to the network - legit? faster?
# TODO: batch norm in NN required?