In [3]:
!nvidia-smi

Thu Jan 28 13:46:51 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
!pip uninstall fastai -y
!pip install fastai -q

import fastai
fastai.__version__

Uninstalling fastai-1.0.61:
  Successfully uninstalled fastai-1.0.61
[K     |████████████████████████████████| 194kB 8.8MB/s 
[K     |████████████████████████████████| 61kB 7.5MB/s 
[?25h

'2.2.5'

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
# @title Download dataset
import os
import shutil
import time
from google_drive_downloader import GoogleDriveDownloader as gdd

if not os.path.exists("/content/images/valid.zip"):
    # https://drive.google.com/file/d/1bXD5hR5WlIB6LsqDHmNn29EMA43Z82f3/view?usp=sharing
    gdd.download_file_from_google_drive(file_id='1bXD5hR5WlIB6LsqDHmNn29EMA43Z82f3',
                                        dest_path='./images/valid.zip',
                                        unzip=True,
                                        showsize=True,
                                        )
    # !rm -rf /content/images/valid.zip

# download scripts

!rm -rf *.py
!rm -rf *.zip
time.sleep(1)
!wget -q https://github.com/veb-101/GAN-Colorisation/archive/master.zip
!unzip -qq -o /content/master.zip 
!mv ./GAN-Colorisation-master/trainer.py ./trainer.py
!mv ./GAN-Colorisation-master/utils.py ./utils.py
!mv ./GAN-Colorisation-master/model.py ./model.py
!mv ./GAN-Colorisation-master/data_loader.py ./data_loader.py
!rm -rf /content/GAN-Colorisation-master


if not os.path.exists("/content/images/train.zip"):
    # https://drive.google.com/file/d/1c5WQwglbVL9S_LHH5E9XU9bCWLzZg_V7/view?usp=sharing
    gdd.download_file_from_google_drive(file_id='1c5WQwglbVL9S_LHH5E9XU9bCWLzZg_V7',
                                        dest_path='./images/train.zip',
                                        unzip=True,
                                        showsize=True,
                                        )
    # !rm -rf /content/images/train.zip

Downloading 1bXD5hR5WlIB6LsqDHmNn29EMA43Z82f3 into ./images/valid.zip... 
9.3 MiB Done.
Unzipping...Done.
Downloading 1c5WQwglbVL9S_LHH5E9XU9bCWLzZg_V7 into ./images/train.zip... 
495.6 MiB Done.
Unzipping...Done.


In [7]:
#@title setup

import os
import gc
import random
import shutil
import warnings
import importlib
import numpy as np
# from tqdm. import tqdm
from tqdm.notebook import tqdm



import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

import model
import utils
import trainer
import data_loader

seed = 41
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

importlib.reload(data_loader)
importlib.reload(utils)
importlib.reload(trainer)
importlib.reload(model)

from trainer import MainModel
from utils import init_model
from data_loader import make_dataloaders
from utils import update_losses, visualize
from utils import init_model, AverageMeter
from utils import create_loss_meters, log_results
from model import Generator_Res_Unet, Discriminator



def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_default_device()

warnings.simplefilter("ignore", UserWarning)

In [8]:
def load_model(config, ckpt_number=-1, device=None, prefix="warm_"):
    
    lr_G = config["lr_G"]
    lr_D = config["lr_D"]
    beta1 = config["beta1"]
    beta2 = config["beta2"]
    load_previous = config["load_previous"]
    
    epoch = 0
    gen_model = Generator_Res_Unet()
    net_G = init_model(gen_model.get_model(), device)

    net_D = init_model(Discriminator(input_channels=3), device)
    scaler_G = GradScaler()
    scaler_D = GradScaler()

    opt_G = optim.Adam(net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
    opt_D = optim.Adam(net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    drive_path = r"/content/drive/MyDrive/Project-Colorisation"
    print(f"[*] Finding checkpoint {ckpt_number} in {drive_path}")

    checkpoint_file = f"{prefix}checkpoint_{ckpt_number}.tar"
    checkpoint_path = os.path.join(drive_path, checkpoint_file)

    if not os.path.exists(checkpoint_path):
        print(f"[!] No checkpoint for epoch {ckpt_number}")
    else:    
        checkpoint = torch.load(checkpoint_path)

        epoch = checkpoint["epoch"]

        net_G.load_state_dict(checkpoint[f"net_G"])
        print("Generator weight loaded")

        try:
            net_D.load_state_dict(checkpoint[f"net_D"])
            print("Discriminator weights loaded.")
        except:
            pass

        if load_previous:
            try:  
                opt_D.load_state_dict(checkpoint[f"opt_D"])
                print("Discriminator optimizer loaded.")
            except:
                pass

            opt_G.load_state_dict(checkpoint[f"opt_G"])
            print("Optimizer's state loaded")
            
            try:
                scaler_G.load_state_dict(checkpoint[f"scaler_G"])
                print("Grad Scaler - Generator loaded")
                scaler_D.load_state_dict(checkpoint[f"scaler_D"])
                print("Grad Scaler - Discriminator loaded")
            except Exception as e:
                pass
        

    return_ = {
        "epoch": epoch,
        "net_G": net_G,
        "opt_G": opt_G,
        "scaler_G": scaler_G,
        "net_D": net_D,
        "scaler_D": scaler_D,
        "opt_D": opt_D
    }
    return return_

In [14]:
def pretrain_generator(train_dl, valid_dl, net, scaler, opt, criterion, start_epoch, epochs):
    savedir = "/content/warmup_train"
    os.makedirs(savedir, exist_ok=True)
    drive_path = "/content/drive/MyDrive/Project-Colorisation"
    os.makedirs(drive_path, exist_ok=True)
    
    for e in range(start_epoch, start_epoch+epochs):
        loss_meter = AverageMeter()
        training_loader_iter = iter(train_dl)
        length_train = len(training_loader_iter)
        
        for i in tqdm(range(length_train), desc=f"Epoch: [{e}/{start_epoch+epochs-1}]"):
            data = next(training_loader_iter)

            opt.zero_grad()
            L, ab = data['L'].to(device), data['ab'].to(device)

            with autocast():
                preds = net(L)
                loss = criterion(preds, ab)
            
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler_G.update()
            # loss.backward()
            # opt.step()
            loss_meter.update(loss.detach().item(), L.size(0))
            torch.cuda.empty_cache()
            gc.collect()
            
            if i == 0 or i == int(length_train / 2) or i == (length_train - 1):
                print(f"Iteration:[{i+1}/{length_train}] L1 Loss: [{loss_meter.avg:.5f}]")
        
        for idx, valid_data in enumerate(valid_dl):
            visualize(net, valid_data, save_name=os.path.join(savedir, f"Validation_Epoch_{e}-{idx}.png"))
        
        file_name = f"warm_checkpoint_{e}.tar"
        torch.save({
            "epoch": e,
            "net_G": net.state_dict(),
            "opt_G": opt.state_dict(),
            "scaler_G": scaler.state_dict(),
        }, file_name)
        shutil.copyfile(file_name, os.path.join(drive_path, file_name))

In [10]:
def train_model(model, train_dl, valid_dl, start_epoch, epochs):
    # getting a batch for visualizing the model output after fixed intrvals
    savedir = "/content/gan_train"
    os.makedirs(savedir, exist_ok=True)
    drive_path = "/content/drive/MyDrive/Project-Colorisation"
    os.makedirs(drive_path, exist_ok=True)

    for e in range(start_epoch, start_epoch+epochs):
        loss_meter_dict = create_loss_meters()
        training_loader_iter = iter(train_dl)
        length_train = len(training_loader_iter)
        
        for i in tqdm(range(length_train), desc=f"Epoch: [{e}/{start_epoch+epochs-1}]"):
            data = next(training_loader_iter)
            model.setup_input(data)
            model.optimize()
            
            # function updating the log objects
            update_losses(model, loss_meter_dict, count=data["L"].size(0)) 
            
            if i == 0 or i == int(length_train / 2) or i == (length_train - 1):
                print(f"Iteration:[{i+1}/{length_train}]", end=" ")
                # function to print out the losses
                log_results(loss_meter_dict)  
            
        for idx, valid_data in enumerate(valid_dl):
            visualize(model, valid_data, save_name=os.path.join(savedir, f"Validation_Epoch_{e}-{idx}.png"))
        
        file_name = f"gan_checkpoint_{e}.tar"
        torch.save({
            "epoch": e,
            "net_G": model.net_G.state_dict(),
            "opt_G": model.opt_G.state_dict(),
            "scaler_G": model.scaler_G.state_dict(),
            "net_D": model.net_D.state_dict(),
            "opt_D": model.opt_D.state_dict(),
            "scaler_D": model.scaler_D.state_dict(),
        }, file_name)

        shutil.copyfile(file_name, os.path.join(drive_path, file_name))

In [11]:
train_ = "/content/images/train"
valid_ = "/content/images/valid"

device = get_default_device()

train_dl = make_dataloaders(path=train_, num_images=-1)
valid_dl = make_dataloaders(path=valid_, num_images=-1, is_training=False, shuffle=False)
    
print(f"Number of batches ::Train:: {len(train_dl)}, ::Valid:: {len(valid_dl)}")

Number of batches ::Train:: 109, ::Valid:: 2


In [12]:
config = {
    "lr_G": 1e-4,
    "lr_D": 2e-4,
    "beta1": 0.5,
    "beta2": 0.999,
    "lambda_L1": 100.0,
    "load_previous": True
}

In [13]:
#@title Warmup generator

# load previous epoch optimizers and grad scaler
config["load_previous"] = False # First run
# config["load_previous"] = True # uncomment for subsequent runs of same Type

load_ckpt = -1
gen_load = load_model(config, ckpt_number=load_ckpt, device=device, prefix="warm_")

net_G = gen_load["net_G"]
scaler_G = gen_load["scaler_G"]
opt_G = gen_load["opt_G"]
last_epoch = gen_load["epoch"]

criterion = nn.L1Loss()     
criterion.to(device)
start_epoch = last_epoch + 1
epochs = 2

pretrain_generator(train_dl, valid_dl, net_G, scaler_G, opt_G, criterion, start_epoch, epochs)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))


model initialized with norm initialization
model initialized with norm initialization
[*] Finding checkpoint -1 in /content/drive/MyDrive/Project-Colorisation
[!] No checkpoint for epoch -1


HBox(children=(FloatProgress(value=0.0, description='Epoch: [1/2]', max=109.0, style=ProgressStyle(description…

Iteration:[1/109] L1 Loss: [0.09907]


RuntimeError: ignored

In [None]:
# train Gan model

config["lr_G"] = 2e-4

# load previous epoch optimizers and grad scaler
config["load_previous"] = False # First run
# config["load_previous"] = True # uncomment for subsequent runs of same Type
load_ckpt = 2
gan_load = load_model(config, ckpt_number=load_ckpt, device=device, prefix="warm_")

# load_ckpt = 1
# gan_load = load_model(config, ckpt_number=load_ckpt, device=device, prefix="gan_")

# pretrained Generator 
net_G = gan_load["net_G"]
scaler_G = gan_load["scaler_G"]
opt_G = gan_load["opt_G"]

# newly Discriminator if load_previous = False else previously trained
net_D = gan_load["net_D"]
scaler_D = gan_load["scaler_D"]
opt_D = gan_load["opt_D"]

last_epoch = gan_load["epoch"] 

if not config["load_previous"]:
    last_epoch = 0

start_epoch = last_epoch + 1
epochs = 2

model = MainModel(net_G=net_G,
                  net_D=net_D,
                  scaler_G=scaler_G,
                  scaler_D=scaler_D,
                  opt_G=opt_G,
                  opt_D=opt_D,
                  device=device)

train_model(model, train_dl, valid_dl, start_epoch, epochs)


model initialized with norm initialization
model initialized with norm initialization
[*] Finding checkpoint 2 in /content/drive/MyDrive/Project-Colorisation
Generator weight loaded
Optimizer's state loaded
Grad Scaler - Generator loaded


HBox(children=(FloatProgress(value=0.0, description='Epoch: [3/4]', max=1.0, style=ProgressStyle(description_w…

Iteration:[1/1] loss_D: 0.97997 loss_G_GAN: 1.31731 loss_G_L1: 8.91221 loss_G: 10.22951 



HBox(children=(FloatProgress(value=0.0, description='Epoch: [4/4]', max=1.0, style=ProgressStyle(description_w…

Iteration:[1/1] loss_D: 0.98914 loss_G_GAN: 2.06088 loss_G_L1: 8.74411 loss_G: 10.80499 



In [None]:
# !rm -rf /content/drive/MyDrive/Project-Colorisation/*tar
# !rm -rf *tar
# !rm -rf /content/gan_train
# !rm -rf /content/warmup_train