In [1]:
import sys
import os
import matplotlib.pyplot as plt
import PIL
import cv2
import time
import skimage
import numpy as np
from sklearn.model_selection import train_test_split

from torchsummary import summary
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as vision_models

# pip install fastai==2.4
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_colab = None

print(f"device: {device}")

ModuleNotFoundError: No module named 'matplotlib'

In [2]:
home = "./input"
os.listdir("./input/ab/ab")

['ab1.npy', 'ab2.npy', 'ab3.npy']

In [3]:
def load_data(home, channels_first=True, train_percent=0.8):
    ab1 = np.load(os.path.join(home,"ab/ab", "ab1.npy"))
    ab2 = np.load(os.path.join(home, "ab/ab", "ab2.npy"))
    ab3 = np.load(os.path.join(home,"ab/ab", "ab3.npy"))
    ab = np.concatenate([ab1, ab2, ab3], axis=0)
#     ab = np.transpose(ab, [0, 3, 1, 2])
    l = np.load(os.path.join(home,"l/gray_scale.npy"))
    
    
    return train_test_split(ab,l, train_size=train_percent)
    

In [3]:
ab_train, ab_test, l_train, l_test = load_data(home, channels_first=True)

NameError: name 'load_data' is not defined

In [3]:
def plot_channels(img_batch, figsize=(8,3), cmap=None):
    if len(img_batch.shape)==3:
        img_batch = np.expand_dims(img_batch, axis=0)
    for img in img_batch:
        fig = plt.figure(figsize=figsize)
        plt.subplot(1,3,1)
        plt.imshow(img[:,:,0].T, cmap=cmap)
        plt.title(f"ab-0")
        
        plt.subplot(1,3,2)
        plt.imshow(img[:,:,1].T, cmap=cmap)
        plt.title(f"ab-1")
        
        plt.subplot(1,3,3)
        plt.imshow(img[:,:,2].T, cmap=cmap)
        plt.title(f"l")
        
        plt.show()

def plot_image(img_batch, figsize=(8,3), cmap=None, title=None):
    if len(img_batch.shape)==3:
        img_batch = np.expand_dims(img_batch, axis=0)
    N = len(img_batch)
    fig = plt.figure(figsize=figsize)
    for i in range(N):
        img = img_batch[i]
#         img = np.transpose(img, [1,0,2])
        plt.subplot(1,N,i+1)
        plt.imshow(img, cmap=cmap)
    if title is not None:
        plt.title(f"{title}")
    plt.show()
        
def to_lab(l, ab, channels_first=True):
    if channels_first:
        if len(l.shape)==3:
            l = np.expand_dims(l, axis=1)
        lab = np.concatenate([l, ab], axis=1)
    else:
        if len(l.shape)==3:
            l = np.expand_dims(l, axis=3)
        lab = np.concatenate([l, ab], axis=3)
    return lab

def lab2rgb(lab):
    if len(lab.shape)==4:
        arr = []
        for img in lab:
            img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
            arr.append(img)
        arr = np.array(arr)
    else:
        arr = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
    return arr

def rgb2lab(rgb):
    if len(rgb.shape)==4:
        arr = []
        for img in rgb:
            img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            arr.append(img)
        arr = np.array(arr)
    else:
        arr = cv2.cvtColor(rgb, cv2.COLOR_LAB2RGB)
    return arr

def to_channel_first(arr):
    if len(arr.shape)==4:
        arr = np.transpose(arr, [0,3,2,1])
    else:
        arr = np.transpose(arr, [2,1,0])
    return arr

def to_channel_last(arr):
    if len(arr.shape)==4:
        arr = np.transpose(arr, [0,3,2,1])
    else:
        arr = np.transpose(arr, [2,1,0])
    return arr

In [4]:
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)
        
        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
    
    def forward(self, x):
        return self.model(x)
    


## Transform functions

In [5]:
def transform_expand_dim(axis):
    def fn(arr):
        arr = np.expand_dims(arr, axis=axis)
        return arr
    return fn

def transform_multiply(mul):
    def fn(arr):
        arr = arr * mul
        return arr
    return fn

def transform_divide(div):
    def fn(arr):
        arr = arr / div
        return arr
    return fn

def model_parameters_count(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

## Unet Model

In [6]:
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

def build_fastai_model(in_channels=1, out_channels=2, image_shape=(224, 224)):
    model_body = create_body(resnet18(), n_in=in_channels, cut=-2)
    model = DynamicUnet(encoder=model_body, n_out=out_channels, img_size=image_shape)
    return model.to(device)


## dataset

In [7]:
class DatasetImg(Dataset):
    def __init__(self, l, ab, input_transforms=[], output_transforms=[]):
        self.l = l
        self.ab = ab
        self.input_transforms = input_transforms
        self.output_transforms = output_transforms
    
    def __len__(self):
        return len(self.l)
    
    def __getitem__(self, idx):
        x = self.l[idx]
        y = self.ab[idx]
        
        if self.input_transforms is not None:
            for fn in self.input_transforms:
                x = fn(x)
        
        if self.output_transforms is not None:
            for fn in self.output_transforms:
                y = fn(y)
        return x,y

## Utility

In [8]:


def to_channel_first(arr):
    if len(arr.shape)==4:
        arr = np.transpose(arr, [0,3,2,1])
    else:
        arr = np.transpose(arr, [2,1,0])
    return arr

def transform_expand_dim(axis):
    def fn(arr):
        arr = np.expand_dims(arr, axis=axis)
        return arr
    return fn

def transform_divide(div):
    def fn(arr):
        arr = arr / div
        return arr
    return fn


# ab_train.shape, l_train.shape, ab_test.shape, l_test.shape

input_shape = [224, 224]
batch_size = 1
num_examples = -1
device=  "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")
epochs = 1
plot_freq = 1


class DatasetImg(Dataset):
    def __init__(self, l, ab, input_transforms=[], output_transforms=[]):
        
        self.l = l
        self.ab = ab
        self.input_transforms = input_transforms
        self.output_transforms = output_transforms
    
    def __len__(self):
        return len(self.l)
    
    def __getitem__(self, idx):
        l = self.l[idx]
        ab = self.ab[idx]
        
        if self.input_transforms is not None:
            for fn in self.input_transforms:
                l = fn(l)
        
        if self.output_transforms is not None:
            for fn in self.output_transforms:
                ab = fn(ab)
        return {'L': l, 'ab': ab}
        return l, ab
    
def load_data(home, channels_first=True, train_percent=0.8):
    ab1 = np.load(os.path.join(home,"ab/ab", "ab1.npy"))
    ab2 = np.load(os.path.join(home, "ab/ab", "ab2.npy"))
    ab3 = np.load(os.path.join(home,"ab/ab", "ab3.npy"))
    ab = np.concatenate([ab1, ab2, ab3], axis=0).astype("float32")
    # ab = np.transpose(ab, [0, 3, 1, 2])
    l = np.load(os.path.join(home,"l/gray_scale.npy")).astype("float32")


    return train_test_split(ab,l, train_size=train_percent)

ab_train, ab_test, l_train, l_test = load_data("input", channels_first=True)
    
input_transforms = [transform_expand_dim(axis=2),
                    to_channel_first,
                   transform_divide(255.0)
                   ]
output_transforms = [
                    to_channel_first,
                   transform_divide(255.0)
                   ]

ds_train = DatasetImg(l_train[:num_examples], ab_train[:num_examples], input_transforms=input_transforms, output_transforms=output_transforms)
train_dl = DataLoader(ds_train, batch_size=batch_size, shuffle=True)

ds_test = DatasetImg(l_test[:num_examples], ab_test[:num_examples], input_transforms=input_transforms, output_transforms=output_transforms)
test_dl = DataLoader(ds_test, batch_size=batch_size, shuffle=True)


device: cuda


In [10]:
class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)
    
def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")
        
def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

In [11]:
def pretrain_generator(net_G, train_dl, opt, criterion, epochs):
    for e in range(epochs):
        loss_meter = AverageMeter()
        for data in tqdm(train_dl):
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            loss_meter.update(loss.item(), L.size(0))
            
        print(f"Epoch {e + 1}/{epochs}")
        print(f"L1 Loss: {loss_meter.avg:.5f}")

net_G = build_fastai_model(in_channels=1, out_channels=2, image_shape=(224,224))
opt = torch.optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()        
pretrain_generator(net_G, train_dl, opt, criterion, 20)
#torch.save(net_G.state_dict(), "res18-unet.pt")

100%|██████████| 19999/19999 [35:31<00:00,  9.38it/s] 


Epoch 1/20
L1 Loss: 0.04467


100%|██████████| 19999/19999 [34:50<00:00,  9.57it/s]


Epoch 2/20
L1 Loss: 0.04125


 76%|███████▌  | 15225/19999 [26:24<08:18,  9.58it/s]