In [1]:
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tarfile
import torch
import torch.nn as nn
import torch.nn.functional as F

from matplotlib.image import imread
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
from skimage import io
from skimage.color import rgb2lab, lab2rgb, rgb2gray
from sklearn.metrics import r2_score
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate

In [2]:
torch.backends.cudnn.benchmark = True

## Data

In [None]:
def unpack_dataset():
    ! mkdir data
    
    # train dataset
    # ! wget http://data.csail.mit.edu/places/places365/train_256_places365standard.tar
    # ! tar -xvf train_256_places365standard.tar -C data
    
    # validation dataset
    ! wget http://data.csail.mit.edu/places/places365/val_256.tar
    ! tar -xvf val_256.tar -C data
    
    # test dataset
    # ! wget http://data.csail.mit.edu/places/places365/test_256.tar
    # ! tar -xvf test_256.tar -C data

In [None]:
unpack_dataset()

In [3]:
PATH = Path("data/")
list(PATH.iterdir())

[PosixPath('data/val_256')]

## Dataset

In [4]:
class ImageDataset(Dataset):
    def __init__(self, files):
        self.files = np.array(files)
        self.length = len(files)
    
    def __getitem__(self, idx):
        img = imread(self.files[idx])
        if img.shape == (256, 256, 3):  # if a color image
            img_lab = rgb2lab(img)
            img_lab = (img_lab + [0, 128, 128]) / [100, 255, 255]  # normalize L, a, b dimensions
            img_lightness = img_lab[:, :, 0:1].transpose(2, 0, 1)
            img_ab = img_lab[:, :, 1:3].transpose(2, 0, 1)
        else:  # if a grayscale image
            img_lightness = (img/255)[None, :, :]
            img_ab = np.zeros(shape=(2, 256, 256))
        return img_lightness, img_ab
    
    def __len__(self):
        return self.length

In [5]:
val_files = glob.glob("data/val_256/*.jpg")
len(val_files)

36500

In [6]:
train, val = val_files[:29200], val_files[29200:]

In [7]:
train_ds = ImageDataset(train)
val_ds = ImageDataset(val)

In [8]:
batch_size = 30
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=10)
val_dl = DataLoader(val_ds, batch_size=batch_size, num_workers=10)

## Model

In [9]:
def basic_block(in_, out_, kernel_size=3, stride=1):
    """Return a block consisting of a conv2d, ReLU and BatchNorm2d layer."""
    padding = kernel_size // 2
    block = nn.Sequential(
        nn.Conv2d(in_, out_, kernel_size, stride, padding),
        nn.ReLU(),
        nn.BatchNorm2d(out_))
    return block

In [10]:
class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        self.layers = nn.Sequential(
            basic_block(1, 64, kernel_size=3, stride=2),
            basic_block(64, 128, kernel_size=3, stride=2),
            basic_block(128, 256, kernel_size=3, stride=2),
            basic_block(256, 512, kernel_size=3, stride=1),
            basic_block(512, 256, kernel_size=3, stride=1))
        self.upsample = nn.Upsample(scale_factor=8)
        self.out_layer = basic_block(256, 2, kernel_size=3, stride=1)
    
    def forward(self, x):
        x = self.layers(x)
        x = self.upsample(x)
        return self.out_layer(x)

## Training

In [11]:
def save_model(model, path): torch.save(model.state_dict(), path)

def load_model(model, path): model.load_state_dict(torch.load(path))

In [12]:
def train_epoch(model, train_dl, val_dl, optimizer, epochs=10):
    iterations = len(train_dl) * epochs
    pbar = tqdm(total=iterations)
    best_val_loss = float("inf")
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        total = 0
        for x, y in train_dl:
            x = x.float().cuda()
            y = y.float().cuda()
            y_hat = model(x)
            loss = F.mse_loss(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * y.shape[0]
            total += y.shape[0]
            pbar.update()
        val_loss = val_metrics(model, val_dl)
        print(f"train loss: {total_loss/total:.8f}\tval loss: {val_loss:.8f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            path = f"models/model_{best_val_loss:.6f}"
            save_model(model, path)
            print(path)
    
    return best_val_loss

In [13]:
def val_metrics(model, val_dl):
    model.eval()
    total_loss = 0
    total = 0
    y_pred = []
    y_true = []
    
    for x, y in val_dl:
        x = x.float().cuda()
        y = y.float()
        out = model(x)
        loss = F.mse_loss(out, y.cuda())
        total_loss += loss.item() * y.shape[0]
        total += y.shape[0]
        y_pred.append(out.cpu().detach().numpy())
        y_true.append(y)
    
    y_pred = np.vstack(y_pred)
    y_true = np.vstack(y_true)
    return total_loss/total

In [14]:
model = ColorizationNet().cuda()

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
best_val = train_epoch(model, train_dl, val_dl, optimizer, epochs=2)

HBox(children=(FloatProgress(value=0.0, max=1948.0), HTML(value='')))

train loss: 0.04738433	val loss: 0.00348357
models/model_0.003484
train loss: 0.00343976	val loss: 0.00339602
models/model_0.003396
