In [None]:
!pip install fastai --upgrade

Collecting fastai
[?25l  Downloading https://files.pythonhosted.org/packages/5b/53/edf39e15b7ec5e805a0b6f72adbe48497ebcfa009a245eca7044ae9ee1c6/fastai-2.3.0-py3-none-any.whl (193kB)
[K     |█▊                              | 10kB 20.2MB/s eta 0:00:01[K     |███▍                            | 20kB 17.8MB/s eta 0:00:01[K     |█████                           | 30kB 14.7MB/s eta 0:00:01[K     |██████▊                         | 40kB 13.7MB/s eta 0:00:01[K     |████████▌                       | 51kB 8.9MB/s eta 0:00:01[K     |██████████▏                     | 61kB 10.3MB/s eta 0:00:01[K     |███████████▉                    | 71kB 9.7MB/s eta 0:00:01[K     |█████████████▌                  | 81kB 9.9MB/s eta 0:00:01[K     |███████████████▏                | 92kB 10.8MB/s eta 0:00:01[K     |█████████████████               | 102kB 8.6MB/s eta 0:00:01[K     |██████████████████▋             | 112kB 8.6MB/s eta 0:00:01[K     |████████████████████▎           | 122kB 8.6MB/s eta

In [None]:
import os
import glob
import time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.models.resnet import resnet18
from torch.utils.data import Dataset, DataLoader

from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet
from fastai.data.external import untar_data, URLs

In [None]:
SIZE = 256

In [None]:
class AverageMeter:
  def __init__(self):
    self.reset()

  def reset(self):
    self.count, self.avg, self.sum = 0., 0., 0.

  def update(self, val, count):
    self.count += count
    self.sum += val * count
    self.avg = self.sum / self.count

def init_loss_meters():
  loss_D_fake = AverageMeter()
  loss_D_real = AverageMeter()
  loss_D = AverageMeter()

  loss_G_GAN = AverageMeter()
  loss_G_L1 = AverageMeter()
  loss_G = AverageMeter()

  return {"loss_D_fake": loss_D_fake,
          "loss_D_real": loss_D_real,
          "loss_D": loss_D,
          "loss_G_GAN": loss_G_GAN,
          "loss_G_L1": loss_G_L1,
          "loss_G": loss_G}

def update_losses(model, loss_meters, count):
  for loss_name, loss_meter in loss_meters.items():
    loss = getattr(model, loss_name)
    loss_meter.update(loss.item(), count)

In [None]:
def build_generator(n_inputs, n_outputs, core):
  body = create_body(core, pretrained=True, n_in=n_inputs, cut=-2)

  return DynamicUnet(body, n_outputs, (SIZE, SIZE))

In [None]:
def train_generator(device, G, train_dl, val_dl, opt, criterion, epochs):
  for e in range(epochs):
    train_loss_meter = AverageMeter()
    val_loss_meter = AverageMeter()

    G.train()
    for data in tqdm(train_dl):
      L, ab = data["L"].to(device), data["ab"].to(device)

      preds = G(L)

      loss = criterion(preds, ab)

      opt.zero_grad()
      loss.backward()
      opt.step()

      train_loss_meter.update(loss.item(), L.size(0))

    G.eval()
    for data in tqdm(val_dl):
      L, ab = data["L"].to(device), data["ab"].to(device)

      preds = G(L)

      loss = criterion(preds, ab)

      val_loss_meter.update(loss.item(), L.size(0))

    print(f"Epoch {e + 1}/{epochs}")
    print(f"L1 --- Trn_loss: {train_loss_meter.avg:.4f} --- Val_loss: {val_loss_meter.avg:.4f}")
    torch.save(G.state_dict(), f"./gen_models/{e}_{time.time()}_res18-unet.pt")


In [None]:
class TrainingDataset(Dataset):
  def __init__(self, paths):
    self.transforms = transforms.Compose([
            transforms.Resize((SIZE, SIZE)),
            transforms.RandomHorizontalFlip(),
        ])
    self.paths = paths

  def __getitem__(self, idx):
    img = Image.open(self.paths[idx]).convert("RGB")
    img = self.transforms(img)
    img = np.array(img)

    lab_img = rgb2lab(img).astype("float32")
    lab_img = transforms.ToTensor()(lab_img)

    L = lab_img[[0], ...] / 50. - 1.
    ab = lab_img[[1, 2], ...] / 110.

    return {"L": L, "ab": ab}

  def __len__(self):
    return len(self.paths)

class ValidationDataset(Dataset):
  def __init__(self, paths):
    self.transforms = transforms.Compose([
            transforms.Resize((SIZE, SIZE)),
        ])
    self.paths = paths

  def __getitem__(self, idx):
    img = Image.open(self.paths[idx]).convert("RGB")
    img = self.transforms(img)
    img = np.array(img)

    lab_img = rgb2lab(img).astype("float32")
    lab_img = transforms.ToTensor()(lab_img)

    L = lab_img[[0], ...] / 50. - 1.
    ab = lab_img[[1, 2], ...] / 110.

    return {"L": L, "ab": ab}

  def __len__(self):
    return len(self.paths)

In [None]:
root = str(untar_data(URLs.COCO_SAMPLE)) + "/train_sample"

paths = glob.glob(root + "/*.jpg")

np.random.seed(42)
paths_subset = np.random.choice(paths, 12_000, replace=False)

rand_idxs = np.random.permutation(12_000)
train_idxs = rand_idxs[:10_000]
val_idxs = rand_idxs[10_000:]

train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]

train_dset = TrainingDataset(train_paths)
val_dset = ValidationDataset(val_paths)

train_dl = DataLoader(train_dset, 16, num_workers=2, pin_memory=True)
val_dl = DataLoader(val_dset, 8, num_workers=2, pin_memory=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
G = build_generator(1, 2, resnet18).to(device)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [None]:
l1_opt = optim.Adam(G.parameters(), lr=1e-4)
l1_loss = nn.L1Loss()

In [None]:
train_generator(device, G, train_dl, val_dl, l1_opt, l1_loss, 30)

HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 1/30
L1 --- Trn_loss: 0.0870 --- Val_loss: 0.0803


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 2/30
L1 --- Trn_loss: 0.0794 --- Val_loss: 0.0788


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 3/30
L1 --- Trn_loss: 0.0766 --- Val_loss: 0.0798


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 4/30
L1 --- Trn_loss: 0.0741 --- Val_loss: 0.0787


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 5/30
L1 --- Trn_loss: 0.0715 --- Val_loss: 0.0797


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 6/30
L1 --- Trn_loss: 0.0695 --- Val_loss: 0.0790


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 7/30
L1 --- Trn_loss: 0.0676 --- Val_loss: 0.0801


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 8/30
L1 --- Trn_loss: 0.0662 --- Val_loss: 0.0791


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 9/30
L1 --- Trn_loss: 0.0647 --- Val_loss: 0.0791


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 10/30
L1 --- Trn_loss: 0.0637 --- Val_loss: 0.0791


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 11/30
L1 --- Trn_loss: 0.0629 --- Val_loss: 0.0817


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 12/30
L1 --- Trn_loss: 0.0622 --- Val_loss: 0.0819


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 13/30
L1 --- Trn_loss: 0.0614 --- Val_loss: 0.0799


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 14/30
L1 --- Trn_loss: 0.0602 --- Val_loss: 0.0790


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 15/30
L1 --- Trn_loss: 0.0588 --- Val_loss: 0.0802


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 16/30
L1 --- Trn_loss: 0.0579 --- Val_loss: 0.0793


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 17/30
L1 --- Trn_loss: 0.0574 --- Val_loss: 0.0821


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 18/30
L1 --- Trn_loss: 0.0567 --- Val_loss: 0.0785


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 19/30
L1 --- Trn_loss: 0.0556 --- Val_loss: 0.0777


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 20/30
L1 --- Trn_loss: 0.0545 --- Val_loss: 0.0778


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 21/30
L1 --- Trn_loss: 0.0534 --- Val_loss: 0.0780


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 22/30
L1 --- Trn_loss: 0.0527 --- Val_loss: 0.0781


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 23/30
L1 --- Trn_loss: 0.0521 --- Val_loss: 0.0778


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 24/30
L1 --- Trn_loss: 0.0515 --- Val_loss: 0.0780


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 25/30
L1 --- Trn_loss: 0.0509 --- Val_loss: 0.0778


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 26/30
L1 --- Trn_loss: 0.0501 --- Val_loss: 0.0780


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 27/30
L1 --- Trn_loss: 0.0495 --- Val_loss: 0.0785


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 28/30
L1 --- Trn_loss: 0.0489 --- Val_loss: 0.0790


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 29/30
L1 --- Trn_loss: 0.0484 --- Val_loss: 0.0787


HBox(children=(FloatProgress(value=0.0, max=625.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))


Epoch 30/30
L1 --- Trn_loss: 0.0476 --- Val_loss: 0.0786


In [None]:
!zip -r ./gen_models.zip ./gen_models

  adding: gen_models/ (stored 0%)
  adding: gen_models/17_1619678546.44514_res18-unet.pt (deflated 7%)
  adding: gen_models/4_1619672929.0411165_res18-unet.pt (deflated 7%)
  adding: gen_models/28_1619683290.488112_res18-unet.pt (deflated 7%)
  adding: gen_models/22_1619680706.0349762_res18-unet.pt (deflated 7%)
  adding: gen_models/16_1619678115.397143_res18-unet.pt (deflated 7%)
  adding: gen_models/24_1619681568.6929681_res18-unet.pt (deflated 7%)
  adding: gen_models/0_1619671187.9509108_res18-unet.pt (deflated 7%)
  adding: gen_models/21_1619680274.38992_res18-unet.pt (deflated 7%)
  adding: gen_models/7_1619674230.4844167_res18-unet.pt (deflated 7%)
  adding: gen_models/20_1619679842.199287_res18-unet.pt (deflated 7%)
  adding: gen_models/11_1619675958.1446617_res18-unet.pt (deflated 7%)
  adding: gen_models/19_1619679410.9194286_res18-unet.pt (deflated 7%)
  adding: gen_models/23_1619681137.5498242_res18-unet.pt (deflated 7%)
  adding: gen_models/9_1619675094.6967552_res18-unet.

In [None]:
from google.colab import drive
drive.mount('./drive', force_remount=True)

Mounted at ./drive
