# Colorizing Images using a VAE

## Importing Libraries

In [1]:
from google.colab import drive
drive.mount('/content/drive')

# Data directory
dataDir = '/content/drive/MyDrive/Projects/gray2rgb'

Mounted at /content/drive


In [2]:
import torch
from torch import nn
from torchvision import utils, datasets, transforms
from matplotlib import pyplot as plt
%matplotlib inline
from PIL import Image
import numpy as np
import os
import time

In [3]:
'''# Run to extract a random validation set from training data
import random

files = os.listdir(os.path.join(dataDir,'trainData'))
valFiles = random.sample(files,k=100)

for f in valFiles:
    os.rename(os.path.join(dataDir,'trainData',f), os.path.join(dataDir,'valData',f))
'''

"# Run to extract a random validation set from training data\nimport random\n\nfiles = os.listdir(os.path.join(dataDir,'trainData'))\nvalFiles = random.sample(files,k=100)\n\nfor f in valFiles:\n    os.rename(os.path.join(dataDir,'trainData',f), os.path.join(dataDir,'valData',f))\n"

## Pre-processing to Generate Training Data

I have a lot of colored images in my phone. I can convert them to grayscale and have enough input-output pairs for training the model.

In [4]:
# A custom dataset to generate (RGB, Grayscale) image pairs
class myDataset(torch.utils.data.Dataset):
    def __init__(self, path):
        self.path = path
        self.files = os.listdir(path)
        self.transformIn = transforms.Compose([
                     transforms.Resize(256),
                     transforms.CenterCrop(256),
                     transforms.ToTensor()])
        self.transformOut = transforms.Compose([
                     transforms.Grayscale(num_output_channels=1),
                     transforms.Resize(256),
                     transforms.CenterCrop(256),
                     transforms.ToTensor()])
        
    def __len__(self):
        return len(self.files)
        
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.path,self.files[index]))
        inImgs = self.transformIn(img)
        outImgs = self.transformOut(img)
        return inImgs, outImgs

# Initialized loader
batchSize = 4

dataset_train = myDataset(os.path.join(dataDir,'trainData'))
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batchSize, shuffle=True, num_workers=4, pin_memory=True)

dataset_val = myDataset(os.path.join(dataDir,'valData'))
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batchSize, shuffle=True, num_workers=4, pin_memory=True)

dataset_test = myDataset(os.path.join(dataDir,'testInput'))
loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=4)

## Unet Model

In [5]:
class encoder(nn.Module): # 1/2 of original size, 2 x channels
    def __init__(self,input_channels):
        super(encoder, self).__init__()
        self.gen = nn.Sequential(
                nn.Conv2d(input_channels,2*input_channels,kernel_size=3,stride=1,padding=1),
                nn.BatchNorm2d(2*input_channels),
                nn.LeakyReLU(0.2,inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
    def forward(self, x):
        return self.gen(x)

class decoder(nn.Module): # 2 x size, 1/2 channel
    def __init__(self, input_channels, skip_channels, final=False):
        super(decoder, self).__init__()
        self.gen1 = nn.Sequential(
                nn.ConvTranspose2d(input_channels,input_channels//2,kernel_size=2,stride=2,padding=0),
                nn.BatchNorm2d(input_channels//2),
                nn.ReLU(inplace=True),
                nn.Conv2d(input_channels//2,input_channels//2,kernel_size=3,stride=1,padding=1),
                nn.ReLU(inplace=True)
        ) if not final else nn.Sequential(
                nn.ConvTranspose2d(input_channels,input_channels,kernel_size=2,stride=2,padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(input_channels,input_channels,kernel_size=3,stride=1,padding=1),
                nn.ReLU(inplace=True)
        )
        
        self.gen2 = nn.Sequential(
                nn.Conv2d(input_channels//2+skip_channels,input_channels//2,kernel_size=3,stride=1,padding=1),
                nn.BatchNorm2d(input_channels//2),
                nn.ReLU(inplace=True),
        ) if not final else nn.Sequential(
                nn.Conv2d(input_channels+skip_channels,input_channels,kernel_size=3,stride=1,padding=1),
                nn.ReLU(inplace=True),
        )
        
    def forward(self, x, xskip):
        x = self.gen1(x)
        x = self.gen2(torch.cat([x,xskip],axis=1))
        return x
    
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.e1 = encoder(1) # 256,1 -> 128,2
        self.e2 = encoder(2) # 128,2 -> 64,4
        self.e3 = encoder(4) # 64,4 -> 32,8
        self.e4 = encoder(8) # 32,8 -> 16,16 
        self.e5 = encoder(16) # 16,16 -> 8,32
        self.e6 = encoder(32) # 8,32 -> 4,64
        self.e7 = encoder(64) # 4,64 -> 2,128
        self.e8 = encoder(128) # 2,128 -> 1,256
        self.l1 = nn.Sequential(
                  nn.Linear(256,256),
                  nn.ReLU(inplace=True),
                  nn.Linear(256,256),
                  nn.ReLU(inplace=True)
        )
        self.d1 = decoder(256,128) # 1,256 -> 2,128 .. 2,128 -> 2,256 -> 2,128
        self.d2 = decoder(128,64) # 2,128 -> 4,64 .. 4,64 -> 4,128 -> 4,64
        self.d3 = decoder(64,32) # 4,64 -> 8,32 .. 8,32 -> 8,64 -> 8,32
        self.d4 = decoder(32,16) # 8,32 -> 16,16 .. 16,16 -> 16,32 -> 16,16
        self.d5 = decoder(16,8) # 16,16 -> 32,8 .. 32,8 -> 32,16 -> 32,8
        self.d6 = decoder(8,4) # 32,8 -> 64,4 .. 64,4 -> 64,8 -> 64,4
        self.d7 = decoder(4,2, final=True) # 64,4 -> 128,4 .. 128,2 -> 128,6 -> 128,4
        self.d8 = decoder(4,1, final=True) # 128,4 -> 256,4 .. 256,1 -> 256,5 -> 256,4
        self.last = nn.Sequential(
            nn.Conv2d(4, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        ) # 256,4 -> 256,3
        
    def forward(self, x):
        x1 = self.e1(x)
        x2 = self.e2(x1)
        x3 = self.e3(x2)
        x4 = self.e4(x3)
        x5 = self.e5(x4)
        x6 = self.e6(x5)
        x7 = self.e7(x6)
        x8 = self.e8(x7)
        
        x81 = x8.view(-1,256)
        x82 = self.l1(x81)
        x83 = x82.view(-1,256,1,1)

        x9 = self.d1(x83, x7)
        x10 = self.d2(x9, x6)
        x11 = self.d3(x10, x5)
        x12 = self.d4(x11, x4)
        x13 = self.d5(x12, x3)
        x14 = self.d6(x13, x2)
        x15 = self.d7(x14, x1)
        x16 = self.d8(x15, x)
        output = self.last(x16)
        
        return output

### Helper Functions for Validation and Testing 

In [6]:
def valRun(model, loader, device):
    rgbs, grays = next(iter(loader))
    preds = model(grays.to(device))

    graygrd = utils.make_grid(grays)
    rgbgrd = utils.make_grid(rgbs)
    predgrd = utils.make_grid(preds.detach().cpu())

    fig, (ax1, ax2, ax3) = plt.subplots(3,1)
    ax1.imshow(graygrd.permute(1,2,0))
    ax2.imshow(rgbgrd.permute(1,2,0))
    ax3.imshow(predgrd.permute(1,2,0))

    for ax in [ax1, ax2, ax3]:
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    plt.show()
    time.sleep(1)
    return

def testRun(model, loader, device):
    for i, (_, grays) in enumerate(loader):
        preds = model(grays.to(device))

        predgrd = utils.make_grid(preds.detach().cpu())

        toImg = transforms.ToPILImage()
        img = toImg(predgrd)
        img.save(os.path.join(dataDir,'testOutput',f'{i}.jpeg'))
    return 

def saveChkPt(state, filename):
    torch.save(state,filename)
    return

def loadChkPt(filename, model, optimizer=None):
    chkpt = torch.load(filename)
    model.load_state_dict(chkpt['model'])
    if optimizer!=None: optimizer.load_state_dict(chkpt['optimizer'])
    loss_train = chkpt['loss_train']
    loss_val = chkpt['loss_val']
    return model, optimizer, chkpt['epoch'], loss_train, loss_val

## Train

In [None]:
nEpochs = 50
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
step_interval = 1000
save_interval = 10
epoch0 = 20

model = Unet().to(device)
reconCriterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

loss_train = []
loss_val = []
if epoch0>0:
    model, optimizer, _, loss_train, loss_val = loadChkPt(os.path.join(dataDir,'checkpoints',f'chkpt_{epoch0-1}.pt'), model, optimizer)

cur_step=0

for epoch in range(nEpochs):
    print(f'Epoch: {epoch0+epoch}')
    loss_train_s = 0
    loss_val_s = 0
    
    # Train
    model.train()
    for i, (rgbs, grays) in enumerate(loader_train):
        rgbs = rgbs.to(device)
        grays = grays.to(device)
        
        optimizer.zero_grad()
        pred = model(grays)
        loss = reconCriterion(pred, rgbs)
        loss_train_s = loss.item()
        loss.backward()
        optimizer.step()
        
        cur_step+=1        

    # Validation
    model.eval()
    with torch.no_grad():
        for i, (rgbs, grays) in enumerate(loader_val):
            rgbs = rgbs.to(device)
            grays = grays.to(device)
            pred = model(grays)
            loss = reconCriterion(pred, rgbs)
            loss_val_s+=loss.item()

    loss_train.append(loss_train_s)
    loss_val.append(loss_val_s/len(loader_val))

    if (epoch0+epoch+1)%save_interval==0:
        chkpt = {'epoch': epoch0+epoch,
                 'model': model.state_dict(),
                 'optimizer': optimizer.state_dict(),
                 'loss_train': loss_train,
                 'loss_val': loss_val}
        saveChkPt(chkpt, os.path.join(dataDir,'checkpoints',f'chkpt_{epoch0+epoch}.pt'))

        print(f'Steps: {cur_step}, Train Loss: {loss_train[-1]}, Val Loss: {loss_val[-1]}')
        valRun(model,loader_val,device)

        plt.figure()
        plt.plot(range(epoch0+epoch+1), loss_train, label='Train')
        plt.plot(range(epoch0+epoch+1), loss_val, label='Validation')
        plt.legend()
        plt.ylabel('Cost')
        plt.xlabel('Epochs')
        plt.show()
        time.sleep(1)

Epoch: 20
Epoch: 21
Epoch: 22
Epoch: 23
Epoch: 24
Epoch: 25


## Test

In [None]:
epoch0 = 50
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = Unet().to(device)
if epoch0>0:
    model, _, _, _, _ = loadChkPt(os.path.join(dataDir,'checkpoints',f'chkpt_{epoch0-1}.pt'), model)

testRun(model, loader_test, device)