In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import time
from IPython import display
import torch.nn.functional as F 
from PIL import Image
from torch.autograd import Variable
from skimage.color import rgb2lab, lab2rgb
import torch.utils.tensorboard
from torch.utils.tensorboard import SummaryWriter       
%matplotlib inline 
from torch.utils.data import Dataset
from torchvision import transforms
import glob
from torch.utils.data.sampler import SubsetRandomSampler


In [None]:
def weights_init(model):
    if type(model) in [nn.Conv2d, nn.ConvTranspose2d, nn.Linear]:
        nn.init.xavier_normal_(model.weight.data)
        nn.init.constant_(model.bias.data, 0.)

class Color_model(nn.Module):
    def __init__(self):
        super(Color_model, self).__init__()
        self.layer1 = nn.Sequential(
            # conv1
            nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 64),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 64))
        self.layer2 = nn.Sequential(
            # conv2
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 128),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 128))
            # conv3
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 256),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 256),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 256))
            # conv4
        self.layer4 = nn.Sequential(
            nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, dilation= 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, dilation= 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, dilation= 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512))
            # conv5
        self.layer5 = nn.Sequential(
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512))
            # conv6
        self.layer6 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 2, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512))
            # conv7
        self.layer7 = nn.Sequential(
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 512))
            # conv8
        self.layer8 = nn.Sequential(
            nn.Upsample(scale_factor = 2.0),
            nn.Conv2d(in_channels = 512, out_channels = 256, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 256),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 256),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 256))
        self.layer9 = nn.Sequential(
            nn.Upsample(scale_factor = 2.0),
            nn.Conv2d(in_channels = 256, out_channels = 128, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 128),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 128),
            nn.Conv2d(in_channels = 128, out_channels =64, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 64))
        self.layer10 = nn.Sequential(
            nn.Upsample(scale_factor = 2.0),
            nn.Conv2d(in_channels = 64, out_channels = 32, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 32),
            nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 32),
            nn.Conv2d(in_channels = 32, out_channels = 2, kernel_size = 3, stride = 1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features = 2))
        self.apply(weights_init)

    def forward(self, gray_image):
        features = self.layer1(gray_image)
        features = self.layer2(features) 
        features = self.layer3(features) 
        features = self.layer4(features)
        features = self.layer5(features)
        features = self.layer6(features)  
        features = self.layer7(features)  
        features = self.layer8(features) 
        features = self.layer9(features) 
        features = self.layer10(features) 
        return features

In [None]:
class Places(Dataset):
    def __init__(self):
        self.data = np.load(imgLab.npy)
    
    def __len__(self):
        return self.data.shape[0]
        
    def __getitem__(self, idx):
        imgLab = self.data[idx]
        imgGrey = imgLab[:, :, 0]
        imgAB = imgLab[:, :, 1:]
        imgGrey = torch.from_numpy(np.expand_dims(imgGrey, 0))
        
        imgAB = torch.from_numpy(imgAB).permute(2, 0, 1)
        return (imgGrey, imgAB)

In [None]:
# Loss with custom gaussian filter
def cust_loss(output, target):  
    copy = output.clone()
#     maskA = 2 * np.exp(-(copy[:, 0, :, :].detach().cpu() - 1) * (copy[:, 0, :, :].detach().cpu() - 1) * 1/0.2) + 1
#     maskB = 2 * np.exp(-(copy[:, 1, :, :].detach().cpu() - .5) * (copy[:, 1, :, :].detach().cpu() - .5) * 1/0.07) + 1
    sub = (target - output)
#     sub[:, 0, :, :] = sub[:, 0, :, :] * maskA.cuda()
#     sub[:, 1, :, :] = sub[:, 1, :, :] * maskB.cuda()
    loss = torch.mean((sub)**2)
    return loss

In [None]:
eta = .005
momentum=0.9
epochs = 50000
batchSize = 12
writer = SummaryWriter() 

net = Color_model()
optimizer = optim.Adam(net.parameters(), lr = eta)
net = net.float()
net = net.cuda()

lossFreq = 20
validFreq = 50
saveFreq = 5
startEpoch = 1

print("Constructing Data")

train_Set = Places()

indices = list(range(len(train_Set)))
np.random.shuffle(indices)

split = int(len(train_Set) * 0.85)
trainIdx = indices[:split]
validIdx = indices[split:]

print("Random Sampling")

train_sampler = SubsetRandomSampler(trainIdx)
valid_sampler = SubsetRandomSampler(validIdx)

trainSet = torch.utils.data.DataLoader(train_Set, batch_size=batchSize, sampler = train_sampler, pin_memory=True, num_workers = 4)
validSet = torch.utils.data.DataLoader(train_Set, batch_size=batchSize, sampler = valid_sampler, pin_memory=True, num_workers = 4)

print("BEGIN TRAINING, NUM IMAGES IN TRAIN/VALID SET:", len(trainIdx))
torch.autograd.set_detect_anomaly(True)
count = 0
for i in range(startEpoch, epochs + startEpoch):
    for batch_idx, (data) in enumerate(trainSet):
        net.train()
        X_batch = Variable(data[0]).cuda()
        Y_batch = Variable(data[1]).cuda()
        
        optimizer.zero_grad()
        
        pred = net(X_batch.float())
        
        loss_mse = cust_loss(pred, Y_batch.float())
        
        count+=1
        
        # Record training loss from each epoch into the writer
        writer.add_scalar('Batch MSE Train/Loss', loss_mse.item(), count)
        writer.flush()
        
        loss_mse.backward()
        optimizer.step()
        
        if count % lossFreq == 0:
            print("Epoch: ", i, "Batch:", batch_idx + 1, "Loss: ", loss_mse.item())
            num = pred[0].detach().cpu().permute(1, 2, 0).numpy() 
            disp = np.append(X_batch[0].reshape(224, 224, 1).cpu(), num, axis = 2) * [100.0, 255.0, 255.0] - [0, 128, 128]
            plt.imshow(lab2rgb(disp))
            plt.show()

    with torch.no_grad():
        for batch_idx, (data) in enumerate(validSet):
            net.eval()
            X_valid = Variable(data[0]).cuda()
            Y_valid = Variable(data[1]).cuda()

            predValid = net(X_valid.float())
            lossValid = F.mse_loss(predValid, Y_valid.float()).item()
            
            writer.add_scalar('Valid Loss', lossValid, i)
            writer.flush()

        print("Epoch: ", i, "Batch:", batch_idx + 1, "Valid Loss: ", lossValid)
        num = predValid[0].cpu().permute(1, 2, 0).data.numpy()
        img_temp = X_valid[0].cpu().reshape(224, 224, 1)

        disp = np.append(img_temp, num, axis = 2) * [100.0, 255.0, 255.0] - [0, 128, 128]
        disp = lab2rgb(disp)

        plt.imshow(disp)
        plt.show()

    if (i % saveFreq == 0):
        PATH = "Models/" + str(i) + "ckpt.pt"

        torch.save({
            'epoch': i,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss_mse.item(),
            }, PATH)
    