### Import all the necessary libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset

from torchvision import models
from collections import namedtuple


import numpy as np
import os
from PIL import Image
from skimage.color import rgb2ycbcr
from skimage.metrics import peak_signal_noise_ratio
from tqdm import tqdm
import random

### Dataset

In [2]:
class mydata(Dataset):
    def __init__(self, LR_path, GT_path, in_memory = True, transform = None):
        
        self.LR_path = LR_path
        self.GT_path = GT_path
        self.in_memory = in_memory
        self.transform = transform
        
        self.LR_img = sorted(os.listdir(LR_path))
        self.GT_img = sorted(os.listdir(GT_path))
        
        if in_memory:
            self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr)).convert("RGB")).astype(np.uint8) for lr in self.LR_img]
            self.GT_img = [np.array(Image.open(os.path.join(self.GT_path, gt)).convert("RGB")).astype(np.uint8) for gt in self.GT_img]
        
    def __len__(self):
        
        return len(self.LR_img)
        
    def __getitem__(self, i):
        
        img_item = {}
        
        if self.in_memory:
            GT = self.GT_img[i].astype(np.float32)
            LR = self.LR_img[i].astype(np.float32)
            
        else:
            GT = np.array(Image.open(os.path.join(self.GT_path, self.GT_img[i])).convert("RGB"))
            LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])).convert("RGB"))

        img_item['GT'] = (GT / 127.5) - 1.0
        img_item['LR'] = (LR / 127.5) - 1.0
                
        if self.transform is not None:
            img_item = self.transform(img_item)
            
        img_item['GT'] = img_item['GT'].transpose(2, 0, 1).astype(np.float32)
        img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
        
        return img_item
    
    
class testOnly_data(Dataset):
    def __init__(self, LR_path, in_memory = True, transform = None):
        
        self.LR_path = LR_path
        self.LR_img = sorted(os.listdir(LR_path))
        self.in_memory = in_memory
        if in_memory:
            self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr))) for lr in self.LR_img]
        
    def __len__(self):
        
        return len(self.LR_img)
        
    def __getitem__(self, i):
        
        img_item = {}
        
        if self.in_memory:
            LR = self.LR_img[i]
            
        else:
            LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])))

        img_item['LR'] = (LR / 127.5) - 1.0                
        img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
        
        return img_item


class crop(object):
    def __init__(self, scale, patch_size):
        
        self.scale = scale
        self.patch_size = patch_size
        
    def __call__(self, sample):
        LR_img, GT_img = sample['LR'], sample['GT']
        ih, iw = LR_img.shape[:2]
        
        ix = random.randrange(0, iw - self.patch_size +1)
        iy = random.randrange(0, ih - self.patch_size +1)
        
        tx = ix * self.scale
        ty = iy * self.scale
        
        LR_patch = LR_img[iy : iy + self.patch_size, ix : ix + self.patch_size]
        GT_patch = GT_img[ty : ty + (self.scale * self.patch_size), tx : tx + (self.scale * self.patch_size)]
        
        return {'LR' : LR_patch, 'GT' : GT_patch}

class augmentation(object):
    
    def __call__(self, sample):
        LR_img, GT_img = sample['LR'], sample['GT']
        
        hor_flip = random.randrange(0,2)
        ver_flip = random.randrange(0,2)
        rot = random.randrange(0,2)
    
        if hor_flip:
            temp_LR = np.fliplr(LR_img)
            LR_img = temp_LR.copy()
            temp_GT = np.fliplr(GT_img)
            GT_img = temp_GT.copy()
            
            del temp_LR, temp_GT
        
        if ver_flip:
            temp_LR = np.flipud(LR_img)
            LR_img = temp_LR.copy()
            temp_GT = np.flipud(GT_img)
            GT_img = temp_GT.copy()
            
            del temp_LR, temp_GT
            
        if rot:
            LR_img = LR_img.transpose(1, 0, 2)
            GT_img = GT_img.transpose(1, 0, 2)
        
        
        return {'LR' : LR_img, 'GT' : GT_img}
        



### VGG19 for Perceptual Loss

In [3]:
class vgg19(nn.Module):
    
    def __init__(self, pre_trained = True, require_grad = False):
        super(vgg19, self).__init__()
        self.vgg_feature = models.vgg19(pretrained = pre_trained).features
        self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature]
        self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 
                         'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
                         'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
                         'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
                         'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
        
        if not require_grad:
            for parameter in self.parameters():
                parameter.requires_grad = False
        
    def forward(self, x):
        
        conv1_1 = self.seq_list[0](x)
        relu1_1 = self.seq_list[1](conv1_1)
        conv1_2 = self.seq_list[2](relu1_1)
        relu1_2 = self.seq_list[3](conv1_2)
        pool1 = self.seq_list[4](relu1_2)
        
        conv2_1 = self.seq_list[5](pool1)
        relu2_1 = self.seq_list[6](conv2_1)
        conv2_2 = self.seq_list[7](relu2_1)
        relu2_2 = self.seq_list[8](conv2_2)
        pool2 = self.seq_list[9](relu2_2)
        
        conv3_1 = self.seq_list[10](pool2)
        relu3_1 = self.seq_list[11](conv3_1)
        conv3_2 = self.seq_list[12](relu3_1)
        relu3_2 = self.seq_list[13](conv3_2)
        conv3_3 = self.seq_list[14](relu3_2)
        relu3_3 = self.seq_list[15](conv3_3)
        conv3_4 = self.seq_list[16](relu3_3)
        relu3_4 = self.seq_list[17](conv3_4)
        pool3 = self.seq_list[18](relu3_4)
        
        conv4_1 = self.seq_list[19](pool3)
        relu4_1 = self.seq_list[20](conv4_1)
        conv4_2 = self.seq_list[21](relu4_1)
        relu4_2 = self.seq_list[22](conv4_2)
        conv4_3 = self.seq_list[23](relu4_2)
        relu4_3 = self.seq_list[24](conv4_3)
        conv4_4 = self.seq_list[25](relu4_3)
        relu4_4 = self.seq_list[26](conv4_4)
        pool4 = self.seq_list[27](relu4_4)
        
        conv5_1 = self.seq_list[28](pool4)
        relu5_1 = self.seq_list[29](conv5_1)
        conv5_2 = self.seq_list[30](relu5_1)
        relu5_2 = self.seq_list[31](conv5_2)
        conv5_3 = self.seq_list[32](relu5_2)
        relu5_3 = self.seq_list[33](conv5_3)
        conv5_4 = self.seq_list[34](relu5_3)
        relu5_4 = self.seq_list[35](conv5_4)
        pool5 = self.seq_list[36](relu5_4)
        
        vgg_output = namedtuple("vgg_output", self.vgg_layer)
        
        vgg_list = [conv1_1, relu1_1, conv1_2, relu1_2, pool1, 
                         conv2_1, relu2_1, conv2_2, relu2_2, pool2,
                         conv3_1, relu3_1, conv3_2, relu3_2, conv3_3, relu3_3, conv3_4, relu3_4, pool3,
                         conv4_1, relu4_1, conv4_2, relu4_2, conv4_3, relu4_3, conv4_4, relu4_4, pool4,
                         conv5_1, relu5_1, conv5_2, relu5_2, conv5_3, relu5_3, conv5_4, relu5_4, pool5]
        
        out = vgg_output(*vgg_list)
        
        
        return out
        



### Loss Definitions

In [4]:
class MeanShift(nn.Conv2d):
    def __init__(
        self, rgb_range = 1,
        norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(norm_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(norm_mean) / std
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        for p in self.parameters():
            p.requires_grad = False


class perceptual_loss(nn.Module):

    def __init__(self, vgg):
        super(perceptual_loss, self).__init__()
        self.normalization_mean = [0.485, 0.456, 0.406]
        self.normalization_std = [0.229, 0.224, 0.225]
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.transform = MeanShift(norm_mean = self.normalization_mean, norm_std = self.normalization_std).to(self.device)
        self.vgg = vgg
        self.criterion = nn.MSELoss()
    def forward(self, HR, SR, layer = 'relu5_4'):
        ## HR and SR should be normalized [0,1]
        hr = self.transform(HR)
        sr = self.transform(SR)
        
        hr_feat = getattr(self.vgg(hr), layer)
        sr_feat = getattr(self.vgg(sr), layer)
        
        return self.criterion(hr_feat, sr_feat), hr_feat, sr_feat

class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]



### Model Definintion

In [5]:
class _conv(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias):
        super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels, 
                               kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)
        
        self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02)
        self.bias.data = torch.zeros((out_channels))
        
        for p in self.parameters():
            p.requires_grad = True
        

class conv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True):
        super(conv, self).__init__()
        m = []
        m.append(_conv(in_channels = in_channel, out_channels = out_channel, 
                               kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True))
        
        if BN:
            m.append(nn.BatchNorm2d(num_features = out_channel))
        
        if act is not None:
            m.append(act)
        
        self.body = nn.Sequential(*m)
        
    def forward(self, x):
        out = self.body(x)
        return out
    

        
class ResBlock(nn.Module):
    def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True):
        super(ResBlock, self).__init__()
        m = []
        m.append(conv(channels, channels, kernel_size, BN = True, act = act))
        m.append(conv(channels, channels, kernel_size, BN = True, act = None))
        self.body = nn.Sequential(*m)
        
    def forward(self, x):
        res = self.body(x)
        res += x
        return res
    
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)):
        super(BasicBlock, self).__init__()
        m = []
        
        self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act)
        for i in range(num_res_block):
            m.append(ResBlock(out_channels, kernel_size, act))
        
        m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None))
        
        self.body = nn.Sequential(*m)
        
    def forward(self, x):
        res = self.conv(x)
        out = self.body(res)
        out += res
        
        return out
        
class Upsampler(nn.Module):
    def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)):
        super(Upsampler, self).__init__()
        m = []
        m.append(conv(channel, channel * scale * scale, kernel_size))
        m.append(nn.PixelShuffle(scale))
    
        if act is not None:
            m.append(act)
        
        self.body = nn.Sequential(*m)
    
    def forward(self, x):
        out = self.body(x)
        return out

class discrim_block(nn.Module):
    def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)):
        super(discrim_block, self).__init__()
        m = []
        m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act))
        m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2))
        self.body = nn.Sequential(*m)
        
    def forward(self, x):
        out = self.body(x)
        return out


In [6]:
class MiniSRGAN(nn.Module):
    
    def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 8, act = nn.PReLU(), scale=4):
        super(MiniSRGAN, self).__init__()
        
        self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act)
        
        resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)]
        self.body = nn.Sequential(*resblocks)
        
        self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None)
        
        if(scale == 4):
            upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)]
        else:
            upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)]

        self.tail = nn.Sequential(*upsample_blocks)
        
        self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh())
        
    def forward(self, x):
        
        x = self.conv01(x)
        _skip_connection = x
        
        x = self.body(x)
        x = self.conv02(x)
        feat = x + _skip_connection
        
        x = self.tail(feat)
        x = self.last_conv(x)
        
        return x, feat
     
class Discriminator(nn.Module):
    
    def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, act = nn.LeakyReLU(inplace = True), num_of_block = 3, patch_size = 96):
        super(Discriminator, self).__init__()
        self.act = act
        
        self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act)
        self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act, stride = 2)
        
        body = [discrim_block(in_feats = n_feats * (2 ** i), out_feats = n_feats * (2 ** (i + 1)), kernel_size = 3, act = self.act) for i in range(num_of_block)]    
        self.body = nn.Sequential(*body)
        
        self.linear_size = ((patch_size // (2 ** (num_of_block + 1))) ** 2) * (n_feats * (2 ** num_of_block))
        
        tail = []
        
        tail.append(nn.Linear(self.linear_size, 1024))
        tail.append(self.act)
        tail.append(nn.Linear(1024, 1))
        tail.append(nn.Sigmoid())
        
        self.tail = nn.Sequential(*tail)
        
        
    def forward(self, x):
        
        x = self.conv01(x)
        x = self.conv02(x)
        x = self.body(x)        
        x = x.view(-1, self.linear_size)
        x = self.tail(x)
        
        return x



### Training and testing functions

In [None]:
def train(LR_path, GT_path, batch_size=16, res_num=16, num_workers=0, scale=4, L2_coeff=1.0, adv_coeff=1e-3, tv_loss_coeff=0.0, pre_train_epoch=8000, fine_train_epoch=4000, patch_size=24, feat_layer='relu5_4', vgg_rescale_coeff=0.006, in_memory=True, generator_path=None, fine_tuning=False):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    transform  = transforms.Compose([crop(scale, patch_size), augmentation()])
    dataset = mydata(GT_path=GT_path, LR_path=LR_path, in_memory=in_memory, transform=transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    
    generator = MiniSRGAN()
    
    if fine_tuning:        
        generator.load_state_dict(torch.load(generator_path))
        print("pre-trained model is loaded")
        print("path : %s"%(generator_path))
        
    generator = generator.to(device)
    generator.train()
    
    l2_loss = nn.MSELoss()
    g_optim = optim.Adam(generator.parameters(), lr=1e-4)
        
    pre_epoch = 0
    fine_epoch = 0
    
    #### Train using L2_loss
    # Initialize best loss to a very large value
    best_loss = np.inf
    #pre_epoch = 3526
    #generator.load_state_dict(torch.load('./minisrgan_weights/pre_trained_model_latest.pt'))

    while pre_epoch < pre_train_epoch:
        for i, tr_data in enumerate(tqdm(loader, desc=f"Epoch {pre_epoch+1}/{pre_train_epoch}")):
            gt = tr_data['GT'].to(device)
            lr = tr_data['LR'].to(device)

            output, _ = generator(lr)
            loss = l2_loss(gt, output)

            g_optim.zero_grad()
            loss.backward()
            g_optim.step()

        pre_epoch += 1

        if pre_epoch % 2 == 0:
            torch.save(generator.state_dict(), './minisrgan_weights/pre_trained_model_latest.pt')
            print(pre_epoch)
            print(loss.item())
            print('=========')

        # Save the model if it has the best loss so far
        if loss.item() < best_loss:
            best_loss = loss.item()
            torch.save(generator.state_dict(), './minisrgan_weights/best_pre_trained_model.pt')
            print(f"New best model saved with loss: {best_loss}")

        if pre_epoch % 800 == 0:
            torch.save(generator.state_dict(), './minisrgan_weights/pre_trained_model_%03d.pt' % pre_epoch)
    
    #### Train using perceptual & adversarial loss
    vgg_net = vgg19().to(device)
    vgg_net = vgg_net.eval()
    
    discriminator = Discriminator(patch_size=patch_size * scale)
    discriminator = discriminator.to(device)
    discriminator.train()
    
    d_optim = optim.Adam(discriminator.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.StepLR(g_optim, step_size=2000, gamma=0.1)
    
    VGG_loss = perceptual_loss(vgg_net)
    cross_ent = nn.BCELoss()
    tv_loss = TVLoss()
    real_label = torch.ones((batch_size, 1)).to(device)
    fake_label = torch.zeros((batch_size, 1)).to(device)
    
    generator.load_state_dict(torch.load('./minisrgan_weights/pre_trained_model_latest.pt'))

    print('Training Discriminator and Generator...')
    while fine_epoch < fine_train_epoch:
        
        scheduler.step()
        
        # Initialize tqdm progress bar
        for i, tr_data in enumerate(tqdm(loader, desc=f"Epoch {fine_epoch+1}/{fine_train_epoch}")):
            gt = tr_data['GT'].to(device)
            lr = tr_data['LR'].to(device)

                            
            ## Training Discriminator
            output, _ = generator(lr)
            fake_prob = discriminator(output)
            real_prob = discriminator(gt)
            
            d_loss_real = cross_ent(real_prob, real_label)
            d_loss_fake = cross_ent(fake_prob, fake_label)
            
            d_loss = d_loss_real + d_loss_fake

            g_optim.zero_grad()
            d_optim.zero_grad()
            d_loss.backward()
            d_optim.step()
            
            ## Training Generator
            output, _ = generator(lr)
            fake_prob = discriminator(output)
            
            _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer=feat_layer)
            
            L2_loss = l2_loss(output, gt)
            percep_loss = vgg_rescale_coeff * _percep_loss
            adversarial_loss = adv_coeff * cross_ent(fake_prob, real_label)
            total_variance_loss = tv_loss_coeff * tv_loss(vgg_rescale_coeff * (hr_feat - sr_feat)**2)
            
            g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss
            
            if g_loss.item() < best_loss:
                best_loss = g_loss.item()
                torch.save(generator.state_dict(), './minisrgan_weights/best_trained_model.pt')
                print(f"New best model saved with loss: {best_loss}")

            
            g_optim.zero_grad()
            d_optim.zero_grad()
            g_loss.backward()
            g_optim.step()

            
        fine_epoch += 1

        if fine_epoch % 2 == 0:
            print(fine_epoch)
            print(g_loss.item())
            print(d_loss.item())
            print('=========')
            torch.save(generator.state_dict(), './minisrgan_weights/latest_trained_model.pt')

        if fine_epoch % 500 == 0:
            torch.save(generator.state_dict(), './minisrgan_weights/SRGAN_gene_%03d.pt' % fine_epoch)
            torch.save(discriminator.state_dict(), './minisrgan_weights/SRGAN_discrim_%03d.pt' % fine_epoch)


def test(LR_path, GT_path, num_workers=0, scale=4, generator_path=None):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = mydata(GT_path=GT_path, LR_path=LR_path, in_memory=False, transform=None)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=num_workers)
    
    generator = MiniSRGAN()
    generator.load_state_dict(torch.load(generator_path))
    generator = generator.to(device)
    generator.eval()
    
    f = open('./result_minisrgan.txt', 'w')
    psnr_list = []
    
    with torch.no_grad():
        for i, te_data in enumerate(loader):
            gt = te_data['GT'].to(device)
            lr = te_data['LR'].to(device)

            bs, c, h, w = lr.size()
            gt = gt[:, :, : h * scale, : w * scale]

            output, _ = generator(lr)

            output = output[0].cpu().numpy()
            output = np.clip(output, -1.0, 1.0)
            gt = gt[0].cpu().numpy()

            output = (output + 1.0) / 2.0
            gt = (gt + 1.0) / 2.0

            output = output.transpose(1, 2, 0)
            gt = gt.transpose(1, 2, 0)

            y_output = rgb2ycbcr(output)[scale:-scale, scale:-scale, :1]
            y_gt = rgb2ycbcr(gt)[scale:-scale, scale:-scale, :1]
            
            psnr_value = peak_signal_noise_ratio(y_output / 255.0, y_gt / 255.0, data_range=1.0)
            psnr_list.append(psnr_value)
            f.write(f'PSNR: {psnr_value:.4f}\n')

            result = Image.fromarray((output * 255.0).astype(np.uint8))
            result.save('./result_minisrgan/res_%04d.png' % i)

        f.write(f'Average PSNR: {np.mean(psnr_list):.4f}')
        f.close()


def test_only(LR_path, num_workers=0, generator_path=None):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = testOnly_data(LR_path=LR_path, in_memory=False, transform=None)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=num_workers)
    
    generator = MiniSRGAN()
    generator.load_state_dict(torch.load(generator_path))
    generator = generator.to(device)
    generator.eval()
    
    with torch.no_grad():
        for i, te_data in enumerate(loader):
            lr = te_data['LR'].to(device)
            output, _ = generator(lr)
            output = output[0].cpu().numpy()
            output = (output + 1.0) / 2.0
            output = output.transpose(1, 2, 0)
            result = Image.fromarray((output * 255.0).astype(np.uint8))
            result.save('./result_minisrgan/res_%04d.png' % i)

### The original weights were trained as scripts (due to faster training). The following training is for 0 epochs to 106 epochs for reference

In [8]:
train(GT_path='/media/moose/Main Volume/DIV2K_Complete/DIV2K_train',LR_path='/media/moose/Main Volume/DIV2K_Complete/DIV2K_train_LR_bicubic/X4')

Epoch 1/8000: 100%|██████████| 50/50 [00:15<00:00,  3.18it/s]


New best model saved with loss: 0.0630432516336441


Epoch 2/8000: 100%|██████████| 50/50 [00:16<00:00,  3.05it/s]


2
0.06989946961402893


Epoch 3/8000: 100%|██████████| 50/50 [00:16<00:00,  2.99it/s]


New best model saved with loss: 0.05066428706049919


Epoch 4/8000: 100%|██████████| 50/50 [00:16<00:00,  2.96it/s]


4
0.027863571420311928
New best model saved with loss: 0.027863571420311928


Epoch 5/8000: 100%|██████████| 50/50 [00:16<00:00,  2.98it/s]


New best model saved with loss: 0.019608577713370323


Epoch 6/8000: 100%|██████████| 50/50 [00:17<00:00,  2.94it/s]


6
0.03305332735180855


Epoch 7/8000: 100%|██████████| 50/50 [00:16<00:00,  2.99it/s]


New best model saved with loss: 0.018714990466833115


Epoch 8/8000: 100%|██████████| 50/50 [00:16<00:00,  2.95it/s]


8
0.01531557273119688
New best model saved with loss: 0.01531557273119688


Epoch 9/8000: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Epoch 10/8000: 100%|██████████| 50/50 [00:13<00:00,  3.83it/s]


10
0.021924301981925964


Epoch 11/8000: 100%|██████████| 50/50 [00:18<00:00,  2.68it/s]


New best model saved with loss: 0.013780037872493267


Epoch 12/8000: 100%|██████████| 50/50 [00:16<00:00,  2.98it/s]


12
0.02578768879175186


Epoch 13/8000: 100%|██████████| 50/50 [00:16<00:00,  3.08it/s]
Epoch 14/8000: 100%|██████████| 50/50 [00:16<00:00,  3.08it/s]


14
0.016483230516314507


Epoch 15/8000: 100%|██████████| 50/50 [00:16<00:00,  3.05it/s]
Epoch 16/8000: 100%|██████████| 50/50 [00:16<00:00,  3.04it/s]


16
0.023890763521194458


Epoch 17/8000: 100%|██████████| 50/50 [00:16<00:00,  3.03it/s]
Epoch 18/8000: 100%|██████████| 50/50 [00:16<00:00,  3.12it/s]


18
0.010968178510665894
New best model saved with loss: 0.010968178510665894


Epoch 19/8000: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s]
Epoch 20/8000: 100%|██████████| 50/50 [00:16<00:00,  2.94it/s]


20
0.02725946716964245


Epoch 21/8000: 100%|██████████| 50/50 [00:16<00:00,  3.12it/s]
Epoch 22/8000: 100%|██████████| 50/50 [00:13<00:00,  3.83it/s]


22
0.023250719532370567


Epoch 23/8000: 100%|██████████| 50/50 [00:16<00:00,  2.96it/s]


New best model saved with loss: 0.010198301635682583


Epoch 24/8000: 100%|██████████| 50/50 [00:16<00:00,  2.96it/s]


24
0.015099535696208477


Epoch 25/8000: 100%|██████████| 50/50 [00:15<00:00,  3.20it/s]
Epoch 26/8000: 100%|██████████| 50/50 [00:19<00:00,  2.61it/s]


26
0.014168597757816315


Epoch 27/8000: 100%|██████████| 50/50 [00:18<00:00,  2.76it/s]


New best model saved with loss: 0.008731229230761528


Epoch 28/8000: 100%|██████████| 50/50 [00:17<00:00,  2.86it/s]


28
0.012123379856348038


Epoch 29/8000: 100%|██████████| 50/50 [00:13<00:00,  3.61it/s]
Epoch 30/8000: 100%|██████████| 50/50 [00:13<00:00,  3.61it/s]


30
0.018614832311868668


Epoch 31/8000: 100%|██████████| 50/50 [00:13<00:00,  3.70it/s]


New best model saved with loss: 0.006930181290954351


Epoch 32/8000: 100%|██████████| 50/50 [00:12<00:00,  3.88it/s]


32
0.01767265424132347


Epoch 33/8000: 100%|██████████| 50/50 [00:12<00:00,  3.93it/s]
Epoch 34/8000: 100%|██████████| 50/50 [00:14<00:00,  3.49it/s]


34
0.019116433337330818


Epoch 35/8000: 100%|██████████| 50/50 [00:13<00:00,  3.83it/s]
Epoch 36/8000: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]


36
0.012117142789065838


Epoch 37/8000: 100%|██████████| 50/50 [00:12<00:00,  3.91it/s]
Epoch 38/8000: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


38
0.012259937822818756


Epoch 39/8000: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]
Epoch 40/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]


40
0.009144810028374195


Epoch 41/8000: 100%|██████████| 50/50 [00:12<00:00,  3.95it/s]
Epoch 42/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]


42
0.013747246004641056


Epoch 43/8000: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]
Epoch 44/8000: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


44
0.014122581109404564


Epoch 45/8000: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]
Epoch 46/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]


46
0.008083941414952278


Epoch 47/8000: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]
Epoch 48/8000: 100%|██████████| 50/50 [00:14<00:00,  3.48it/s]


48
0.007141957525163889


Epoch 49/8000: 100%|██████████| 50/50 [00:14<00:00,  3.43it/s]
Epoch 50/8000: 100%|██████████| 50/50 [00:13<00:00,  3.68it/s]


50
0.00769816292449832


Epoch 51/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]
Epoch 52/8000: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]


52
0.01099731121212244


Epoch 53/8000: 100%|██████████| 50/50 [00:13<00:00,  3.83it/s]
Epoch 54/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]


54
0.010885214433073997


Epoch 55/8000: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]
Epoch 56/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]


56
0.014050130732357502


Epoch 57/8000: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]
Epoch 58/8000: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]


58
0.030997512862086296


Epoch 59/8000: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]
Epoch 60/8000: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]


60
0.007158256135880947


Epoch 61/8000: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]
Epoch 62/8000: 100%|██████████| 50/50 [00:12<00:00,  3.93it/s]


62
0.008513914421200752


Epoch 63/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]
Epoch 64/8000: 100%|██████████| 50/50 [00:12<00:00,  3.86it/s]


64
0.016217032447457314


Epoch 65/8000: 100%|██████████| 50/50 [00:12<00:00,  3.85it/s]
Epoch 66/8000: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]


66
0.020708324387669563


Epoch 67/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]
Epoch 68/8000: 100%|██████████| 50/50 [00:12<00:00,  3.85it/s]


68
0.01338091492652893


Epoch 69/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]
Epoch 70/8000: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


70
0.009677101857960224


Epoch 71/8000: 100%|██████████| 50/50 [00:13<00:00,  3.78it/s]
Epoch 72/8000: 100%|██████████| 50/50 [00:13<00:00,  3.66it/s]


72
0.03464870527386665


Epoch 73/8000: 100%|██████████| 50/50 [00:13<00:00,  3.78it/s]
Epoch 74/8000: 100%|██████████| 50/50 [00:13<00:00,  3.59it/s]


74
0.02053173817694187


Epoch 75/8000: 100%|██████████| 50/50 [00:14<00:00,  3.41it/s]
Epoch 76/8000: 100%|██████████| 50/50 [00:13<00:00,  3.70it/s]


76
0.020234061405062675


Epoch 77/8000: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s]
Epoch 78/8000: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s]


78
0.01577788032591343


Epoch 79/8000: 100%|██████████| 50/50 [00:14<00:00,  3.53it/s]
Epoch 80/8000: 100%|██████████| 50/50 [00:14<00:00,  3.57it/s]


80
0.009388983249664307


Epoch 81/8000: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]
Epoch 82/8000: 100%|██████████| 50/50 [00:12<00:00,  4.05it/s]


82
0.014110139571130276


Epoch 83/8000: 100%|██████████| 50/50 [00:12<00:00,  4.04it/s]
Epoch 84/8000: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]


84
0.010958055965602398


Epoch 85/8000: 100%|██████████| 50/50 [00:13<00:00,  3.81it/s]
Epoch 86/8000: 100%|██████████| 50/50 [00:12<00:00,  4.02it/s]


86
0.013147730380296707


Epoch 87/8000: 100%|██████████| 50/50 [00:12<00:00,  3.99it/s]
Epoch 88/8000: 100%|██████████| 50/50 [00:14<00:00,  3.45it/s]


88
0.02685728296637535


Epoch 89/8000: 100%|██████████| 50/50 [00:14<00:00,  3.42it/s]
Epoch 90/8000: 100%|██████████| 50/50 [00:13<00:00,  3.80it/s]


90
0.007255207747220993


Epoch 91/8000: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s]
Epoch 92/8000: 100%|██████████| 50/50 [00:14<00:00,  3.45it/s]


92
0.008225744590163231


Epoch 93/8000: 100%|██████████| 50/50 [00:13<00:00,  3.72it/s]
Epoch 94/8000: 100%|██████████| 50/50 [00:12<00:00,  3.90it/s]


94
0.030857080593705177


Epoch 95/8000: 100%|██████████| 50/50 [00:14<00:00,  3.53it/s]
Epoch 96/8000: 100%|██████████| 50/50 [00:14<00:00,  3.46it/s]


96
0.017454542219638824


Epoch 97/8000: 100%|██████████| 50/50 [00:14<00:00,  3.49it/s]
Epoch 98/8000: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]


98
0.016219528391957283


Epoch 99/8000: 100%|██████████| 50/50 [00:14<00:00,  3.48it/s]
Epoch 100/8000: 100%|██████████| 50/50 [00:14<00:00,  3.48it/s]


100
0.010794473811984062


Epoch 101/8000: 100%|██████████| 50/50 [00:12<00:00,  3.86it/s]
Epoch 102/8000: 100%|██████████| 50/50 [00:13<00:00,  3.78it/s]


102
0.006341397762298584
New best model saved with loss: 0.006341397762298584


Epoch 103/8000: 100%|██████████| 50/50 [00:13<00:00,  3.73it/s]
Epoch 104/8000: 100%|██████████| 50/50 [00:12<00:00,  3.86it/s]


104
0.010456353425979614


Epoch 105/8000: 100%|██████████| 50/50 [00:13<00:00,  3.84it/s]
Epoch 106/8000:  76%|███████▌  | 38/50 [00:12<00:05,  2.27it/s]

In [None]:
test(GT_path='../../../../Set5/image_SRF_4/HR/', LR_path='../../../../Set5/image_SRF_4/LR/', generator_path='./minisrgan_weights/best_trained_model.pt')