# 0. Set Parameters

In [None]:
batch_size = 5
num_workers = 2
seed = 814

In [None]:
coco_dataset_folder = "/content/drive/MyDrive/Final_Project/COCO_DATASET"
MODEL_SAVE_PATH = ''

# 1. Dataset loading

In [None]:
!pip install fastai

In [None]:
# Download and build dataset
'''
import os
from google.colab import drive
from fastai.data.external import untar_data, URLs

drive.mount('/content/drive')
download_path = untar_data(URLs.COCO_SAMPLE, data=dataset_path)
'''

In [None]:
'''
st = os.path.join(download_path, 'train_sample')
TRAIN_DATASET_PATH = '/content/drive/MyDrive/Final_Project/COCO_DATASET/train'
VAL_DATASET_PATH = '/content/drive/MyDrive/Final_Project/COCO_DATASET/val'
TEST_DATASET_PATH = '/content/drive/MyDrive/Final_Project/COCO_DATASET/test'
os.makedirs(TRAIN_DATASET_PATH)
os.makedirs(VAL_DATASET_PATH)
os.makedirs(TEST_DATASET_PATH)

all_jpg_file = []
for dir in os.listdir(st):
    if os.path.isfile(os.path.join(st, dir)) and ('jpg' in dir):
        all_jpg_file.append(dir)

from sklearn.model_selection import train_test_split
import shutil

remain, test = train_test_split(all_jpg_file, test_size=0.1)
train, val = train_test_split(remain, test_size = 0.3)
print(len(all_jpg_file))
print(len(train))
print(len(val))
print(len(test))
def move_files(ARR, DIR, target_DIR):
    for jpg in ARR:
        c_DIR = os.path.join(DIR, jpg)
        shutil.move(c_DIR, os.path.join(target_DIR, jpg))
move_files(train, st, TRAIN_DATASET_PATH)
move_files(val, st, VAL_DATASET_PATH)
move_files(test, st, TEST_DATASET_PATH)
'''

In [None]:
from google.colab import drive
import os, shutil, sys

drive.mount('/content/drive')
#os.makedirs(MODEL_SAVE_PATH, exist_ok= False)

Mounted at /content/drive


In [None]:
# sys.path.append('/content/drive/MyDrive/Final_Project')

# COCO_DATA_DIR = '/content/coco_dataset'

# print('Copying training dataset to machine disk')
# shutil.copytree(coco_dataset_folder, COCO_DATA_DIR)
# TRAIN_DATASET_PATH = os.path.join(COCO_DATA_DIR, 'train')
# VAL_DATASET_PATH = os.path.join(COCO_DATA_DIR, 'val')
# TEST_DATASET_PATH = os.path.join(COCO_DATA_DIR, 'test')

In [None]:
COCO_DATA_DIR = coco_dataset_folder

TRAIN_DATASET_PATH = os.path.join(COCO_DATA_DIR, 'train')
VAL_DATASET_PATH = os.path.join(COCO_DATA_DIR, 'val')
TEST_DATASET_PATH = os.path.join(COCO_DATA_DIR, 'test')

In [None]:
import os, torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

from skimage.color import rgb2lab, lab2rgb
import numpy as np


class COCO_Dataset(Dataset):
    def __init__(self, ALL_IMG_PATH, BASE_PATH, purpose, transform = None):
        assert(purpose in ['train', 'val', 'test'])
        self.H = 256

        if purpose in ['val', 'test']:
            if transform == None:
                self.transform = transforms.Resize((self.H, self.H), transforms.InterpolationMode.BICUBIC)
            else:
                self.transform = transform

        else:
            # Train
            if transform == None:
                self.transform = transforms.Resize((self.H, self.H), transforms.InterpolationMode.BICUBIC)
            else:
                self.transform = transform

        self.purpose = purpose
        self.ALL_IMG_PATH = ALL_IMG_PATH
        self.BASE_PATH = BASE_PATH

    def __len__(self):
        return len(self.ALL_IMG_PATH)

    def _normalize_L_channel(self, l_channel):
        assert(l_channel.shape[0] == 1)
        return (l_channel / 50.0) - 1.0

    def _normalize_AB_channel(self, AB_channel):
        assert(AB_channel.shape[0] == 2)
        return (AB_channel / 110.0)

    def __getitem__(self, index):
        c_image_path = os.path.join(self.BASE_PATH, self.ALL_IMG_PATH[index])
        img = Image.open(c_image_path)
        img = img.convert('RGB')
        img = self.transform(img)
        img = np.array(img)

        # COnvert RGB to LAB for easier
        lab_img = torch.from_numpy(rgb2lab(img)).float().permute(2,0,1)
        print(lab_img.shape)
        L_ch = lab_img[[0], ...]
        AB_ch = lab_img[[1,2], ...]
        L_ch = self._normalize_L_channel(L_ch)
        AB_ch = self._normalize_AB_channel(AB_ch)
        print(L_ch.shape)
        return (L_ch, AB_ch)


In [None]:
train_imgs = os.listdir(TRAIN_DATASET_PATH)
train_Dataset = COCO_Dataset(ALL_IMG_PATH = train_imgs, BASE_PATH=TRAIN_DATASET_PATH,
                             purpose='train')
train_dataloader = DataLoader(train_Dataset, batch_size = batch_size, shuffle=True,
                              num_workers = num_workers, pin_memory=True, drop_last=True)

val_imgs = sorted(os.listdir(VAL_DATASET_PATH))
val_Dataset = COCO_Dataset(ALL_IMG_PATH = val_imgs, BASE_PATH=VAL_DATASET_PATH,
                             purpose='val')
val_dataloader = DataLoader(val_Dataset, batch_size = batch_size, shuffle=False,
                              num_workers = num_workers, pin_memory=True, drop_last=True)


# 2. Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

torch.manual_seed(seed)

<torch._C.Generator at 0x7d8224d035b0>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm


# --- UNet --- #
# credit to https://github.com/milesial/Pytorch-UNet for the unet implementation

class UNet(nn.Module):
    def __init__(self, in_channels, out_im_channels, batchnorm = True, dropout=0.3, bc=64):
        super(UNet, self).__init__()

        self.inc = inconv(in_channels, bc*1, batchnorm)
        self.down1 = down(bc*1, bc*2, batchnorm, dropout=dropout)
        self.down2 = down(bc*2, bc*4, batchnorm, dropout=dropout)
        self.down3 = down(bc*4, bc*8, batchnorm, dropout=dropout)
        self.down4 = down(bc*8, bc*8, batchnorm, dropout=dropout)
        self.up1 = up(bc*16, bc*4, batchnorm, dropout=dropout)
        self.up2 = up(bc*8, bc*2, batchnorm, dropout=dropout)
        self.up3 = up(bc*4, bc*1, batchnorm, dropout=dropout)
        self.up4 = up(bc*2, bc*2, batchnorm, dropout=dropout)
        self.outc = outconv(bc*2, out_im_channels,)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        return self.outc(x)

# --- helper modules --- #
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
      nn.ReLU(inplace=True),
    )


class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch, batchnorm=True):
        super(double_conv, self).__init__()
        if batchnorm:
            self.conv = nn.Sequential(
              nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
              nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
        else:
            self.conv = nn.Sequential(
              nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True),
              nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch, batchnorm):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch, batchnorm)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch, batchnorm, dropout=None):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(nn.MaxPool2d(2), double_conv(in_ch, out_ch, batchnorm))

        if dropout:
            self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.mpconv(x)

        if self.dropout:
            x = self.dropout(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, batchnorm, method='conv', dropout=None):
        super(up, self).__init__()

        if method == 'bilinear':
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        elif method == 'conv':
            self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
        elif method == 'upconv':
            self.up = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear'),
                nn.ReflectionPad2d(1),
                # note the interesting size and stride
                nn.Conv2d(in_ch // 2, in_ch // 2, kernel_size=2, stride=2, padding=0),
            )
        elif method == 'none':
            self.up = nn.Identity()

        self.conv = double_conv(in_ch, out_ch, batchnorm)

        if dropout:
            self.dropout = nn.Dropout(dropout)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # up conv here

        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))

        # for padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)

        if self.dropout:
            x = self.dropout(x)

        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch,):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)


    def forward(self, x):
        x = self.conv(x)
        return x

In [None]:
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

def generate_pretrain_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G


In [None]:
# Discriminator

class dcgan_Discriminator(nn.Module):
    def __init__(self):

        self.m = None

class Patch_Discriminator(nn.Module):
    def __init__(self, in_channels, bc=64):
        super().__init__()
        layer = []

        # step 0
        s_0 = [nn.Conv2d(in_channels, bc, 4, 2, 1, bias=False)]
        s_0.append(nn.LeakyReLU(0.2, inplace=True))
        s_0 = nn.Sequential(*s_0)
        layer.append(s_0)

        # step 1
        s_1 = [nn.Conv2d(bc, bc*2, 4, 2, 1, bias=False)]
        s_1.append(nn.BatchNorm2d(bc*2))
        s_1.append(nn.LeakyReLU(0.2, inplace=True))
        s_1 = nn.Sequential(*s_1)
        layer.append(s_1)

        # step 2
        s_2 = [nn.Conv2d(bc*2, bc*4, 4, 2, 1, bias=False)]
        s_2.append(nn.BatchNorm2d(bc*4))
        s_2.append(nn.LeakyReLU(0.2, inplace=True))
        s_2 = nn.Sequential(*s_2)
        layer.append(s_2)

        # step 3
        s_3 = [nn.Conv2d(bc*4, bc*8, 4, 1, 1, bias=False)]
        s_3.append(nn.BatchNorm2d(bc*8))
        s_3.append(nn.LeakyReLU(0.2, inplace=True))
        s_3 = nn.Sequential(*s_3)
        layer.append(s_3)

        # step 4
        s_4 = [nn.Conv2d(bc*8, 1, 4,1,1,)]
        s_4 = nn.Sequential(*s_4)
        layer.append(s_4)

        self.layers = nn.Sequential(*layer)

    def forward(self, x):
        return self.layers(x)



In [None]:
# Loss
class loss_GAN(nn.Module):
    def __init__(self,):
        super().__init__()
        self.loss = nn.BCEWithLogitsLoss()

    def calculate_loss(self, pred, isreal):
        gt_label = None
        if isreal == True:
            gt_label = torch.ones_like(pred, dtype=pred.dtype, device=pred.device)
        else:
            gt_label = torch.zeros_like(pred, dtype=pred.dtype, device=pred.device)
        return self.loss(pred, gt_label)

In [None]:
def denorm_l_channel(l_channel):
    return (l_channel + 1.0) * 50.0

def denorm_ab_channel(ab_channel):
    return ab_channel * 110.0

In [None]:
class Patch_GAN(nn.Module):
    def __init__(self, G_model, D_model, lr, loss_lambda = 100.):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        # self.G_model = G_model
        self.G_model = self.initialize_and_to_device(G_model, self.device)
        self.D_model = self.initialize_and_to_device(D_model, self.device)

        #self.G_model = self.G_model.to(self.device)
        #self.D_model = self.D_model.to(self.device)
        self.optim_G = optim.AdamW(self.G_model.parameters(), lr = lr, betas=(0.5, 0.999), weight_decay=1e-5)
        self.optim_D = optim.AdamW(self.D_model.parameters(), lr = lr, betas=(0.5, 0.999), weight_decay=1e-5)
        self.optim_G_only = optim.AdamW(self.G_model.parameters(), lr = lr, betas=(0.5, 0.999), weight_decay=1e-5)
        self.l1_loss = nn.L1Loss()
        self.GAN_loss = loss_GAN()
        self.loss_lambda = loss_lambda

    # Predict fake ab channel
    def forward(self, l_channel):
        fake_ab = self.G_model(l_channel)
        return fake_ab

    def whether_train_model(self, model, whether_train):
        for param in model.parameters():
            param.requires_grad = whether_train

    def train_step(self, l_channel, ab_channel):
        l_channel = l_channel.to(self.device)
        ab_channel = ab_channel.to(self.device)

        fake_ab = self.forward(l_channel)
        # Train Discriminator
        self.D_model.train()
        self.whether_train_model(self.D_model, whether_train=True)
        self.optim_D.zero_grad()

        # Backward on Discriminator
        gt_img = torch.cat((l_channel, ab_channel), dim=1)
        fake_img = torch.cat((l_channel, fake_ab), dim=1)
        D_pred_fake = self.D_model(fake_img.detach())
        D_loss_fake = self.GAN_loss.calculate_loss(pred=D_pred_fake, isreal=False)
        D_pred_gt = self.D_model(gt_img)
        D_loss_real = self.GAN_loss.calculate_loss(pred=D_pred_gt, isreal=True)

        D_loss_total = 0.5*D_loss_fake + 0.5*D_loss_real
        D_loss_total.backward()
        self.optim_D.step()

        # Backward on Generator, Freeze Generator
        self.G_model.train()
        self.whether_train_model(self.D_model, whether_train=False)
        self.optim_G.zero_grad()

        fake_img = torch.cat((l_channel, fake_ab), dim=1)
        G_pred_fake = self.D_model(fake_img)

        G_loss_GAN = self.GAN_loss.calculate_loss(pred=G_pred_fake, isreal = True)
        G_loss_L1 = self.l1_loss(fake_ab, ab_channel)
        G_loss_total = G_loss_GAN + G_loss_L1*self.loss_lambda
        G_loss_total.backward()
        self.optim_G.step()

        return [D_loss_fake.cpu().item(), D_loss_real.cpu().item(), D_loss_total.cpu().item(),
                G_loss_GAN.cpu().item(), G_loss_L1.cpu().item(), G_loss_total.cpu().item()]

    def visualize(self, l_channel, ab_channel, epoch=0):
        l_channel = l_channel.to(self.device)
        ab_channel = ab_channel.to(self.device)
        self.G_model.eval()
        with torch.no_grad():
            fake_ab = self.forward(l_channel).detach()
            fake_img = torch.cat((denorm_l_channel(l_channel), denorm_ab_channel(fake_ab)), dim=1).detach()
            gt_img = torch.cat((denorm_l_channel(l_channel), denorm_ab_channel(ab_channel)), dim=1)

            fake_img = fake_img.permute(0,2,3,1).cpu().numpy()
            gt_img = gt_img.permute(0,2,3,1).cpu().numpy()

            all_fake_rgb = []
            all_gt_rgb = []

            for i in range(fake_img.shape[0]):
                c_fake_img = fake_img[i]
                c_gt_img = gt_img[i]

                c_fake_rgb = lab2rgb(c_fake_img)
                c_gt_rgb = lab2rgb(c_gt_img)

                all_fake_rgb.append(c_fake_rgb)
                all_gt_rgb.append(c_gt_rgb)

        self.G_model.train()

        # Make plots:
        img_num = 5
        vis_fig = plt.figure(figsize=(9,16))
        for i in range(img_num):
            c_grey_img = l_channel[i][0].detach().cpu().numpy()
            vis = plt.subplot(img_num, 3, 1+ i*3)
            vis.imshow(c_grey_img, cmap='gray')
            vis.axis('off')

            vis = plt.subplot(img_num, 3, 2+ i*3)
            vis.imshow(all_fake_rgb[i])
            vis.axis('off')

            vis = plt.subplot(img_num, 3, 3+ i*3)
            vis.imshow(all_gt_rgb[i])
            vis.axis('off')
        plt.show()
        # plt.savefig(os.path.join(MODEL_SAVE_PATH, f"epoch={epoch}.png"))

    def save_model(self, save_path, epoch):
        save_path = os.path.join(save_path, f'epoch={epoch}.pt')
        torch.save({
            'epoch': epoch,
            'G_state_dict': self.G_model.state_dict(),
            'D_state_dict': self.D_model.state_dict(),
            'opt_G_state_dict': self.optim_G.state_dict(),
            'opt_D_state_dict': self.optim_D.state_dict(),
        }, save_path)

    def load_model(self, load_DIR):
        CHECKPOINT = torch.load(load_DIR)
        self.G_model.load_state_dict(CHECKPOINT['G_state_dict'])
        self.D_model.load_state_dict(CHECKPOINT['D_state_dict'])
        self.optim_G.load_state_dict(CHECKPOINT['opt_G_state_dict'])
        self.optim_D.load_state_dict(CHECKPOINT['opt_D_state_dict'])

    def weight_initialization(self, model, ):
        for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        return model

    def initialize_and_to_device(self, model, device):
        model = model.to(device)
        model = self.weight_initialization(model)
        return model

    def train_generator_only(self, l_channel, ab_channel, ):
        self.G_model.train()
        l_channel = l_channel.to(self.device)
        ab_channel = ab_channel.to(self.device)
        self.optim_G_only.zero_grad()

        fake_ab = self.forward(l_channel)
        G_loss = self.l1_loss(fake_ab, ab_channel)
        G_loss.backward()
        self.optim_G_only.step()

        return G_loss.cpu().item()

# Training Loop

In [None]:
def train(model, epoch, ):
    visualization_batch = next(iter(val_dataloader))
    vis_l, vis_ab = visualization_batch

    best_G_loss = 99999999

    loss_log = []
    loss_log_save_DIR = os.path.join(MODEL_SAVE_PATH, 'loss.npy')

    for epoch_num in range(epoch):
        loss_arr = []
        for c_data in tqdm(train_dataloader):
            c_l_channel, c_ab_channel = c_data
            c_loss = model.train_generator_only(c_l_channel, c_ab_channel)
            loss_arr.append(c_loss)
        loss_arr = np.array(loss_arr)
        loss_arr = np.mean(loss_arr, axis=0)

        loss_log.append(loss_arr)
        np.save(loss_log_save_DIR, np.array(loss_log))

        print(loss_arr)

        if epoch_num % 5 == 0:
            print(f'Epoch: {epoch_num}')
            model.visualize(vis_l, vis_ab, epoch_num)
            model.save_model(MODEL_SAVE_PATH, epoch_num)

    # for epoch_num in range(epoch):
    #     loss_arr = []

    #     for c_data in tqdm(train_dataloader):
    #         c_l_channel, c_ab_channel = c_data
    #         c_loss_arr = model.train_step(c_l_channel, c_ab_channel)

    #         loss_arr.append(c_loss_arr)
    #     loss_arr = np.array(loss_arr)
    #     loss_arr = np.mean(loss_arr, axis=0)

    #     loss_log.append(loss_arr)
    #     np.save(loss_log_save_DIR, np.array(loss_log))

    #     epoch_G_loss = loss_arr[-1]
    #     if epoch_G_loss < best_G_loss:
    #         # model.save_model(MODEL_SAVE_PATH, epoch_num)
    #         best_G_loss = epoch_G_loss


    #     if epoch_num % 5 == 0:
    #         print(f'Epoch: {epoch_num}')
    #         model.visualize(vis_l, vis_ab, epoch_num)
    #         model.save_model(MODEL_SAVE_PATH, epoch_num)

In [None]:
# G_model = Unet(1, 2, )
G_model = generate_pretrain_unet(n_input=1, n_output=2, size=256)
D_model = Patch_Discriminator(3)
model = Patch_GAN(G_model, D_model, 1e-4)
train(model, 100)

# Visualization

In [None]:
model_path = ''
G_model = generate_pretrain_unet(n_input=1, n_output=2, size=256)
D_model = Patch_Discriminator(3)
model = Patch_GAN(G_model, D_model, 1e-4)

In [None]:
model_path = ''

G_model = Unet(1, 2, )
D_model = Patch_Discriminator(3)
model = Patch_GAN(G_model, D_model, 1e-4)

In [None]:
from torchvision import transforms

batch_size = 5
num_workers = 8


train_imgs = ['000000179025.jpg', '000000331331.jpg', '000000156323.jpg', '000000562305.jpg', '000000222908.jpg']
train_Dataset = COCO_Dataset(ALL_IMG_PATH = train_imgs, BASE_PATH=TRAIN_DATASET_PATH,
                             purpose='train', )
train_dataloader = DataLoader(train_Dataset, batch_size = batch_size, shuffle=False,
                              num_workers = num_workers, pin_memory=True, drop_last=True)

val_imgs = ['000000002347.jpg', '000000002429.jpg', '000000002444.jpg', '000000004125.jpg', '000000005115.jpg']
val_Dataset = COCO_Dataset(ALL_IMG_PATH = val_imgs, BASE_PATH=VAL_DATASET_PATH,
                             purpose='val', )
val_dataloader = DataLoader(val_Dataset, batch_size = batch_size, shuffle=False,
                              num_workers = num_workers, pin_memory=True, drop_last=True)

In [None]:
print("training set")
for c_data in tqdm(train_dataloader):
    vis_l, vis_ab = c_data
    model.load_model(model_path)
    model.visualize(vis_l, vis_ab)

print("validation set")
for c_data in tqdm(val_dataloader):
    vis_l, vis_ab = c_data
    model.visualize(vis_l, vis_ab)