In [None]:
# !pip uninstall fastai -y
# !pip install fastai

import fastai
fastai.__version__

'2.2.5'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# @title Download dataset
import os
import shutil
import time
from google_drive_downloader import GoogleDriveDownloader as gdd

if not os.path.exists("/content/images/valid.zip"):
    # https://drive.google.com/file/d/1bXD5hR5WlIB6LsqDHmNn29EMA43Z82f3/view?usp=sharing
    gdd.download_file_from_google_drive(file_id='1bXD5hR5WlIB6LsqDHmNn29EMA43Z82f3',
                                        dest_path='./images/valid.zip',
                                        unzip=True,
                                        showsize=True,
                                        )
    # !rm -rf /content/images/valid.zip

# download scripts
!rm -rf *.py
time.sleep(2)

!wget -q https://raw.githubusercontent.com/veb-101/GAN-Colorisation/master/trainer.py -O trainer.py
!wget -q https://raw.githubusercontent.com/veb-101/GAN-Colorisation/master/utils.py -O utils.py
!wget -q https://raw.githubusercontent.com/veb-101/GAN-Colorisation/master/model.py -O model.py
!wget -q https://raw.githubusercontent.com/veb-101/GAN-Colorisation/master/data_loader.py -O data_loader.py



if not os.path.exists("/content/images/train.zip"):
    # https://drive.google.com/file/d/1c5WQwglbVL9S_LHH5E9XU9bCWLzZg_V7/view?usp=sharing
    gdd.download_file_from_google_drive(file_id='1c5WQwglbVL9S_LHH5E9XU9bCWLzZg_V7',
                                        dest_path='./images/train.zip',
                                        unzip=True,
                                        showsize=True,
                                        )
    # !rm -rf /content/images/train.zip

In [None]:
#@title setup
import os
import gc
import random
import warnings
import importlib
import numpy as np
# from tqdm. import tqdm
from tqdm.notebook import tqdm



import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

import model
import utils
import trainer
import data_loader

seed = 41
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

importlib.reload(data_loader)
importlib.reload(utils)
importlib.reload(trainer)
importlib.reload(model)

from trainer import MainModel
from utils import init_model
from data_loader import make_dataloaders
from utils import update_losses, visualize
from utils import init_model, AverageMeter
from utils import create_loss_meters, log_results
from model import Generator_Res_Unet, Discriminator



def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_default_device()

warnings.simplefilter("ignore", UserWarning)

In [None]:
def train_model(model, train_dl, valid_dl, start_epoch, epochs):
    # getting a batch for visualizing the model output after fixed intrvals
    savedir = "/content/gan_train"
    os.makedirs(savedir, exist_ok=True)

    for e in range(start_epoch, start_epoch+epochs):
        loss_meter_dict = create_loss_meters()
        training_loader_iter = iter(train_dl)
        length_train = len(training_loader_iter)
        
        for i in tqdm(length_train, desc=f"Epoch: [{e}/{start_epoch+epochs-1}]"):
            data = next(training_loader_iter)
            model.setup_input(data)
            model.optimize()
            
            # function updating the log objects
            update_losses(model, loss_meter_dict, count=data["L"].size(0)) 
            
            if i == 0 or i == int(length_train / 2) or i == (length_train - 1):
                print(f"Iteration:[{i}/{length_train}]")
                # function to print out the losses
                log_results(loss_meter_dict)  
            
        for idx, valid_data in enumerate(valid_dl):
            visualize(model, valid_data, save_name=os.path.join(savedir, f"Validation_Epoch_{e}-{idx}.png"))
        
        torch.save({
            "epoch": e,
            "net_G": model.net_G.state_dict(),
            "opt_G": model.opt_G.state_dict(),
            "scaler_G": model.scaler_G.state_dict(),
            "net_D": model.net_D.state_dict(),
            "opt_D": model.opt_D.state_dict(),
            "scaler_D": model.scaler_D.state_dict(),
        }, f"gan_checkpoint_{e}.tar")

In [None]:
def load_model(config, ckpt_number=-1, device=None, prefix="warm_"):
    
    lr_G = config["lr_G"]
    lr_D = config["lr_D"]
    beta1 = config["beta1"]
    beta2 = config["beta2"]
    load_previous = config["load_previous"]
    
    epoch = 0
    gen_model = Generator_Res_Unet()
    net_G = init_model(gen_model.get_model(), device)

    net_D = init_model(Discriminator(input_channels=3), device)
    scaler_G = GradScaler()
    scaler_D = GradScaler()

    opt_G = optim.Adam(net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
    opt_D = optim.Adam(net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    drive_path = r"/content/drive/MyDrive/Project-Colorisation"
    drive_path = r"/content/"
    print(f"[*] Finding checkpoint {ckpt_number} in {drive_path}")

    checkpoint_file = f"{prefix}checkpoint_{ckpt_number}.tar"
    checkpoint_path = os.path.join(drive_path, checkpoint_file)

    if not os.path.exists(checkpoint_path):
        print(f"[!] No checkpoint for epoch {ckpt_number}")
    else:    
        checkpoint = torch.load(checkpoint_path)

        epoch = checkpoint["epoch"]

        net_G.load_state_dict(checkpoint[f"net_G"])
        print("Generator weight loaded")


        try:
            net_D.load_state_dict(checkpoint[f"net_D"])
            print("Discriminator weights loaded.")
        except:
            pass

        if load_previous:
            try:  
                opt_D.load_state_dict(checkpoint[f"opt_D"])
                print("Discriminator optimizer loaded.")
            except:
                pass

            opt_G.load_state_dict(checkpoint[f"opt_G"])
            print("Optimizer's state loaded")
            
            try:
                scaler_G.load_state_dict(checkpoint[f"scaler_G"])
                print("Grad Scaler - Generator loaded")
                scaler_D.load_state_dict(checkpoint[f"scaler_D"])
                print("Grad Scaler - Discriminator loaded")
            except Exception as e:
                pass
        

    return_ = {
        "epoch": epoch,
        "net_G": net_G,
        "opt_G": opt_G,
        "scaler_G": scaler_G,
        "net_D": net_D,
        "scaler_D": scaler_D,
        "opt_D": opt_D
    }
    return return_

In [None]:
def pretrain_generator(train_dl, valid_dl, net, scaler, opt, criterion, start_epoch, epochs):
    savedir = "/content/warmup"
    os.makedirs(savedir, exist_ok=True)
    
    for e in range(start_epoch, start_epoch+epochs):
        loss_meter = AverageMeter()
        training_loader_iter = iter(train_dl)
        length_train = len(training_loader_iter)
        
        for i in tqdm(length_train, desc=f"Epoch: [{e}/{start_epoch+epochs-1}]"):
            data = next(training_loader_iter)

            opt.zero_grad()
            L, ab = data['L'].to(device), data['ab'].to(device)

            with autocast():
                preds = net(L)
                loss = criterion(preds, ab)
            
            scaler.scale(loss).backward()
            scaler.step(opt)
            # loss.backward()
            # opt.step()
            loss_meter.update(loss.detach().item(), L.size(0))
            torch.cuda.empty_cache()
            gc.collect()
            
            if i == 0 or i == int(length_train / 2) or i == (length_train - 1):
                print(f"L1 Loss: [{loss_meter.avg:.5f}]")
        
        for idx, valid_data in enumerate(valid_dl):
            visualize(net, valid_data, save_name=os.path.join(savedir, f"Validation_Epoch_{e}-{idx}.png"))
            
        torch.save({
            "epoch": e,
            "net_G": net.state_dict(),
            "opt_G": opt.state_dict(),
            "scaler_G": scaler.state_dict(),
        }, f"warm_checkpoint_{e}.tar")

In [None]:
train_ = "/content/images/train"
valid_ = "/content/images/valid"

device = get_default_device()

train_dl = make_dataloaders(path=train_, num_images=2)
valid_dl = make_dataloaders(path=valid_, num_images=2, is_training=False, shuffle=False)
    
print(f"Number of batches ::Train:: {len(train_dl)}, ::Valid:: {len(valid_dl)}")

Number of batches ::Train:: 1, ::Valid:: 1


In [None]:
config = {
    "lr_G": 1e-4,
    "lr_D": 2e-4,
    "beta1": 0.5,
    "beta2": 0.999,
    "lambda_L1": 100.0,
    "load_previous": True
}

In [None]:
#@title Warmup generator

config["load_previous"] = False # First run
# config["load_previous"] = True # uncomment for subsequent runs of same Type

load_ckpt = -1
gen_load = load_model(config, ckpt_number=load_ckpt, device=device, prefix="warm_")

net_G = gen_load["net_G"]
scaler_G = gen_load["scaler_G"]
opt_G = gen_load["opt_G"]
last_epoch = gen_load["epoch"]

criterion = nn.L1Loss()     
criterion.to(device)
start_epoch = last_epoch + 1
epochs = 30

pretrain_generator(train_dl, valid_dl, net_G, scaler_G, opt_G, criterion, start_epoch, epochs)

In [None]:
# train Gan model

config["lr_G"] = 2e-4

config["load_previous"] = False # First run
# config["load_previous"] = True # uncomment for subsequent runs of same Type

load_ckpt = -1
gan_load = load_model(config, ckpt_number=load_ckpt, device=device, prefix="warm_")

# pretrained Generator 
net_G = gan_load["net_G"]
scaler_G = gan_load["scaler_G"]
opt_G = gan_load["opt_G"]

# newly Discriminator if load_previous = False else previously trained
net_D = gan_load["net_D"]
scaler_D = gan_load["scaler_D"]
opt_D = gan_load["opt_D"]

last_epoch = gan_load["epoch"] 

last_epoch = 0
start_epoch = last_epoch + 1
epochs = 100

model = MainModel(net_G=net_G,
                  net_D=net_D,
                  scaler_G=scaler_G,
                  scaler_D=scaler_D,
                  opt_G=opt_G,
                  opt_D=opt_D,
                  device=device)

train_model(config, train_dl, valid_dl, start_epoch, epochs)