<a href="https://colab.research.google.com/github/purvilmehta06/Image-Super-Resolution-Using-GAN-SRGAN/blob/main/SRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SRGAN implementation in Pytorch

Ref: [Code Link](https://github.com/leftthomas/SRGAN) 

Paper: [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/pdf/1609.04802.pdf)

Dataset: 
* DIV2K - Valid
  * [HR](https://data.vision.ee.ethz.ch/cvl/DIV2K/validation_release/DIV2K_valid_HR.zip)
  * [LR](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip)
* DIV2k - Train
  * [HR](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip)
  * [LR](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip)

Students: 
1. Ruchit Vithani (201701070)
2. Purvil Mehta (201701073)
3. Bhargey Mehta (201701074)
4. Kushal Shah (201701111)

## Paths for saving weigths and results

In [None]:
train_path = '/content/drive/MyDrive/DL_project_old/data/DIV2K_train_HR'
val_path = '/content/drive/MyDrive/DL_project_old/data/DIV2K_valid_HR'

G_weights_load = '/content/drive/MyDrive/DL_new/epochs/netG_epoch_4_100.pth'
D_weights_load = '/content/drive/MyDrive/DL_new/epochs/netD_epoch_4_100.pth'

imgs_save = '/content/drive/MyDrive/DL_new/training_results/sr/'

G_weights_save = '/content/drive/MyDrive/DL_new/epochs/'
D_weights_save = '/content/drive/MyDrive/DL_new/epochs/'

out_stat_path = '/content/drive/MyDrive/DL_new/statistics/'

# Library Initialisation 

In [None]:
!pip install pytorch_ssim

Collecting pytorch_ssim
  Downloading https://files.pythonhosted.org/packages/dc/78/f6cfa15ff7c66de5bb0873fb4bd699ff8024a0b00a94babbd216e64202b7/pytorch_ssim-0.1.tar.gz
Building wheels for collected packages: pytorch-ssim
  Building wheel for pytorch-ssim (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-ssim: filename=pytorch_ssim-0.1-cp36-none-any.whl size=2027 sha256=ba25fedcbd8cbba9ae4f5b5924efc4ddc2e402ebc0ecb411515caa46b925b34e
  Stored in directory: /root/.cache/pip/wheels/86/60/c8/85a73ea90dcf1d39d5d7f94d83988511f0370229dee641bb79
Successfully built pytorch-ssim
Installing collected packages: pytorch-ssim
Successfully installed pytorch-ssim-0.1


In [None]:
import matplotlib.pyplot as plt
import os
from os import listdir
from os.path import join
from tqdm import tqdm
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
import torch
from torch import nn
from torchvision.models.vgg import vgg16,vgg19
from google.colab import files
import math
import torchvision
import argparse
from math import log10
import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
import tensorflow as tf
import pytorch_ssim

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Supporting Functions 

## Dataset Loader

In [None]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        ToTensor(),
    ])


def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
        ToTensor()
    ])


def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])

class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)


class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir)]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        
        w, h = hr_image.size
        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)

## Loss Initialisation

In [None]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg19(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:34]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = torch.mean(-torch.log(out_labels + 1e-6))
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)

        return  image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss


## Model Initialisation

In [None]:
class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1.clone() + block7)

        return block8


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2, inplace=False),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=False),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=False),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=False),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=False),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=False),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=False),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=False),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2, inplace=False),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x1 = self.net(x)
        x2 = x1.view(batch_size)
        x3 = torch.sigmoid(x2)

        return x3

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x.clone() + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

# Training Process

In [None]:
CROP_SIZE = 96
UPSCALE_FACTOR = 4
NUM_EPOCHS = 60

train_set = TrainDatasetFromFolder(train_path, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder(val_path, upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

# Incase, want to continue the training process from the previous weigths
# netG.load_state_dict(torch.load(G_weights_load))
# netD.load_state_dict(torch.load(D_weights_load))

generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': [], 'mse' : []}

# torch.autograd.set_detect_anomaly(True)

for epoch in range(1, NUM_EPOCHS + 1):
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

    netG.train()
    netD.train()

    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1-D(G(z)))
        ###########################
        real_img = torch.Tensor(target)
        if torch.cuda.is_available():
            real_img = real_img.cuda()
        z = torch.Tensor(data)
        if torch.cuda.is_available():
            z = z.cuda()
        fake_img = netG(z)

        netD.zero_grad()
        real_out_1 = netD(real_img)
        real_out = torch.mean(real_out_1)
        fake_out_1 = netD(fake_img)
        fake_out = torch.mean(fake_out_1)
        # if fake_out = 1, real_out = 0 => loss should be max 
        d_loss = -(torch.log(real_out + 1e-6) + torch.log(1-fake_out + 1e-6))
        d_loss.backward(retain_graph=True)
        

        ############################
        # (2) Update G network: minimize -log(D(G(z))) + Perception Loss + Image Loss
        ###########################
        netG.zero_grad()
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()
        
        fake_img = netG(z)
        fake_out = netD(fake_img).mean()
        
        optimizerD.step()
        optimizerG.step()

        # loss for current batch before optimization 
        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['d_score'] += real_out.item() * batch_size
        running_results['g_score'] += fake_out.item() * batch_size

        train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
            epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
            running_results['g_loss'] / running_results['batch_sizes'],
            running_results['d_score'] / running_results['batch_sizes'],
            running_results['g_score'] / running_results['batch_sizes']))
    

    netG.eval()
    out_path = imgs_save
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    with torch.no_grad():
        val_bar = tqdm(val_loader)
        valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
        # val_images = []
        for val_lr, val_hr_restore, val_hr in val_bar:
            batch_size = val_lr.size(0)
            valing_results['batch_sizes'] += batch_size
            lr = val_lr
            hr = val_hr
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()
            sr = netG(lr)
    
            batch_mse = ((sr - hr) ** 2).data.mean()
            valing_results['mse'] += batch_mse * batch_size
            batch_ssim = pytorch_ssim.ssim(sr, hr).item()
            valing_results['ssims'] += batch_ssim * batch_size
            valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
            valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
            val_bar.set_description(
                desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                    valing_results['psnr'], valing_results['ssim']))

            # val_images.extend(
            #     [[lr.squeeze(0), hr.data.cpu().squeeze(0),
            #         sr.data.cpu().squeeze(0)]])
            

        # val_save_bar = tqdm(val_images, desc='[saving training results]')
        # index = 1
        # for image in val_save_bar:
        #     utils.save_image(image[0], out_path + "lr_" + str(index) + '.png')
        #     utils.save_image(image[1], out_path + "hr_" + str(index) + '.png')
        #     utils.save_image(image[2], out_path + "sr_" + str(index) + '.png')
        #     index += 1
    


    # save loss\scores\psnr\ssim
    results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
    results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
    results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
    results['psnr'].append(valing_results['psnr'])
    results['ssim'].append(valing_results['ssim'])
    results['mse'].append(valing_results['mse'])
    

    if epoch % 10 == 0:
        # save model parameters
        torch.save(netG.state_dict(), G_weights_save + 'netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        torch.save(netD.state_dict(), D_weights_save + 'netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))

        data_frame = pd.DataFrame(
            data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                    'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim'], 'MSE' : results['mse']},
            index=range(1, epoch + 1))
        data_frame.to_csv(out_stat_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')

# generator parameters: 734219
# discriminator parameters: 5215425


[1/60] Loss_D: 1.1772 Loss_G: 0.0339 D(x): 0.5724 D(G(z)): 0.4258: 100%|██████████| 50/50 [01:34<00:00,  1.89s/it]
[converting LR images to SR images] PSNR: 19.0732 dB SSIM: 0.4964: 100%|██████████| 100/100 [00:36<00:00,  2.73it/s]
[2/60] Loss_D: 1.2809 Loss_G: 0.0185 D(x): 0.5579 D(G(z)): 0.4775: 100%|██████████| 50/50 [01:23<00:00,  1.67s/it]
[converting LR images to SR images] PSNR: 19.2237 dB SSIM: 0.4962: 100%|██████████| 100/100 [00:37<00:00,  2.70it/s]
[3/60] Loss_D: 1.4001 Loss_G: 0.0165 D(x): 0.5232 D(G(z)): 0.5146: 100%|██████████| 50/50 [01:11<00:00,  1.42s/it]
[converting LR images to SR images] PSNR: 17.7141 dB SSIM: 0.5328: 100%|██████████| 100/100 [00:37<00:00,  2.69it/s]
[4/60] Loss_D: 1.4289 Loss_G: 0.0142 D(x): 0.5156 D(G(z)): 0.5239: 100%|██████████| 50/50 [01:07<00:00,  1.35s/it]
[converting LR images to SR images] PSNR: 19.1857 dB SSIM: 0.5639: 100%|██████████| 100/100 [00:37<00:00,  2.69it/s]
[5/60] Loss_D: 1.4430 Loss_G: 0.0138 D(x): 0.5965 D(G(z)): 0.5891: 100%|

# Validation - Generate Results

In [None]:
def buildG(UPSCALE_FACTOR=4):
    netG = Generator(UPSCALE_FACTOR)
    netG.train()
    netG.load_state_dict(torch.load(G_weights_load))
    netG.cuda()

    return netG

netG = buildG()
def test_on_single_image(path,UPSCALE_FACTOR=4, index=0):
    img = Image.open(path)
    width, heigth = img.size
    img = img.resize((width//4,heigth//4),Image.BILINEAR)
    layer = ToTensor()
    img3 = layer(img)
    sh = img3.shape
    img3 =img3.reshape((1, sh[0], sh[1], sh[2]))
    img2 = netG(img3.cuda(0))
    img_path = imgs_save +'sr-final_'+ str(index) + '.png'
    torchvision.utils.save_image(img2, img_path)
    #files.download(img_path)

In [None]:
dataset_dir = '/content/drive/MyDrive/test_images/hr'
image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir)]
i=1
for path in image_filenames:
    print(path)
    test_on_single_image(path, 4, i)
    i+=1

/content/drive/MyDrive/test_images/hr/hr1.png
/content/drive/MyDrive/test_images/hr/hr2.png
/content/drive/MyDrive/test_images/hr/hr3.png
/content/drive/MyDrive/test_images/hr/hr4.png
/content/drive/MyDrive/test_images/hr/hr5.png
/content/drive/MyDrive/test_images/hr/hr6.png
/content/drive/MyDrive/test_images/hr/hr7.png
/content/drive/MyDrive/test_images/hr/hr8.png
/content/drive/MyDrive/test_images/hr/hr9.png
/content/drive/MyDrive/test_images/hr/hr10.png
/content/drive/MyDrive/test_images/hr/hr11.png


# Validation On 100 images

In [None]:
# import argparse
# import os
# from math import log10

# import pandas as pd
# import torch.optim as optim
# import torch.utils.data
# import torchvision.utils as utils
# from torch.autograd import Variable
# from torch.utils.data import DataLoader
# from tqdm import tqdm

# import pytorch_ssim

# CROP_SIZE = 96
# UPSCALE_FACTOR = 4
# NUM_EPOCHS = 100

# train_set = TrainDatasetFromFolder('/content/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
# val_set = ValDatasetFromFolder('/content/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
# train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True)
# val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

# netG = Generator(UPSCALE_FACTOR)
# print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
# netD = Discriminator()
# print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

# netG.load_state_dict(torch.load('/content/drive/MyDrive/DL_project (1)/epochs/netG_epoch_4_99.pth'))
# netD.load_state_dict(torch.load('/content/drive/MyDrive/DL_project (1)/epochs/netD_epoch_4_99.pth'))

# generator_criterion = GeneratorLoss()

# if torch.cuda.is_available():
#     netG.cuda()
#     netD.cuda()
#     generator_criterion.cuda()

# optimizerG = optim.Adam(netG.parameters())
# optimizerD = optim.Adam(netD.parameters())

# results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

# # torch.autograd.set_detect_anomaly(True)

# for epoch in range(101, NUM_EPOCHS + 101):
#     train_bar = tqdm(train_loader)
#     running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

#     netG.train()
#     netD.train()

#     for data, target in train_bar:
#         g_update_first = True
#         batch_size = data.size(0)
#         running_results['batch_sizes'] += batch_size

#         ############################
#         # (1) Update D network: maximize D(x)-1-D(G(z))
#         ###########################
#         real_img = torch.Tensor(target)
#         if torch.cuda.is_available():
#             real_img = real_img.cuda()
#         z = torch.Tensor(data)
#         if torch.cuda.is_available():
#             z = z.cuda()
#         fake_img = netG(z)

#         netD.zero_grad()
#         real_out_1 = netD(real_img)
#         real_out = torch.mean(real_out_1)
#         fake_out_1 = netD(fake_img)
#         fake_out = torch.mean(fake_out_1)
#         d_loss = 1 - real_out + fake_out
#         d_loss.backward(retain_graph=True)
        

#         ############################
#         # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss
#         ###########################
#         netG.zero_grad()
#         g_loss = generator_criterion(fake_out, fake_img, real_img)
#         g_loss.backward()
        
#         fake_img = netG(z)
#         fake_out = netD(fake_img).mean()
        
#         optimizerD.step()
#         optimizerG.step()

#         # loss for current batch before optimization 
#         running_results['g_loss'] += g_loss.item() * batch_size
#         running_results['d_loss'] += d_loss.item() * batch_size
#         running_results['d_score'] += real_out.item() * batch_size
#         running_results['g_score'] += fake_out.item() * batch_size

#         train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
#             epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
#             running_results['g_loss'] / running_results['batch_sizes'],
#             running_results['d_score'] / running_results['batch_sizes'],
#             running_results['g_score'] / running_results['batch_sizes']))
         

#     netG.eval()
#     out_path = '/content/drive/MyDrive/DL_project (1)/training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
#     if not os.path.exists(out_path):
#         os.makedirs(out_path)
    
#     with torch.no_grad():
#         val_bar = tqdm(val_loader)
#         valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
#         val_images = []
#         for val_lr, val_hr_restore, val_hr in val_bar:
#             batch_size = val_lr.size(0)
#             valing_results['batch_sizes'] += batch_size
#             lr = val_lr
#             hr = val_hr
#             if torch.cuda.is_available():
#                 lr = lr.cuda()
#                 hr = hr.cuda()
#             sr = netG(lr)
    
#             batch_mse = ((sr - hr) ** 2).data.mean()
#             valing_results['mse'] += batch_mse * batch_size
#             batch_ssim = pytorch_ssim.ssim(sr, hr).item()
#             valing_results['ssims'] += batch_ssim * batch_size
#             valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
#             valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
#             val_bar.set_description(
#                 desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
#                     valing_results['psnr'], valing_results['ssim']))
            
#             val_images.extend(
#                 [display_transform()(lr.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
#                     display_transform()(sr.data.cpu().squeeze(0))])
            

#         val_images = torch.stack(val_images)
#         val_images = torch.chunk(val_images, val_images.size(0) // 15)
#         val_save_bar = tqdm(val_images, desc='[saving training results]')
#         index = 1
#         for image in val_save_bar:
#             image = utils.make_grid(image, nrow=3, padding=5)
#             utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
#             index += 1

#     # save model parameters
#     torch.save(netG.state_dict(), '/content/drive/MyDrive/DL_project (1)/epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
#     torch.save(netD.state_dict(), '/content/drive/MyDrive/DL_project (1)/epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
#     # save loss\scores\psnr\ssim
#     results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
#     results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
#     results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
#     results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
#     results['psnr'].append(valing_results['psnr'])
#     results['ssim'].append(valing_results['ssim'])

#     if epoch % 10 == 0 and epoch != 0:
#         out_path = '/content/drive/MyDrive/DL_project (1)/statistics/'
#         data_frame = pd.DataFrame(
#             data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
#                     'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
#             index=range(1, epoch + 1))
#         data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')