## Import Dependencies

In [1]:
from google.colab import drive
import os
drive.mount('/content/drive/')
os.chdir('/content/drive/My Drive/DL_FinalProject/')

Mounted at /content/drive/


In [2]:
#install fastai to use Dynamic Unet model
!pip install fastai==2.4

# Set up the environment for Real ESRGAN model
!pip install basicsr
!pip install facexlib
!pip install gfpgan

# Download the pre-trained Real ESRGAN model and save it to "experiments/pretrained_models"
!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models

Collecting fastai==2.4
  Downloading fastai-2.4-py3-none-any.whl (187 kB)
[K     |████████████████████████████████| 187 kB 655 kB/s 
Collecting fastcore<1.4,>=1.3.8
  Downloading fastcore-1.3.29-py3-none-any.whl (55 kB)
[K     |████████████████████████████████| 55 kB 2.9 MB/s 
[?25hCollecting torch<1.10,>=1.7.0
  Downloading torch-1.9.1-cp37-cp37m-manylinux1_x86_64.whl (831.4 MB)
[K     |████████████████████████████████| 831.4 MB 6.6 kB/s 
Collecting torchvision>=0.8.2
  Downloading torchvision-0.12.0-cp37-cp37m-manylinux1_x86_64.whl (21.0 MB)
[K     |████████████████████████████████| 21.0 MB 213 kB/s 
[?25h  Downloading torchvision-0.11.3-cp37-cp37m-manylinux1_x86_64.whl (23.2 MB)
[K     |████████████████████████████████| 23.2 MB 130 kB/s 
[?25h  Downloading torchvision-0.11.2-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)
[K     |████████████████████████████████| 23.3 MB 1.2 MB/s 
[?25h  Downloading torchvision-0.11.1-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)
[K     |████████

--2022-05-06 17:40:12--  https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
Resolving github.com (github.com)... 52.192.72.89
Connecting to github.com (github.com)|52.192.72.89|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/387326890/08f0e941-ebb7-48f0-9d6a-73e87b710e7e?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20220506%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220506T174013Z&X-Amz-Expires=300&X-Amz-Signature=525d3700285063f0eba8c9f6a15e75c6890290cb3f36dce66b76ace5b1e1b919&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=387326890&response-content-disposition=attachment%3B%20filename%3DRealESRGAN_x4plus.pth&response-content-type=application%2Foctet-stream [following]
--2022-05-06 17:40:13--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/387326890/08f0e941-ebb7-48f0-9d6a-73e87b

In [3]:
import os
import glob
import time
import numpy as np
from PIL import Image, ImageOps
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader


from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_colab = None

import os.path as osp
import glob
import cv2
import numpy as np
import torch

## Download Data

In [7]:
from fastai.data.external import untar_data, URLs
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"
use_colab = True

Splitting data into training and validation data

In [8]:
if use_colab == True:
    path = coco_path
else:
    path = "./train/"
    
paths = glob.glob(path + "/*.jpg") # Grabbing all the image file names
np.random.seed(123)
paths_subset = np.random.choice(paths, 10_000, replace=False) # choosing 10,000 images randomly
rand_idxs = np.random.permutation(10_000)
train_idxs = rand_idxs[:8000] # choosing the first 8000 as training set
val_idxs = rand_idxs[8000:] # choosing last 2000 as validation set
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]
print(len(train_paths), len(val_paths))

8000 2000


## Creating Data Loaders

In [9]:
SIZE = 256
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(), # A little data augmentation!
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((SIZE, SIZE),  Image.BICUBIC)
        
        self.split = split
        self.size = SIZE
        self.paths = paths
    
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
        
        return {'L': L, 'ab': ab}
    
    def __len__(self):
        return len(self.paths)

def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

In [10]:
train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')

data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl), len(val_dl))

  "Argument interpolation should be of type InterpolationMode instead of int. "


torch.Size([16, 1, 256, 256]) torch.Size([16, 2, 256, 256])
500 125


## Image Colorization Model(Training from Scratch)

Generator: This code implements a UNet with ResNet18 backbone

In [11]:
def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    body = create_body(resnet18, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

Patch Discriminator: This code implements a model by stacking blocks of Conv-BatchNorm-LeackyReLU.

In [12]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) 
                          for i in range(n_down)] # the 'if' statement is taking care of not using
                                                  # stride of 2 for the last block in this loop
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
                                                                                             # activation for the last layer of the model
        self.model = nn.Sequential(*model)                                                   
        
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]          # it's always helpful to make a separate method for that purpose
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

Loss Computations

In [13]:
class GANLoss(nn.Module):
    def __init__(self, real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        self.loss = nn.BCEWithLogitsLoss()
    
    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

In [14]:
def init_weights(net, init='norm', gain=0.02):
    
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)
            
    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

Complete Conditional GAN architecture

In [15]:
class MainModel(nn.Module):
    def __init__(self, net_G, lr_G=2e-4, lr_D=2e-4, 
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        
        self.net_G = net_G.to(self.device)
        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss().to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        
    def forward(self):
        self.fake_color = self.net_G(self.L)
    
    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()
    
    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()
    
    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()
        
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

Result Visualization

In [16]:
class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)
    
def visualize(model, data, batch_no,p, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    for k in range(0,16,4):
      fig = plt.figure(figsize=(15, 8))
      for i in range(4):
          ax = plt.subplot(3, 4, i + 1)
          ax.imshow(L[k+i][0].cpu(), cmap='gray')
          ax.axis("off")
          ax = plt.subplot(3, 4, i + 1 + 4)
          ax.imshow(fake_imgs[k+i])
          ax.axis("off")
          ax = plt.subplot(3, 4, i + 1 + 8)
          ax.imshow(real_imgs[k+i])
          ax.axis("off")
          if save:
            im_fake=Image.fromarray(np.uint8(fake_imgs[k+i]*255))
            im_real_BW=ImageOps.grayscale(Image.fromarray(np.uint8(real_imgs[k+i]*255)))
            im_real_rgb=Image.fromarray(np.uint8(real_imgs[k+i]*255))
            if p == 1:
              im_fake.save("./colorized_images/LR_fake/image_"+str(batch_no)+"_"+str(k+i)+".png")
              im_real_BW.save("./colorized_images/LR_real_bw/image_"+str(batch_no)+"_"+str(k+i)+".png")
              im_real_rgb.save("./colorized_images/LR_real_rgb/image_"+str(batch_no)+"_"+str(k+i)+".png")
            else:
              im_fake.save("./RealESRGAN_Results/supertocolor/image_"+str(batch_no)+"_"+str(k+i)+".png")
              im_real_BW.save("./RealESRGAN_Results/super_bw/image_"+str(batch_no)+"_"+str(k+i)+".png")
              im_real_rgb.save("./RealESRGAN_Results/super_rgb/image_"+str(batch_no)+"_"+str(k+i)+".png")
      plt.show()
        
def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

Model Training

In [17]:
def train_model(model, train_dl, epochs, display_every=200):
    data = next(iter(val_dl)) # getting a batch for visualizing the model output after fixed intrvals
    for e in range(epochs):
        loss_meter_dict = create_loss_meters() # function returing a dictionary of objects to 
        i = 0                                  # log the losses of the complete network
        for data in tqdm(train_dl):
            model.setup_input(data) 
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0)) # function updating the log objects
            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict) # function to print out the losses
                visualize(model, data,batch_no=i,p=1, save=False) # function displaying the model's outputs

In [None]:
def pretrain_generator(net_G, train_dl, opt, criterion, epochs):
    for e in range(epochs):
        loss_meter = AverageMeter()
        for data in tqdm(train_dl):
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            loss_meter.update(loss.item(), L.size(0))
            
        print(f"Epoch {e + 1}/{epochs}")
        print(f"L1 Loss: {loss_meter.avg:.5f}")

net_G = build_res_unet(n_input=1, n_output=2, size=256) 
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()        
pretrain_generator(net_G, train_dl, opt, criterion, 20)
torch.save(net_G.state_dict(), "res18-unet.pt")

In [None]:
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
model = MainModel(net_G=net_G)
train_model(model, train_dl, 20)
torch.save(model.state_dict(),"color.pt")

## Evaluate

Loading our pre-trained generator weights and GAN model.

In [19]:
model_path_color = './color.pt'
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
model_color = MainModel(net_G=net_G)
model_color.load_state_dict(torch.load(model_path_color), strict=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

model initialized with norm initialization


<All keys matched successfully>

## Pipeline 1: LR -> Color -> Super



Colorization:

In [20]:
custom_paths = './test_data/LR/'

paths = glob.glob(custom_paths + "/*.png") # Grabbing all the image file names

customval_paths = paths[:80]

custom_dl = make_dataloaders(paths=customval_paths, split='val')
batch_no = 0
for d in tqdm(custom_dl):
  visualize(model_color, d, batch_no,p=1, save=True) # function displaying the model's outputs
  batch_no += 1

Output hidden; open in https://colab.research.google.com to view.

Super Resolution:

In [21]:
# if it is out of memory, try to use the `--tile` option
# We upsample the image with the scale factor X3.5
#os.chdir("./Real-ESRGAN")
!python inference_realesrgan.py -n RealESRGAN_x4plus -i 'colorized_images/LR_fake' --outscale 3.5 --output 'RealESRGAN_Results/colortosuper'
# Arguments
# -n, --model_name: Model names
# -i, --input: input folder or image
# --outscale: Output scale, can be arbitrary scale factore. 

Testing 0 image_0_0
Testing 1 image_0_1
Testing 2 image_0_10
Testing 3 image_0_11
Testing 4 image_0_12
Testing 5 image_0_13
Testing 6 image_0_14
Testing 7 image_0_15
Testing 8 image_0_2
Testing 9 image_0_3
Testing 10 image_0_4
Testing 11 image_0_5
Testing 12 image_0_6
Testing 13 image_0_7
Testing 14 image_0_8
Testing 15 image_0_9
Testing 16 image_1_0
Testing 17 image_1_1
Testing 18 image_1_10
Testing 19 image_1_11
Testing 20 image_1_12
Testing 21 image_1_13
Testing 22 image_1_14
Testing 23 image_1_15
Testing 24 image_1_2
Testing 25 image_1_3
Testing 26 image_1_4
Testing 27 image_1_5
Testing 28 image_1_6
Testing 29 image_1_7
Testing 30 image_1_8
Testing 31 image_1_9
Testing 32 image_2_0
Testing 33 image_2_1
Testing 34 image_2_10
Testing 35 image_2_11
Testing 36 image_2_12
Testing 37 image_2_13
Testing 38 image_2_14
Testing 39 image_2_15
Testing 40 image_2_2
Testing 41 image_2_3
Testing 42 image_2_4
Testing 43 image_2_5
Testing 44 image_2_6
Testing 45 image_2_7
Testing 46 image_2_8
Testi

## Pipeline 2: LR -> Super -> Color

Super Resolution:

In [22]:
test_img_folder = './test_data/LR/*' #Either colored image or LR image path

In [23]:
# if it is out of memory, try to use the `--tile` option
# We upsample the image with the scale factor X3.5
!python inference_realesrgan.py -n RealESRGAN_x4plus -i 'test_data/LR' --outscale 3.5 --output 'RealESRGAN_Results/LRtosuper/'
# Arguments
# -n, --model_name: Model names
# -i, --input: input folder or image
# --outscale: Output scale, can be arbitrary scale factore. 

Testing 0 0
Testing 1 1
Testing 2 10
Testing 3 11
Testing 4 12
Testing 5 13
Testing 6 14
Testing 7 15
Testing 8 16
Testing 9 17
Testing 10 18
Testing 11 19
Testing 12 2
Testing 13 20
Testing 14 21
Testing 15 22
Testing 16 23
Testing 17 24
Testing 18 25
Testing 19 26
Testing 20 27
Testing 21 28
Testing 22 29
Testing 23 3
Testing 24 30
Testing 25 31
Testing 26 32
Testing 27 33
Testing 28 34
Testing 29 35
Testing 30 36
Testing 31 37
Testing 32 38
Testing 33 39
Testing 34 4
Testing 35 40
Testing 36 41
Testing 37 42
Testing 38 43
Testing 39 44
Testing 40 45
Testing 41 46
Testing 42 47
Testing 43 48
Testing 44 49
Testing 45 5
Testing 46 50
Testing 47 51
Testing 48 52
Testing 49 53
Testing 50 54
Testing 51 55
Testing 52 56
Testing 53 57
Testing 54 58
Testing 55 59
Testing 56 6
Testing 57 60
Testing 58 61
Testing 59 62
Testing 60 63
Testing 61 64
Testing 62 65
Testing 63 66
Testing 64 67
Testing 65 68
Testing 66 69
Testing 67 7
Testing 68 70
Testing 69 71
Testing 70 72
Testing 71 73
Testing 72

Colorization:

In [24]:
custom_paths = './RealESRGAN_Results/LRtosuper/'

paths = glob.glob(custom_paths + "/*.png") # Grabbing all the image file names
customval_paths = paths[:80]

custom_dl = make_dataloaders(paths=customval_paths, split='val')
batch_no = 0
for d in tqdm(custom_dl):
  visualize(model_color, d, batch_no,p=0, save=True) # function displaying the model's outputs
  batch_no += 1

Output hidden; open in https://colab.research.google.com to view.