In [None]:
!nvidia-smi

Sun Feb  7 09:00:49 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.39       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P8     9W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install -qU wandb
!wandb login 984196dfc7bc6ae6093fed3667fd5da413300d29

In [None]:
# @title Download dataset
import os
import shutil
import time
from google_drive_downloader import GoogleDriveDownloader as gdd

if not os.path.exists("/content/images/valid.zip"):
    # https://drive.google.com/file/d/1ux-mv7250XsgZo-vsk4QIedfGbELl9oZ/view?usp=sharing
    # LR - 64x64 images
    gdd.download_file_from_google_drive(file_id='1ux-mv7250XsgZo-vsk4QIedfGbELl9oZ',
                                        dest_path='./images/valid.zip',
                                        unzip=True,
                                        showsize=True,
                                        )
    # !rm -rf /content/images/valid.zip

# download scripts
!rm -rf *.py
time.sleep(2)
!wget -q https://raw.githubusercontent.com/veb-101/Esrgan-pytorch/master/trainer.py -O trainer.py
!wget -q https://raw.githubusercontent.com/veb-101/Esrgan-pytorch/master/utils.py -O utils.py
!wget -q https://raw.githubusercontent.com/veb-101/Esrgan-pytorch/master/models.py -O models.py

# https://drive.google.com/file/d/1ak35jClzM1kAf7p4325b5P9CIudfuwWq/view?usp=sharing

# if not os.path.exists("/content/pretrained.zip"):
#     # https://drive.google.com/file/d/1ak35jClzM1kAf7p4325b5P9CIudfuwWq/view?usp=sharing
#     # HR - 256x256 images
#     gdd.download_file_from_google_drive(file_id='1ak35jClzM1kAf7p4325b5P9CIudfuwWq',
#                                         dest_path='./pretrained.zip',
#                                         unzip=True,
#                                         showsize=True,
#                                         )


if not os.path.exists("/content/images/train.zip"):
    # https://drive.google.com/file/d/1RGvBO7wVCI4aPNkjy8w-Sl3Kzp2UCqXm/view?usp=sharing
    # HR - 256x256 images
    gdd.download_file_from_google_drive(file_id='1RGvBO7wVCI4aPNkjy8w-Sl3Kzp2UCqXm',
                                        dest_path='./images/train.zip',
                                        unzip=True,
                                        showsize=True,
                                        )
    # !rm -rf /content/images/train.zip

Downloading 1ux-mv7250XsgZo-vsk4QIedfGbELl9oZ into ./images/valid.zip... 
7.9 MiB Done.
Unzipping...Done.
Downloading 1RGvBO7wVCI4aPNkjy8w-Sl3Kzp2UCqXm into ./images/train.zip... 
1.3 GiB Done.


In [None]:
#@title Setup

import trainer
import os
import random
from PIL import Image
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
import importlib

import warnings
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)

seed = 41
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")


device = get_default_device()


class ESR_Dataset(Dataset):
    def __init__(self, num_images=9000, path=r"images", train=True):
        self.path = path
        self.is_train = train

        if not os.path.exists(self.path):
            raise Exception(f"[!] dataset is not exited")

        self.image_paths = os.listdir(os.path.join(self.path, "hr"))

        if num_images == -1:
            num_images = len(self.image_paths)

        
        self.image_file_name = np.random.choice(self.image_paths, size=num_images, replace=False)

        if not train:
            self.image_file_name = sorted(self.image_file_name)
        

        self.mean = np.array([0.485, 0.456, 0.406])
        self.std = np.array([0.229, 0.224, 0.225])
        
    def __getitem__(self, item):
        file_name = self.image_file_name[item]
        high_resolution = Image.open(os.path.join(self.path, "hr", file_name)).convert(
            "RGB"
        )
        low_resolution = Image.open(os.path.join(self.path, "lr", file_name)).convert(
            "RGB"
        )

        if self.is_train:
            if random.random() > 0.5:
                high_resolution = TF.vflip(high_resolution)
                low_resolution = TF.vflip(low_resolution)

            if random.random() > 0.5:
                high_resolution = TF.hflip(high_resolution)
                low_resolution = TF.hflip(low_resolution)

            if random.random() > 0.5:
                high_resolution = TF.rotate(high_resolution, 90)
                low_resolution = TF.rotate(low_resolution, 90)

        high_resolution = TF.to_tensor(high_resolution)
        low_resolution = TF.to_tensor(low_resolution)

        high_resolution = TF.normalize(high_resolution, self.mean, self.std)
        low_resolution = TF.normalize(low_resolution, self.mean, self.std)

        images = {"lr": low_resolution, "hr": high_resolution}

        return images

    def __len__(self):
        return len(self.image_file_name)


def make_dataloaders(
    batch_size=32, n_workers=4, shuffle=True, **kwargs):  # A handy function to make our dataloaders
    dataset = ESR_Dataset(**kwargs)

    pin = torch.cuda.is_available()

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=n_workers,
        pin_memory=pin,
        shuffle=shuffle,
    )
    return dataloader

In [None]:
config = {
    "image_size": 256,
    "batch_size": 16,
    "start_epoch": 0,
    "num_epoch": 100,
    "sample_batch_size": 1,
    "checkpoint_dir": "./checkpoints",
    "sample_dir": "./samples",
    "workers": 6,
    "scale_factor": 4,
    "num_rrdn_blocks": 23,
    "nf": 64,
    "gc": 32,
    "b1": 0.9,
    "b2": 0.999,
    "weight_decay": 1e-2,
    # ------ PSNR ------
    "p_lr": 2e-4,
    "p_decay_iter": [25, 50, 75],
    "p_perceptual_loss_factor": 0,
    "p_adversarial_loss_factor": 0,
    "p_content_loss_factor": 1,
    # ------------------
    # ------ ADVR ------
    "g_lr": 1e-4,
    "g_decay_iter": [25, 50, 75],
    "g_perceptual_loss_factor": 1,
    "g_adversarial_loss_factor": 5e-3,
    "g_content_loss_factor": 1e-2,
    # ------------------
    "is_psnr_oriented": True,
    "load_previous_opt": True,
}

In [None]:
#@title Create Dataloader
training_images =  9000#@param {type:"integer"}
validation_images =  -1#@param {type:"integer"}


esr_dataloader_train = make_dataloaders(batch_size=config["batch_size"], 
                                        n_workers=config["workers"], 
                                        shuffle=True, 
                                        num_images=training_images, 
                                        path=r"./images/train", 
                                        train=True)

esr_dataloader_val = make_dataloaders(batch_size=config["batch_size"],
                                      n_workers=config["workers"], 
                                      shuffle=False, 
                                      num_images=validation_images, 
                                      path=r"./images/valid", 
                                      train=False)

In [None]:
#@title reload
import trainer
import models
import utils

importlib.reload(utils)
importlib.reload(models)
importlib.reload(trainer)

if not os.path.exists(config["sample_dir"]):
    os.makedirs(config["sample_dir"])

In [None]:
import wandb
wandb.init(project="esrgan-pytorch")

In [None]:
config["start_epoch"] = 1
config["num_epoch"] = 100 - config["start_epoch"] + 1

config["num_rrdn_blocks"] = 23
config["nf"] = 64

config['p_decay_iter'] = [25, 50, 75]
# config['g_decay_iter'] = [25, 50, 75]

config["is_psnr_oriented"] = True
config["load_previous_opt"] = False


hyperparams = wandb.config
hyperparams.epochs = config["num_epoch"]
hyperparams.RRDN = config["num_rrdn_blocks"]
hyperparams.nf = config["nf"]
hyperparams.psnr_oriented = config["is_psnr_oriented"]
hyperparams.p_decay_iter = config["p_decay_iter"]
hyperparams.g_decay_iter = config["g_decay_iter"]

* **Warmup Training - PSNR: 25.485, SSIM: 0.7215**
* Bicubic Upscale - PSNR: 25.51, SSIM:0.7185


---


* BEST Epoch: 93 -> Dis loss: 0.0093 Gen loss: 1.2519 Per loss:: 1.2265 Adv loss:: 0.0236 Con loss:: 0.0018
Validation Set: PSNR: 24.6282, SSIM:0.69

* Epoch: 111 -> Dis loss: 0.0091 Gen loss: 1.2453 Per loss:: 1.2198 Adv loss:: 0.0237 Con loss:: 0.0018
Validation Set: PSNR: 24.6078, SSIM:0.6922
Bicubic Ups: PSNR: 25.5098, SSIM:0.7185


---



* Best Epoch: 120 -> Dis loss: 0.0089 Gen loss: 1.2454 Per loss:: 1.2198 Adv loss:: 0.0238 Con loss:: 0.0018
Validation Set: PSNR: 24.6125, SSIM:0.6957
Bicubic Ups: PSNR: 25.51, SSIM:0.7182
Best val scores  till epoch 120 -> PSNR: 24.6195, SSIM: 0.6957


In [None]:
#@title training

import gc

torch.cuda.empty_cache()
gc.collect()

for key, value in config.items():
    print(f"{key:30}: {value}")

print("\n\n")
print(f"ESRGAN start")

model = trainer.Trainer(
    config, esr_dataloader_train, esr_dataloader_val, device)
model_metrics = model.train()

In [None]:
#@title Testing Images

import cv2
from google.colab.patches import cv2_imshow
import numpy as np
import models
import importlib
importlib.reload(models)

checkpoint_number = 9

checkpoint = torch.load(rf"/content/checkpoint_{checkpoint_number}.tar")
Model = models.Generator(channels=3, nf=32, gc=32, num_res_blocks=11, scale=4)
Model.load_state_dict(checkpoint[rf"generator_dict_{checkpoint_number}"])
Model.to(device)
# print("Generator weights loaded.")


image_path = r"/content/images/lr/00000.png"


def test_image(image_path, model):

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    cv2_imshow(image)
    image = image // 255.0
    image= image.astype(np.float32)

    image = np.moveaxis(image, (0, 1, 2), (1, 2, 0))
    image = torch.from_numpy(np.expand_dims(image, 0))
    with torch.no_grad():
        # with torch.cuda.amp.autocast():
        high_res = model(image.to(device))
    
    high_res = high_res.cpu().detach().permute(0, 2, 3, 1).numpy()
    cv2_imshow(high_res[0] * 255.0)
    print(high_res.shape)

test_image(image_path, Model)


# Model.eval()
# with torch.no_grad():
# with torch.cuda.amp.autocast(enabled=False):
    # out = Model(image)

In [None]:
# !rm -rf /content/samples
# !rm -rf /content/drive/MyDrive/Project-ESRGAN
# !rm -rf checkpoint*