In [1]:
!nvidia-smi 

Fri Jan 22 08:28:04 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   65C    P8    11W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# @title Download dataset
import os
import shutil
from google_drive_downloader import GoogleDriveDownloader as gdd

if not os.path.exists("/content/images/valid.zip"):
    # https://drive.google.com/file/d/1ux-mv7250XsgZo-vsk4QIedfGbELl9oZ/view?usp=sharing
    # LR - 64x64 images
    gdd.download_file_from_google_drive(file_id='1ux-mv7250XsgZo-vsk4QIedfGbELl9oZ',
                                        dest_path='./images/valid.zip',
                                        unzip=True,
                                        showsize=True,
                                        )
    !rm -rf /content/images/valid.zip

# download scripts
!wget -q https://raw.githubusercontent.com/veb-101/Esrgan-pytorch/master/trainer.py -O trainer.py
!wget -q https://raw.githubusercontent.com/veb-101/Esrgan-pytorch/master/utils.py -O utils.py
!wget -q https://raw.githubusercontent.com/veb-101/Esrgan-pytorch/master/models.py -O models.py


if not os.path.exists("/content/images/train.zip"):
    # https://drive.google.com/file/d/1RGvBO7wVCI4aPNkjy8w-Sl3Kzp2UCqXm/view?usp=sharing
    # HR - 256x256 images
    gdd.download_file_from_google_drive(file_id='1RGvBO7wVCI4aPNkjy8w-Sl3Kzp2UCqXm',
                                        dest_path='./images/train.zip',
                                        unzip=True,
                                        showsize=True,
                                        )
    # !rm -rf /content/images/train.zip

Downloading 1T8IZ4huhbVGW2eVUJOcUFuk7m5HlkEpg into ./images/lr.zip... 
373.5 MiB Done.
Downloading 1HQUi1ZZpPHXMO1swNKywcwwVNyqwYJNu into ./images/hr.zip... 
5.3 GiB Done.
Unzipping...Done.


In [4]:
import train
import os
import random
from PIL import Image
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
import importlib


seed = 41
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")


device = get_default_device()


class ESR_Dataset(Dataset):
    def __init__(self, num_images=9000, path=r"images", train=True):
        self.path = path
        self.is_train = train

        if not os.path.exists(self.path):
            raise Exception(f"[!] dataset is not exited")

        self.image_paths = os.listdir(os.path.join(self.path, "hr"))

        # self.image_range = image_range if image_range else (0, len(self.image_paths))
        # if len(self.image_range) == 1:
        #     if self.is_train:
        #         self.image_range = (0, self.image_range[0])
        #     else:
        #         self.image_range = (self.image_range[0], len(self.image_paths))

        # self.start = self.image_range[0]
        # self.end = self.image_range[1]

        self.image_file_name = np.random.choice(
            self.image_paths, size=num_images, replace=False)

        self.mean = np.array([0.485, 0.456, 0.406])
        self.std = np.array([0.229, 0.224, 0.225])
        # self.mean = np.array([0.5, 0.5, 0.5])
        # self.std = np.array([0.5, 0.5, 0.5])

    def __getitem__(self, item):
        file_name = self.image_file_name[item]
        high_resolution = Image.open(os.path.join(self.path, "hr", file_name)).convert(
            "RGB"
        )
        low_resolution = Image.open(os.path.join(self.path, "lr", file_name)).convert(
            "RGB"
        )

        if self.is_train:
            if random.random() > 0.5:
                high_resolution = TF.vflip(high_resolution)
                low_resolution = TF.vflip(low_resolution)

            if random.random() > 0.5:
                high_resolution = TF.hflip(high_resolution)
                low_resolution = TF.hflip(low_resolution)

            if random.random() > 0.5:
                high_resolution = TF.rotate(high_resolution, 90)
                low_resolution = TF.rotate(low_resolution, 90)

        high_resolution = TF.to_tensor(high_resolution)
        low_resolution = TF.to_tensor(low_resolution)

        high_resolution = TF.normalize(high_resolution, self.mean, self.std)
        low_resolution = TF.normalize(low_resolution, self.mean, self.std)

        images = {"lr": low_resolution, "hr": high_resolution}

        return images

    def __len__(self):
        return len(self.image_file_name)

In [5]:
config = {
    "image_size": 256,
    "batch_size": 16,
    "start_epoch": 0,
    "num_epoch": 100,
    "sample_batch_size": 1,
    "checkpoint_dir": "./checkpoints",
    "sample_dir": "./samples",
    "workers": 6,
    "scale_factor": 4,
    "num_rrdn_blocks": 18,
    "nf": 64,
    "gc": 32,
    "b1": 0.9,
    "b2": 0.999,
    "weight_decay": 1e-2,
    # ------ PSNR ------
    "p_lr": 2e-4,
    "p_decay_iter": [20, 40, 60, 80],
    "p_perceptual_loss_factor": 0,
    "p_adversarial_loss_factor": 0,
    "p_content_loss_factor": 1,
    # ------------------
    # ------ ADVR ------
    "g_lr": 1e-4,
    "g_decay_iter": [20, 40, 60, 80],
    "g_perceptual_loss_factor": 1,
    "g_adversarial_loss_factor": 5e-3,
    "g_content_loss_factor": 1e-2,
    # ------------------
    "is_psnr_oriented": True,
    "load_previous_opt": True,
}

In [6]:
import trainer
import models
import utils

importlib.reload(utils)
importlib.reload(models)
importlib.reload(trainer)

if not os.path.exists(config["sample_dir"]):
    os.makedirs(config["sample_dir"])

In [None]:
# @title warmup training
pin = torch.cuda.is_available()

config["start_epoch"] = 0
config["num_epoch"] = 100
config["batch_size"] = 16
config["is_psnr_oriented"] = True
config["load_previous_opt"] = True


esr_dataset_train = ESR_Dataset(
    num_images=9000, path=r"./images/train", train=True)

esr_dataset_val = ESR_Dataset(
    num_images=64, path=r"./images/valid", train=False)

esr_dataloader_train = DataLoader(
    esr_dataset_train,
    config["batch_size"],
    num_workers=config["workers"],
    pin_memory=pin,
    shuffle=True,
)

esr_dataloader_val = DataLoader(
    esr_dataset_val,
    config["batch_size"],
    num_workers=config["workers"],
    pin_memory=pin,
)

for key, value in config.items():
    print(f"{key:30}: {value}")

print("\n\")
print(f"ESRGAN start")

torch.cuda.empty_cache()
psnr_model = train.Trainer(
    config, esr_dataloader_train, esr_dataloader_val, device)
psnr_model_metrics = psnr_model.train()

image_size                    : 512
batch_size                    : 12
start_epoch                   : 2
num_epoch                     : 10
sample_batch_size             : 1
checkpoint_dir                : ./checkpoints
sample_dir                    : ./samples
workers                       : 6
scale_factor                  : 4
num_rrdn_blocks               : 11
nf                            : 32
gc                            : 32
b1                            : 0.9
b2                            : 0.999
weight_decay                  : 0.01
p_lr                          : 0.0002
p_decay_iter                  : [4, 8, 12, 16]
p_perceptual_loss_factor      : 0
p_adversarial_loss_factor     : 0
p_content_loss_factor         : 1
g_lr                          : 0.0001
g_decay_iter                  : [10, 20, 35, 50]
g_perceptual_loss_factor      : 1
g_adversarial_loss_factor     : 0.005
g_content_loss_factor         : 0.01
is_psnr_oriented              : True
load_previous_opt             : 

In [None]:
pin = torch.cuda.is_available()


config["start_epoch"] = 12
config["num_epoch"] = 100
config["batch_size"] = 12
config["is_psnr_oriented"] = False
# for loading last ran checkpoint optimizers parameters
config["load_previous_opt"] = False


esr_dataset_train = ESR_Dataset(num_images=9000,
                                path=r"./images",
                                train=True,
                                image_range=(11000,)
                                )

esr_dataset_val = ESR_Dataset(num_images=config["batch_size"],
                              path=r"./images",
                              train=False,
                              image_range=(11000,)
                              )

esr_dataloader_train = DataLoader(esr_dataset_train,
                                  config["batch_size"],
                                  num_workers=config["workers"],
                                  pin_memory=pin,
                                  shuffle=True,)

esr_dataloader_val = DataLoader(esr_dataset_val,
                                config["batch_size"],
                                num_workers=config["workers"],
                                pin_memory=pin,)

for key, value in config.items():
    print(f"{key:30}: {value}")

print("\n\n\n")
print(f"ESRGAN start")

torch.cuda.empty_cache()
trainer = train.Trainer(config, esr_dataloader_train, esr_dataloader_val, device)
train_metrics = trainer.train()

image_size                    : 512
batch_size                    : 12
start_epoch                   : 12
num_epoch                     : 100
sample_batch_size             : 1
checkpoint_dir                : ./checkpoints
sample_dir                    : ./samples
workers                       : 6
scale_factor                  : 4
num_rrdn_blocks               : 11
nf                            : 32
gc                            : 32
b1                            : 0.9
b2                            : 0.999
weight_decay                  : 0.01
p_lr                          : 0.0002
p_decay_iter                  : [4, 8, 12, 16]
p_perceptual_loss_factor      : 0
p_adversarial_loss_factor     : 0
p_content_loss_factor         : 1
g_lr                          : 0.0001
g_decay_iter                  : [10, 20, 35, 50]
g_perceptual_loss_factor      : 1
g_adversarial_loss_factor     : 0.005
g_content_loss_factor         : 0.01
is_psnr_oriented              : False
load_previous_opt            

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))


[Epoch 12/111] [Batch 1/750][D loss 0.71243] [G loss 0.91814][adversarial loss 0.0035][perceptual loss 0.9138][content loss 0.0008]
[Epoch 12/111] [Batch 376/750][D loss 0.00129] [G loss 1.3242][adversarial loss 0.0345][perceptual loss 1.2887][content loss 0.0011]
[Epoch 12/111] [Batch 750/750][D loss 1.31021] [G loss 1.47806][adversarial loss 0.0113][perceptual loss 1.4655][content loss 0.0012]
Validation Set: PSNR: 17.452, SSIM:0.51
[Epoch 13/111] [Batch 1/750][D loss 0.62779] [G loss 1.47641][adversarial loss 0.0087][perceptual loss 1.4667][content loss 0.001]
[Epoch 13/111] [Batch 376/750][D loss 0.0022] [G loss 1.20905][adversarial loss 0.0328][perceptual loss 1.1741][content loss 0.0022]
[Epoch 13/111] [Batch 750/750][D loss 0.0019] [G loss 1.26491][adversarial loss 0.0337][perceptual loss 1.2295][content loss 0.0017]
Validation Set: PSNR: 13.646, SSIM:0.181
[Epoch 14/111] [Batch 1/750][D loss 0.00152] [G loss 1.39827][adversarial loss 0.0346][perceptual loss 1.3617][content los

In [None]:
#@title Testing Images

import cv2
from google.colab.patches import cv2_imshow
import numpy as np
import models
import importlib
importlib.reload(models)

checkpoint_number = 9

checkpoint = torch.load(rf"/content/checkpoint_{checkpoint_number}.tar")
Model = models.Generator(channels=3, nf=32, gc=32, num_res_blocks=11, scale=4)
Model.load_state_dict(checkpoint[rf"generator_dict_{checkpoint_number}"])
Model.to(device)
# print("Generator weights loaded.")


image_path = r"/content/images/lr/00000.png"


def test_image(image_path, model):

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    cv2_imshow(image)
    image = image // 255.0
    image= image.astype(np.float32)

    image = np.moveaxis(image, (0, 1, 2), (1, 2, 0))
    image = torch.from_numpy(np.expand_dims(image, 0))
    with torch.no_grad():
        # with torch.cuda.amp.autocast():
        high_res = model(image.to(device))
    
    high_res = high_res.cpu().detach().permute(0, 2, 3, 1).numpy()
    cv2_imshow(high_res[0] * 255.0)
    print(high_res.shape)

test_image(image_path, Model)


# Model.eval()
# with torch.no_grad():
# with torch.cuda.amp.autocast(enabled=False):
    # out = Model(image)  



In [None]:
# !rm -rf /content/samples
# !rm -rf /content/drive/MyDrive/Project-ESRGAN
# !rm -rf checkpoint*