In [1]:
!nvidia-smi

Sat Feb 13 11:58:02 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.39       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   55C    P0    41W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip uninstall fastai -y
!pip install fastai -qU

import fastai
print(fastai.__version__)

Uninstalling fastai-2.2.5:
  Successfully uninstalled fastai-2.2.5
2.2.5


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
# @title old dataset
# import os
# import shutil
# import time
# from google_drive_downloader import GoogleDriveDownloader as gdd

# if not os.path.exists("/content/images/valid.zip"):
#     # https://drive.google.com/file/d/1bXD5hR5WlIB6LsqDHmNn29EMA43Z82f3/view?usp=sharing
#     gdd.download_file_from_google_drive(file_id='1bXD5hR5WlIB6LsqDHmNn29EMA43Z82f3',
#                                         dest_path='./images/valid.zip',
#                                         unzip=True,
#                                         showsize=True,
#                                         )
#     # !rm -rf /content/images/valid.zip

# # download scripts

# !rm -rf *.py
# !rm -rf *.zip
# time.sleep(1)
# !wget -q https://github.com/veb-101/GAN-Colorisation/archive/master.zip
# !unzip -qq -o /content/master.zip 
# !mv ./GAN-Colorisation-master/trainer.py ./trainer.py
# !mv ./GAN-Colorisation-master/utils.py ./utils.py
# !mv ./GAN-Colorisation-master/model.py ./model.py
# !mv ./GAN-Colorisation-master/data_loader.py ./data_loader.py
# !rm -rf /content/GAN-Colorisation-master



# if not os.path.exists("/content/images/train.zip"):
#     # https://drive.google.com/file/d/1c5WQwglbVL9S_LHH5E9XU9bCWLzZg_V7/view?usp=sharing
#     gdd.download_file_from_google_drive(file_id='1c5WQwglbVL9S_LHH5E9XU9bCWLzZg_V7',
#                                         dest_path='./images/train.zip',
#                                         unzip=True,
#                                         showsize=True,
#                                         )
#     # !rm -rf /content/images/train.zip

In [3]:
#@title Download datasets and scripts

import os
import shutil
import time
from google_drive_downloader import GoogleDriveDownloader as gdd


time.sleep(1)

if not os.path.exists("/content/GAN-Colorisation"):
    # !rm -rf *.py
    !git clone -b new_training https://github.com/veb-101/GAN-Colorisation.git

    for file in os.listdir("/content/GAN-Colorisation"):
        if not os.path.isdir(os.path.join("/content/GAN-Colorisation", file)):
            shutil.copyfile(os.path.join("/content/GAN-Colorisation", file), f"/content/{file}")

#     # !rm -rf /content/GAN-Colorisation



# https://drive.google.com/file/d/1rDE8V7FuvTtyhxJ1MyJUrMeYvvoG0gCT/view?usp=sharing
if not os.path.exists("images.zip"):
    gdd.download_file_from_google_drive(file_id='1rDE8V7FuvTtyhxJ1MyJUrMeYvvoG0gCT',
                                        dest_path='./images.zip',
                                        unzip=True,
                                        showsize=True,
                                        )

In [6]:
!pip install -qU wandb
!wandb login 984196dfc7bc6ae6093fed3667fd5da413300d29

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [4]:
#@title setup

import os
import gc
import random
import shutil
import warnings
import importlib
import numpy as np
from tqdm.notebook import tqdm
import wandb



import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

import utils
import data_loader

importlib.reload(utils)
importlib.reload(data_loader)


from utils import init_model
from data_loader import make_dataloaders
from utils import update_losses, visualize
from utils import init_model, AverageMeter
from utils import create_loss_meters, log_results

seed = 41
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)



def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_default_device()

warnings.simplefilter("ignore", UserWarning)

## Pretrain Generator

In [5]:
def pretrain_generator(train_dl, valid_dl, net, scaler, opt, criterion, start_epoch, epochs, device=None):
    savedir_train = "/content/warmup_train/train"
    savedir_val = "/content/warmup_train/validation"
    drive_path = "/content/drive/MyDrive/Project-Colorisation"
    os.makedirs(savedir_train, exist_ok=True)
    os.makedirs(savedir_val, exist_ok=True)
    os.makedirs(drive_path, exist_ok=True)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        opt,
        [30, 60, 90]  
    ) 

    for e in range(start_epoch, start_epoch+epochs):
        loss_meter = AverageMeter()
        psnr_val = []
        ssim_val = []
        training_loader_iter = iter(train_dl)
        length_train = len(training_loader_iter)
        
        net.train()
        
        torch.cuda.empty_cache()
        gc.collect()
        for i in tqdm(range(length_train), desc=f"Epoch: [{e}/{start_epoch+epochs-1}]"):
            data = next(training_loader_iter)

            opt.zero_grad()
            L, ab = data['L'].to(device), data['ab'].to(device)

            with autocast():
                preds = net(L)
                loss = criterion(preds, ab)
            
            scaler.scale(loss).backward()
            scaler.step(opt)
            scale_gen = scaler.get_scale()
            scaler.update()
            skip_lr_sched = scale_gen != scaler.get_scale()

            # loss.backward()
            # opt.step()
            loss_meter.update(loss.detach().item(), L.size(0))
            torch.cuda.empty_cache()
            gc.collect()
            
            if i == 0 or i == int(length_train / 2) or i == (length_train - 1):
                print(f"Iteration:[{i+1}/{length_train}] L1 Loss: [{loss_meter.avg:.5f}]")
                _, _ = visualize(net, 
                        data, 
                        save_name=os.path.join(savedir_train, f"Train_Epoch_{e}-{i}.png"), 
                        device=device)
            

        training_epoch_loss = round(loss_meter.avg, 4)
        wandb.log({
            "epoch": e,
            "train_loss": training_epoch_loss,
        })
        
        for idx, valid_data in enumerate(valid_dl):
            psnr_batch, ssim_batch_ = visualize(net, 
                                                valid_data, 
                                                save_name=os.path.join(savedir_val, f"Validation_Epoch_{e}-{idx}.png"), 
                                                device=device)
            
            psnr_val.append(psnr_batch)
            ssim_val.append(ssim_batch_)
        
        print(f"Validation PSNR: {sum(psnr_val)/len(psnr_val):.5}, SSIM: {sum(ssim_val)/len(ssim_val):.5}")

        wandb.log({
            "epoch": e,
            "valid_PSNR": round(sum(psnr_val)/len(psnr_val), 5),
            "valid_SSIM": round(sum(ssim_val)/len(ssim_val), 5),
        })
        wandb.log({
            "validation_images": wandb.Image(
                os.path.join(savedir_val, f"Validation_Epoch_{e}-0.png")),
        })
        
        if not skip_lr_sched:
            lr_scheduler.step()
        
        file_name = f"warm_checkpoint_{e}.tar"
        torch.save({
            "epoch": e,
            "net_G": net.state_dict(),
            "opt_G": opt.state_dict(),
            "scaler_G": scaler.state_dict(),
        }, file_name)
        
        shutil.copyfile(file_name, os.path.join(drive_path, file_name))

        remove_file_name = f"/content/warm_checkpoint_{e-1}.tar"
        remove_drive_file_name = f"/content/drive/MyDrive/Project-Colorisation/warm_checkpoint_{e-1}.tar"
        try:
          os.remove(remove_file_name)
          os.remove(remove_drive_file_name)
        except:
          pass

## Train Gan

In [6]:
def train_model(model, train_dl, valid_dl, start_epoch, epochs):

    savedir_train = "/content/gan_train/train"
    savedir_val = "/content/gan_train/validation"
    drive_path = "/content/drive/MyDrive/Project-Colorisation"
    os.makedirs(savedir_train, exist_ok=True)
    os.makedirs(savedir_val, exist_ok=True)
    os.makedirs(drive_path, exist_ok=True)

    scheduler_G = torch.optim.lr_scheduler.ReduceLROnPlateau(model.opt_G,
                                                             mode="min",
                                                             factor=0.1,
                                                             patience=4,
                                                             threshold=1e-3,
                                                             cooldown=2,
                    )
    
    scheduler_D = torch.optim.lr_scheduler.ReduceLROnPlateau(model.opt_D,
                                                             mode="min",
                                                             factor=0.1,
                                                             patience=4,
                                                             threshold=1e-3,
                                                             cooldown=2,
                    )


    for e in range(start_epoch, start_epoch+epochs):
        loss_meter_dict = create_loss_meters()
        psnr_val = []
        ssim_val = []
        training_loader_iter = iter(train_dl)
        length_train = len(training_loader_iter)
        
        for i in tqdm(range(length_train), desc=f"Epoch: [{e}/{start_epoch+epochs-1}]"):
            data = next(training_loader_iter)
            model.setup_input(data)
            update_LR_G, update_LR_D = model.optimize()

            loss_G_epoch = loss_meter_dict["loss_G"]
            loss_D_epoch = loss_meter_dict["loss_D"]
            
            # function updating the log objects
            update_losses(model, loss_meter_dict, count=data["L"].size(0)) 
            
            if i == 0 or i == int(length_train / 2) or i == (length_train - 1):
                print(f"Iteration:[{i+1}/{length_train}]", end=" ")
                # function to print out the losses
                log_results(loss_meter_dict)  
                _, _ = visualize(model, 
                        data, 
                        save_name=os.path.join(savedir_train, f"Train_Epoch_{e}-{i}.png"), 
                        device=device)
            
        for idx, valid_data in enumerate(valid_dl):
            psnr_batch, ssim_batch_ = visualize(model, 
                                                valid_data, 
                                                save_name=os.path.join(savedir_val, f"Validation_Epoch_{e}-{idx}.png"), 
                                                device=device)
            
            psnr_val.append(psnr_batch)
            ssim_val.append(ssim_batch_)
        
        print(f"Validation PSNR: {sum(psnr_val)/len(psnr_val):.5}, SSIM: {sum(ssim_val)/len(ssim_val):.5}")


        file_name = f"gan_checkpoint_{e}.tar"
        torch.save({
            "epoch": e,
            "net_G": model.net_G.state_dict(),
            "opt_G": model.opt_G.state_dict(),
            "scaler_G": model.scaler_G.state_dict(),
            "net_D": model.net_D.state_dict(),
            "opt_D": model.opt_D.state_dict(),
            "scaler_D": model.scaler_D.state_dict(),
        }, file_name)

        remove_file_name = f"/content/gan_checkpoint_{e-1}.tar"
        remove_drive_file_name = f"/content/drive/MyDrive/Project-Colorisation/gan_checkpoint_{e-1}.tar"
        try:
          os.remove(remove_file_name)
          os.remove(remove_drive_file_name)
        except:
          pass

## Training Setup

In [9]:
#@title Dataloaders 
num_images_train =   9000#@param {type:"integer"}
num_images_valid =  -1 #@param {type:"integer"}
batch_size =  32#@param {type:"integer"}
image_size =  128#@param {type:"integer"}


if __name__ == "__main__":
    train_ = "/content/train"
    valid_ = "/content/valid"
    device = get_default_device()
    train_dl = make_dataloaders(path=train_, 
                                size=image_size,
                                num_images=num_images_train, 
                                batch_size=batch_size,
                                is_training=True,
                                shuffle=True
                                )
    
    valid_dl = make_dataloaders(path=valid_, 
                                size=image_size,
                                num_images=num_images_valid, 
                                batch_size=batch_size,
                                is_training=False, 
                                shuffle=False, 
                                )
        
    print(f"Number of batches ::Train:: {len(train_dl)}, ::Valid:: {len(valid_dl)}")

Number of batches ::Train:: 282, ::Valid:: 4


In [10]:
config = {
    "lr_G": 1e-4,
    "lr_D": 4e-4,
    "beta1": 0.5,
    "beta2": 0.999,
    "perceptual_loss_factor": 1.0,
    "adversarial_loss_factor": 5e-3,
    "lambda_L1": 100.0,
    "load_previous": True
}

In [11]:
import wandb
run = wandb.init(project="image-colorization")

[34m[1mwandb[0m: Currently logged in as: [33mveb-101[0m (use `wandb login --relogin` to force relogin)


In [28]:
#@title reload
import importlib
import trainer
import model
import utils
import os


importlib.reload(model)
importlib.reload(trainer)

from trainer import MainModel
from model import Generator_Res_Unet, Discriminator, Generator_Unet



In [29]:
#@title Load Model
model_type = "Generator_Unet" #@param ["Generator_Res_Unet-pretrained", "Generator_Res_Unet-new", "Generator_Unet"]

if model_type == "Generator_Res_Unet-pretrained":
    opt=1
    model = Generator_Res_Unet()
    pretrained = True

elif model_type == "Generator_Res_Unet-new":
    opt=2
    model = Generator_Res_Unet()
    pretrained = False
elif model_type == "Generator_Unet":
    opt=3
    model_3 = Generator_Unet()


def load_model(config, ckpt_number=-1, device=None, prefix="warm_"):
    
    lr_G = config["lr_G"]
    lr_D = config["lr_D"]
    beta1 = config["beta1"]
    beta2 = config["beta2"]
    load_previous = config["load_previous"]
    
    epoch = 0
    if opt in (1, 2,):
        net_G = init_model(model.get_model(pretrained=pretrained), device)
    else:
        net_G = init_model(model_3, device)

    net_D = init_model(Discriminator(input_channels=3), device)
    scaler_G = GradScaler()
    scaler_D = GradScaler()

    opt_G = optim.Adam(net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
    opt_D = optim.Adam(net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    drive_path = r"/content/drive/MyDrive/Project-Colorisation"
    print(f"[*] Finding checkpoint {ckpt_number} in {drive_path}")

    checkpoint_file = f"{prefix}checkpoint_{ckpt_number}.tar"
    checkpoint_path = os.path.join(drive_path, checkpoint_file)

    if not os.path.exists(checkpoint_path):
        print(f"[!] No checkpoint for epoch {ckpt_number}")
    else:    
        checkpoint = torch.load(checkpoint_path)

        epoch = checkpoint["epoch"]

        net_G.load_state_dict(checkpoint[f"net_G"])
        print("Generator weight loaded")

        try:
            net_D.load_state_dict(checkpoint[f"net_D"])
            print("Discriminator weights loaded.")
        except:
            pass

        if load_previous:
            try:  
                opt_D.load_state_dict(checkpoint[f"opt_D"])
                print("Discriminator optimizer loaded.")
            except:
                pass

            opt_G.load_state_dict(checkpoint[f"opt_G"])
            print("Optimizer's state loaded")
            
            try:
                scaler_G.load_state_dict(checkpoint[f"scaler_G"])
                print("Grad Scaler - Generator loaded")
                scaler_D.load_state_dict(checkpoint[f"scaler_D"])
                print("Grad Scaler - Discriminator loaded")
            except Exception as e:
                pass
        

    return_ = {
        "epoch": epoch,
        "net_G": net_G,
        "opt_G": opt_G,
        "scaler_G": scaler_G,
        "net_D": net_D,
        "scaler_D": scaler_D,
        "opt_D": opt_D
    }
    return return_

## Warmup generator

In [None]:
# load previous epoch optimizers and grad scaler
config["load_previous"] = False # First run
# config["load_previous"] = True # uncomment for subsequent runs of same Type

load_ckpt = 0
gen_load = load_model(config, ckpt_number=load_ckpt, device=device, prefix="warm_")

net_G = gen_load["net_G"]
scaler_G = gen_load["scaler_G"]
opt_G = gen_load["opt_G"]
last_epoch = gen_load["epoch"]

criterion = nn.L1Loss()     
criterion.to(device)
start_epoch = last_epoch + 1
epochs = 110

# from torchsummary import summary
# summary(net_G.to("cuda"), (1, 128, 128))

pretrain_generator(train_dl, valid_dl, net_G, scaler_G, opt_G, criterion, start_epoch, epochs, device=device)

model initialized with norm initialization
model initialized with norm initialization
[*] Finding checkpoint 0 in /content/drive/MyDrive/Project-Colorisation
[!] No checkpoint for epoch 0


HBox(children=(FloatProgress(value=0.0, description='Epoch: [1/110]', max=282.0, style=ProgressStyle(descripti…

Iteration:[1/282] L1 Loss: [0.08666]
Iteration:[142/282] L1 Loss: [0.08788]
Iteration:[282/282] L1 Loss: [0.08693]

Validation PSNR: 20.415, SSIM: 0.9095


HBox(children=(FloatProgress(value=0.0, description='Epoch: [2/110]', max=282.0, style=ProgressStyle(descripti…

Iteration:[1/282] L1 Loss: [0.10311]
Iteration:[142/282] L1 Loss: [0.08453]
Iteration:[282/282] L1 Loss: [0.08407]

Validation PSNR: 20.779, SSIM: 0.91125


HBox(children=(FloatProgress(value=0.0, description='Epoch: [3/110]', max=282.0, style=ProgressStyle(descripti…

Iteration:[1/282] L1 Loss: [0.09160]


In [17]:
run.finish()

VBox(children=(Label(value=' 7.40MB of 7.40MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,5.0
train_loss,0.0846
_runtime,137.0
_timestamp,1613217088.0
_step,14.0
valid_PSNR,20.05325
valid_SSIM,0.9085


0,1
epoch,▁▁▃▃▅▅▆▆██
train_loss,███▁▁
_runtime,▁▁▁▃▃▃▄▅▅▆▆▆███
_timestamp,▁▁▁▃▃▃▄▅▅▆▆▆███
_step,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
valid_PSNR,▁▂▃▅█
valid_SSIM,▁▁▁██


## Train GAN

In [None]:
# train Gan model

config["lr_G"] = 2e-4
config["lr_D"] = 4e-4
torch.cuda.empty_cache()
gc.collect()

# load previous epoch optimizers and grad scaler
config["load_previous"] = False # First run
# config["load_previous"] = True # uncomment for subsequent runs of same Type

load_ckpt = 106
gan_load = load_model(config, ckpt_number=load_ckpt, device=device, prefix="warm_")


# pretrained Generator 
net_G = gan_load["net_G"]
scaler_G = gan_load["scaler_G"]
opt_G = gan_load["opt_G"]

# newly Discriminator if load_previous = False else previously trained
net_D = gan_load["net_D"]
scaler_D = gan_load["scaler_D"]
opt_D = gan_load["opt_D"]


last_epoch = gan_load["epoch"] 

if not config["load_previous"]:
    last_epoch = 0

start_epoch = last_epoch + 1
epochs = 400

model = MainModel(net_G=net_G,
                  net_D=net_D,
                  scaler_G=scaler_G,
                  scaler_D=scaler_D,
                  opt_G=opt_G,
                  opt_D=opt_D,
                  config=config,
                  device=device)

train_model(model, train_dl, valid_dl, start_epoch, epochs)

## Random


Run js function in console
```
function KeepClicking(){
console.log("Clicking");
document.querySelector("colab-connect-button").click()
}
setInterval(KeepClicking,300000)
```



In [None]:
# !rm -rf /content/drive/MyDrive/Project-Colorisation/*tar
# !rm -rf *tar
# !rm -rf /content/gan_train
# !rm -rf /content/warmup_train