# Conditional Colorisation Solution

In [1]:
import os
import cv2
import torch
import torch.utils.data
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import time
from dataloader import ColoringDataset, tensor2img
from unet_resblock import *

## Data Augmentation

In [2]:
# train_dir = "data/train/"
# training_images = [os.path.join(train_dir, item) for item in os.listdir(train_dir)]
# print(f"Count of image before augmentation: {len(training_images)}")

# for img_name in tqdm.tqdm(training_images):
#     imgx = cv2.imread(img_name)
#     imgx = cv2.cvtColor(imgx , cv2.COLOR_BGR2RGB)
#     flip_imgx = cv2.flip(imgx, 1)
#     rot_imgx = cv2.rotate(imgx , cv2.ROTATE_180)
#     cv2.imwrite(os.path.join(train_dir, 'rot_image{}.png'.format(training_images.index(img_name))), flip_imgx)
#     cv2.imwrite(os.path.join(train_dir, 'rot_image{}.png'.format(training_images.index(img_name))), rot_imgx)

# print(f"Count of image after augmentation: {len(os.listdir(train_dir))}")

## Define logic for training and validation

In [3]:
class Utility(object):
  def __init__(self):
    self.reset()
  def reset(self):
    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

def train(train_loader, model, criterion, optimizer, epoch):
    print('Starting training epoch {}'.format(epoch+1))
    model.train()
    use_cuda = True
    batch_time, data_time, losses = Utility(), Utility(), Utility()

    end = time.time()
    for i, data in enumerate(tqdm.tqdm(train_loader)):
        if use_cuda:
            l = data["l"].to('cuda')
            ab = data["ab"].to('cuda')
            img_gray = data["gray"].to('cuda')
            img_cue = data["cue"].to('cuda')

        gt_image = torch.cat((l, ab), dim=1)
        img_gray_image = torch.cat((l, img_gray,img_cue), dim=1)
        # Record time to load data (above)
        data_time.update(time.time() - end)

        # Run forward pass
        output_img_gray = model(img_gray_image)
        #loss = 1-criterion(output_img_gray, gt_image)
        loss = criterion(output_img_gray, gt_image)
        losses.update(loss.item(), img_gray_image.size(0))
        # Compute gradient and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record time to do forward and backward passes
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 225 == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                epoch+1, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses))

    print(f'training epoch: {epoch+1} completed')


def validate(val_loader, model, criterion, save_images, epoch):
    model.eval()
    use_cuda = True
    # Prepare value counters and timers
    batch_time, data_time, losses = Utility(), Utility(), Utility()

    end = time.time()
    for i, data in enumerate(tqdm.tqdm(val_loader)):
        if use_cuda:
            l = data["l"].to('cuda')
            ab = data["ab"].to('cuda')
            img_gray = data["gray"].to('cuda')
            img_cue = data["cue"].to('cuda')

        gt_image = torch.cat((l, ab), dim=1)
        img_gray_image = torch.cat((l, img_gray, img_cue), dim=1)
        data_time.update(time.time() - end)
        output_img_gray = model(img_gray_image)

        #loss = 1-criterion(output_img_gray, gt_image)
        loss = criterion(output_img_gray, gt_image)
        losses.update(loss.item(), img_gray_image.size(0))

        # Record time to do forward passes and save images
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 100 == 0:
            print('Validate: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                i, len(val_loader), batch_time=batch_time, loss=losses))
        out_img_gray_np = tensor2img(output_img_gray)
        out_img_gray_bgr = cv2.cvtColor(out_img_gray_np, cv2.COLOR_LAB2BGR)

        cv2.imwrite("outputs/outputs/output_"+str(i)+".png", out_img_gray_bgr)

    print('Validation Completed.')
    return losses.avg

### Prepare Dataloader

In [4]:
data_dir = "data/"
use_cuda = True

train_dataset = ColoringDataset(data_dir, 128)
train_dataset.set_mode("training")
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

test_dataset = ColoringDataset(data_dir, 128)
test_dataset.set_mode("validation")
test_dataloader = torch.utils.data.DataLoader(test_dataset)

## Load and set Model

In [5]:
model = ResAttdU_Net()
# print(model)

criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=25e-5, weight_decay=0.0)

if use_cuda:
    model.cuda()
    criterion = criterion.cuda()
    model = model.cuda()

## Model Training

In [6]:
# Make folders and set parameters
os.makedirs('checkpoints', exist_ok=True)
save_images = True
best_losses = 1e10
epochs = 150

# Train model
for epoch in range(epochs):
    train(train_dataloader, model, criterion, optimizer, epoch)
    with torch.no_grad():
        losses = validate(test_dataloader, model, criterion, save_images, epoch)
    # Save checkpoint
    if losses < best_losses:
        best_losses = losses
        torch.save(model.state_dict(), 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch + 1, losses))

Starting training epoch 1


  0%|          | 2/2251 [00:00<11:06,  3.37it/s]

Epoch: [1][0/2251]	Time 0.559 (0.559)	Data 0.175 (0.175)	Loss 0.7398 (0.7398)	


  4%|▍         | 91/2251 [00:10<04:06,  8.77it/s]


KeyboardInterrupt: 