## Enchanced Super Resolution GAN training
### Based on ESRGAN paper: https://arxiv.org/abs/1809.00219

In [55]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg19
from torchvision.utils import save_image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
import torchvision.transforms.functional as TF
import random
import gc


LOAD_DISC = True
LOAD_GEN = True

SAVE_MODEL = True

DISC_TRAIN = True # Define if we train full GAN or just pre-train Generator

WORK_PATH = "c:/ESRGAN/"
IMAGE_PATH = "HI_RES/"
CHECKPOINT_GEN = "gen_esrgan.pth"
CHECKPOINT_DISC = "disc_esrgan.pth"
TEST_IMAGES = "test_images/"
RESULT_IMAGES = "result_images/"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1000
BATCH_SIZE = 16
NUM_WORKERS = 8
HIGH_RES = 128
LOW_RES = HIGH_RES // 4
DISK_CONV = HIGH_RES // (2**4)
IMG_CHANNELS = 3
PLOT_FREQUENCY = 100

# Loss function weight initialization
pixel_weight = 1e-2 
content_weight = 1.0
adversarial_weight = 5e-3

# Random generators initialization
torch.manual_seed(42)
random.seed(42)

torch.backends.cudnn.benchmark = True

gc.collect()
torch.cuda.empty_cache()

In [None]:
# Download low-res validation dataset
#!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip
#!unzip DIV2K_valid_LR_bicubic_X4.zip

In [None]:
# Download hi-res training dataset
#!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
#!unzip DIV2K_train_HR.zip 

In [57]:
# High-res images transformations
highres_transform = tt.Compose(
    [
        tt.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
    ]
)

# Low-res images transformations
lowres_transform = tt.Compose(
    [
        tt.Resize(size=[LOW_RES, LOW_RES], interpolation=tt.InterpolationMode.BICUBIC),
        tt.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
    ]
)

# High-res and lowres initial images transformations
both_transforms = tt.Compose(
    [   
        tt.RandomCrop(size=[HIGH_RES,HIGH_RES]),
        tt.RandomHorizontalFlip(p=0.5),
        tt.RandomRotation(degrees=(-90, 90), interpolation=tt.InterpolationMode.BICUBIC),
        tt.ToTensor(),
    ]
)

# Test images transformations
test_transform = tt.Compose(
    [
        tt.ToTensor(),
        tt.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
    ]
)

In [58]:
# Simple Convolution block without Batch Normalization
class Conv_Block_No_BN(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(Conv_Block_No_BN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwargs),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        out = self.conv(x)
        return out

In [59]:
# Residual Dense block
class Residual_Dense_Block(nn.Module):
    def __init__(self, in_channels=64, channels=32, residual_beta=0.2):
        super(Residual_Dense_Block, self).__init__()
        self.residual_beta = residual_beta
        
        self.conv1 = Conv_Block_No_BN(in_channels = in_channels, 
                                        out_channels=channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = Conv_Block_No_BN(in_channels = in_channels + channels, 
                                        out_channels=channels, kernel_size=3, stride=1, padding=1)
        self.conv3 = Conv_Block_No_BN(in_channels = in_channels + 2 * channels, 
                                        out_channels=channels, kernel_size=3, stride=1, padding=1)
        self.conv4 = Conv_Block_No_BN(in_channels = in_channels + 3 * channels, 
                                        out_channels=channels, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels = in_channels + 4 * channels, 
                                        out_channels=in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), dim=1))
        x3 = self.conv3(torch.cat((x, x1, x2), dim=1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), dim=1))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), dim=1))
        return x5 * self.residual_beta + x

In [60]:
# Residual in Residual Dense block (RRDB)
class RRDB(nn.Module):
    def __init__(self, in_channels=64, channels=32, residual_beta=0.2):
        super(RRDB, self).__init__()
        self.residual_beta = residual_beta
        
        self.rrdb = nn.Sequential(
            Residual_Dense_Block(in_channels, channels),
            Residual_Dense_Block(in_channels, channels),
            Residual_Dense_Block(in_channels, channels)
        )

    def forward(self, x):
        return self.rrdb(x) * self.residual_beta + x

In [61]:
# Simple Convolution block
class Conv_Block(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(Conv_Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwargs),
            nn.BatchNorm2d(num_features=out_channels),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )

    def forward(self, x):
        out = self.conv(x)
        return out

In [62]:
# Simple Upsample block
class Upsample_Block(nn.Module):
    def __init__(self, channels, scale_factor=2):
        super().__init__()
        self.scale_factor=scale_factor
        self.upsample = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )

    def forward(self, x):
        out =self.upsample(F.interpolate(x, scale_factor=self.scale_factor, mode='nearest'))
        return out

In [63]:
# Generator model
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=23):
        super(Generator, self).__init__()

        self.conv_initial = nn.Conv2d(in_channels=in_channels, 
                                    out_channels=num_channels, kernel_size=3, stride=1, padding=1, bias=True)


        self.residuals_sequence = nn.Sequential(*[RRDB(num_channels) for _ in range(num_blocks)])

        self.conv_midddle = nn.Conv2d(in_channels=num_channels, 
                                    out_channels=num_channels, kernel_size=3, stride=1, padding=1, bias=True)
        
        self.upsample = nn.Sequential(
            Upsample_Block(num_channels), 
            Upsample_Block(num_channels)
        )
        
        self.conv_final = nn.Sequential(
            nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(in_channels=num_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1, bias=True)
        )

    def forward(self, x):
        initial = self.conv_initial(x)
        out = self.residuals_sequence(initial)
        out = self.conv_midddle(out)
        out = torch.add(out, initial)
        out = self.upsample(out)
        out = self.conv_final(out)
        return out

In [64]:
# Discrimonator model
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()

        # (3) x 128 x 128
        self.conv_initial = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )
        # (64) x 128 x 128

        self.conv_sequence = nn.Sequential(
            Conv_Block(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False), # (64) x 64 x 64
            Conv_Block(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False), # (128) x 64 x 64
            Conv_Block(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False), # (128) x 32 x 32
            Conv_Block(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), # (256) x 32 x 32
            Conv_Block(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, bias=False), # (256) x 16 x 16
            Conv_Block(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False),  # (512) x 16 x 16
            Conv_Block(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1, bias=False)  # (512) x 8 x 8
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512*DISK_CONV*DISK_CONV, 1024),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        out = self.conv_initial(x)
        out = self.conv_sequence(out)
        out = self.classifier(out)
        return out

In [65]:
# Initial model weight initialization (as described in ESRGAN paper)
def initialize_weights(model, scale=0.1):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight.data)
            m.weight.data *= scale
            if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight.data)
            m.weight.data *= scale
            if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

In [66]:
def test():
    low_resolution = LOW_RES  # 128x128 -> 32x32
    x = torch.randn((5, 3, low_resolution, low_resolution))
    gen = Generator()
    gen_out = gen(x)
    disc = Discriminator()
    disc_out = disc(gen_out)
    print(gen_out.shape)
    print(disc_out.shape)

#test()

In [67]:
# VGG19-based loss module as described in ESRGAN paper
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:35].eval().to(DEVICE)
        self.loss = nn.L1Loss()

        for param in self.vgg.parameters():
            param.requires_grad = False

        # The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
        self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))


    def forward(self, input, target):

        # Standardized operations
        input = input.sub(self.mean).div(self.std)
        target = target.sub(self.mean).div(self.std)

        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)

In [68]:
# Saving model checkpoint
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth"):
    filename = os.path.join(WORK_PATH, filename)
    print("=> Saving checkpoint to " + filename)
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

# Loading model checkpoint
def load_checkpoint(filename, model, optimizer, lr):
    filename = os.path.join(WORK_PATH, filename)
    print("=> Loading checkpoint from " + filename)
    checkpoint = torch.load(filename, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

# Generating a set of hi-res images from low-res validation images
def plot_examples(gen):
    files_folder = os.path.join(WORK_PATH, TEST_IMAGES)
    files = os.listdir(files_folder)

    gen.eval()
    for file in files:
        image = Image.open(os.path.join(files_folder, file))
        with torch.no_grad():
            upscaled_img = gen(test_transform(image).unsqueeze(0).to(DEVICE))
        save_image(upscaled_img, os.path.join(os.path.join(WORK_PATH, RESULT_IMAGES), file))
    gc.collect()
    torch.cuda.empty_cache()
    gen.train()

In [69]:
# Training procedure
def train_fn(loader, disc, gen, opt_gen, opt_disc, psnr_loss_f, pixel_loss_f, content_loss_f, adversarial_loss_f, epoch):
    loop = tqdm(loader, leave=True)
    loop.set_description(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")

    loss_d_per_epoch = [] #
    loss_g_per_epoch = []
    psnr_per_epoch = []

    for _, image in enumerate(loop):                            
        high_res = highres_transform(image[0]).to(DEVICE)      
        low_res = lowres_transform(image[0]).to(DEVICE)        
        
        fake = gen(low_res)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        if DISC_TRAIN:
            # At this stage, the discriminator needs to require a derivative gradient
            for p in disc.parameters():
                p.requires_grad = True

            # Initialize the discriminator optimizer gradient
            opt_disc.zero_grad()

            # Calculate the loss of the discriminator on the high-res image
            disc_real = disc(high_res)
            disc_fake = disc(fake.detach())
            
            # Defining losses as described in original paper
            disc_loss_real = adversarial_loss_f(disc_real - torch.mean(disc_fake), torch.ones_like(disc_real))
            disc_loss_fake = adversarial_loss_f(disc_fake - torch.mean(disc_real), torch.zeros_like(disc_fake))
            
            # Gradient backpropagation
            disc_loss_real.backward(retain_graph=True)
            disc_loss_fake.backward()
            opt_disc.step()

            # Count discriminator total loss
            loss_disc = disc_loss_real + disc_loss_fake
            # End training discriminator    

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # At this stage, the discriminator needs not to require a derivative gradient
        for p in disc.parameters():
            p.requires_grad = False

        # Initialize the generator optimizer gradient
        opt_gen.zero_grad()
        # Calculate the loss of the generator on the super-res image
        if DISC_TRAIN:
            # Calculate the loss of the discriminator on the high-res image
            disc_real = disc(high_res.detach())
            disc_fake = disc(fake)

            # Calculate different parts of generator loss, as described in original paper.
            pixel_loss = pixel_weight * pixel_loss_f(fake, high_res.detach())
            content_loss = content_weight * content_loss_f(fake, high_res.detach())
            # Adversarial loss as described in original paper
            adversarial_loss = adversarial_weight * adversarial_loss_f(disc_fake - torch.mean(disc_real), 
                                torch.ones_like(disc_fake))
            # Count generator total loss
            gen_loss = pixel_loss + content_loss + adversarial_loss
        else:
            # Just Generator is trained, no GAN involved
            pixel_loss = pixel_loss_f(fake, high_res.detach())
            gen_loss = pixel_loss
            loss_disc = torch.zeros(1)

        # Gradient backpropagation
        gen_loss.backward()
        opt_gen.step()

        # End training generator

        # measure accuracy and record loss
        psnr = 10. * torch.log10(1. / psnr_loss_f(fake, high_res))

        loss_d_per_epoch.append(loss_disc.item()) #
        loss_g_per_epoch.append(gen_loss.item()) #
        psnr_per_epoch.append(psnr.item())

        loop.set_postfix(loss_g=gen_loss.item(), loss_d=loss_disc.item(), psnr=psnr.item())

    # Record losses & scores
    losses_g_e = np.mean(loss_g_per_epoch)
    losses_d_e = np.mean(loss_d_per_epoch)
    psnr_e = np.mean(psnr_per_epoch)

    return losses_g_e, losses_d_e, psnr_e

In [70]:
def main():
    # Initialize dataset
    dataset = ImageFolder(os.path.join(WORK_PATH, IMAGE_PATH), transform=both_transforms)
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=NUM_WORKERS,
    )

    # Define models and losses
    gen = Generator(in_channels=3).to(DEVICE)
    disc = Discriminator(in_channels=3).to(DEVICE)
    initialize_weights(gen)

    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
    
    psnr_loss_f = nn.MSELoss().to(DEVICE)
    pixel_loss_f = nn.L1Loss().to(DEVICE)
    content_loss_f = VGGLoss().to(DEVICE)
    adversarial_loss_f = nn.BCEWithLogitsLoss().to(DEVICE)

    gen.train()
    disc.train()

    if LOAD_GEN:
        load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
    if LOAD_DISC:
        load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)

    # Losses & scores
    losses_g = []
    losses_d = []
    psnr = []
    
    for epoch in range(NUM_EPOCHS):
        losses_g_e, losses_d_e, psnr_e = train_fn(loader, 
                disc, gen, opt_gen, opt_disc, psnr_loss_f, pixel_loss_f, content_loss_f, adversarial_loss_f, epoch)
        
        losses_g.append(losses_g_e)
        losses_d.append(losses_d_e)
        psnr.append(psnr_e)

        
        if (epoch+1) % PLOT_FREQUENCY == 0:
            print(f"Plotting samples for epoch {epoch+1}" )
            plot_examples(gen)
            gc.collect()
            torch.cuda.empty_cache()

    
    if SAVE_MODEL:
        save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
        save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

    gc.collect()
    torch.cuda.empty_cache()
    plot_examples(gen)

    # Show Losses
    plt.figure(figsize=(15, 6))
    plt.plot(losses_d, '-')
    plt.plot(losses_g, '-')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Discriminator', 'Generator'])
    plt.title('Losses')
    plt.show()
    # Show PSNR
    plt.figure(figsize=(15, 6))
    plt.plot(psnr, '-')
    plt.xlabel('epoch')
    plt.ylabel('psnr')
    plt.title('PSNR')
    plt.show()

main()

=> Loading checkpoint from c:/ESRGAN/gen_esrgan.pth


Epoch [1/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.25it/s, loss_d=0, loss_g=0.0275, psnr=24.8]
Epoch [2/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0259, psnr=25.1]
Epoch [3/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0239, psnr=26.1]
Epoch [4/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0179, psnr=27.9]
Epoch [5/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d=0, loss_g=0.0215, psnr=25.5]
Epoch [6/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0207, psnr=27.4]
Epoch [7/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.36it/s, loss_d=0, loss_g=0.0277, psnr=24.2]
Epoch [8/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s, loss_d=0, loss_g=0.0145, psnr=29.7]
Epoch [9/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0299, psnr=24.5]
Epoch [10/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0371, psnr=22.7]

Plotting samples for epoch 100


Epoch [101/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.43it/s, loss_d=0, loss_g=0.0222, psnr=26]  
Epoch [102/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.40it/s, loss_d=0, loss_g=0.0263, psnr=25.2]
Epoch [103/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0148, psnr=29]  
Epoch [104/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0185, psnr=27.2]
Epoch [105/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.28it/s, loss_d=0, loss_g=0.0314, psnr=23.6]
Epoch [106/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.34it/s, loss_d=0, loss_g=0.0218, psnr=27]  
Epoch [107/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0344, psnr=23.4]
Epoch [108/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0283, psnr=24]  
Epoch [109/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s, loss_d=0, loss_g=0.0196, psnr=27.9]
Epoch [110/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g

Plotting samples for epoch 200


Epoch [201/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.40it/s, loss_d=0, loss_g=0.0279, psnr=24.8]
Epoch [202/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d=0, loss_g=0.0229, psnr=26.5]
Epoch [203/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s, loss_d=0, loss_g=0.0197, psnr=27.6]
Epoch [204/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0147, psnr=28.7]
Epoch [205/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.28it/s, loss_d=0, loss_g=0.0324, psnr=24.5]
Epoch [206/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0299, psnr=24.8]
Epoch [207/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.38it/s, loss_d=0, loss_g=0.0148, psnr=28.6]
Epoch [208/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0311, psnr=23.4]
Epoch [209/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0274, psnr=24.7]
Epoch [210/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.36it/s, loss_d=0, loss_g

Plotting samples for epoch 300


Epoch [301/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.42it/s, loss_d=0, loss_g=0.0247, psnr=25.4]
Epoch [302/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.36it/s, loss_d=0, loss_g=0.0193, psnr=26.1]
Epoch [303/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0152, psnr=29.6]
Epoch [304/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.33it/s, loss_d=0, loss_g=0.0234, psnr=25.3]
Epoch [305/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.36it/s, loss_d=0, loss_g=0.0271, psnr=24.1]
Epoch [306/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.022, psnr=25.9] 
Epoch [307/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0292, psnr=24.2]
Epoch [308/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0216, psnr=27.2]
Epoch [309/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s, loss_d=0, loss_g=0.0182, psnr=26.7]
Epoch [310/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.34it/s, loss_d=0, loss_g

Plotting samples for epoch 400


Epoch [401/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.42it/s, loss_d=0, loss_g=0.0356, psnr=23.2]
Epoch [402/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.38it/s, loss_d=0, loss_g=0.0262, psnr=25.1]
Epoch [403/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.28it/s, loss_d=0, loss_g=0.0246, psnr=25.3]
Epoch [404/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.36it/s, loss_d=0, loss_g=0.014, psnr=30.1] 
Epoch [405/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d=0, loss_g=0.026, psnr=24.1] 
Epoch [406/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0188, psnr=27.1]
Epoch [407/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s, loss_d=0, loss_g=0.021, psnr=25.9] 
Epoch [408/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.33it/s, loss_d=0, loss_g=0.039, psnr=22.6] 
Epoch [409/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d=0, loss_g=0.026, psnr=25.6] 
Epoch [410/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g

Plotting samples for epoch 500


Epoch [501/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.39it/s, loss_d=0, loss_g=0.0258, psnr=24.4]
Epoch [502/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.36it/s, loss_d=0, loss_g=0.03, psnr=23]     
Epoch [503/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0184, psnr=25.7]
Epoch [504/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.26it/s, loss_d=0, loss_g=0.0178, psnr=27.7]
Epoch [505/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.021, psnr=26.3] 
Epoch [506/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0262, psnr=25]  
Epoch [507/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0307, psnr=24.8]
Epoch [508/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.26it/s, loss_d=0, loss_g=0.0304, psnr=24.2]
Epoch [509/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d=0, loss_g=0.0221, psnr=26]  
Epoch [510/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.28it/s, loss_d=0, loss_

Plotting samples for epoch 600


Epoch [601/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.37it/s, loss_d=0, loss_g=0.0204, psnr=27]  
Epoch [602/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.36it/s, loss_d=0, loss_g=0.0245, psnr=25.3]
Epoch [603/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d=0, loss_g=0.0272, psnr=24.6]
Epoch [604/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0154, psnr=29.7]
Epoch [605/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.34it/s, loss_d=0, loss_g=0.0206, psnr=26.7]
Epoch [606/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.02, psnr=26.8]  
Epoch [607/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.37it/s, loss_d=0, loss_g=0.0202, psnr=27.3]
Epoch [608/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d=0, loss_g=0.0324, psnr=24]  
Epoch [609/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.37it/s, loss_d=0, loss_g=0.0267, psnr=23.5]
Epoch [610/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d=0, loss_g

Plotting samples for epoch 700


Epoch [701/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.39it/s, loss_d=0, loss_g=0.0295, psnr=24.8]
Epoch [702/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.29it/s, loss_d=0, loss_g=0.0266, psnr=24.6]
Epoch [703/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s, loss_d=0, loss_g=0.0187, psnr=27.6]
Epoch [704/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0222, psnr=25.6]
Epoch [705/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.29it/s, loss_d=0, loss_g=0.0222, psnr=25.9]
Epoch [706/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0302, psnr=24.4]
Epoch [707/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s, loss_d=0, loss_g=0.0147, psnr=28.9]
Epoch [708/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0208, psnr=27.6]
Epoch [709/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0197, psnr=27.6]
Epoch [710/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g

Plotting samples for epoch 800


Epoch [801/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.43it/s, loss_d=0, loss_g=0.0163, psnr=27.9]
Epoch [802/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.37it/s, loss_d=0, loss_g=0.0277, psnr=24.2]
Epoch [803/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.28it/s, loss_d=0, loss_g=0.0219, psnr=26.2]
Epoch [804/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0187, psnr=27.7]
Epoch [805/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0134, psnr=31.4]
Epoch [806/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s, loss_d=0, loss_g=0.024, psnr=25.1] 
Epoch [807/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.29it/s, loss_d=0, loss_g=0.0282, psnr=24.8]
Epoch [808/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.30it/s, loss_d=0, loss_g=0.032, psnr=23]   
Epoch [809/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.29it/s, loss_d=0, loss_g=0.014, psnr=29.9] 
Epoch [810/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.33it/s, loss_d=0, loss_g

Plotting samples for epoch 900


Epoch [901/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.41it/s, loss_d=0, loss_g=0.0204, psnr=26.3]
Epoch [902/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.37it/s, loss_d=0, loss_g=0.0247, psnr=26.2]
Epoch [903/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0286, psnr=24.5] 
Epoch [904/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0319, psnr=23.6]
Epoch [905/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0293, psnr=24.7]
Epoch [906/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0308, psnr=24.3]
Epoch [907/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.34it/s, loss_d=0, loss_g=0.029, psnr=24.6] 
Epoch [908/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d=0, loss_g=0.029, psnr=24.7] 
Epoch [909/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0224, psnr=26.3]
Epoch [910/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.34it/s, loss_d=0, loss_

Plotting samples for epoch 1000


Epoch [1001/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.41it/s, loss_d=0, loss_g=0.0288, psnr=24.1]
Epoch [1002/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.38it/s, loss_d=0, loss_g=0.0244, psnr=24.3]
Epoch [1003/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.33it/s, loss_d=0, loss_g=0.0158, psnr=27.9]
Epoch [1004/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0223, psnr=26.2]
Epoch [1005/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0231, psnr=25.9]
Epoch [1006/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0221, psnr=26.3]
Epoch [1007/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.34it/s, loss_d=0, loss_g=0.0247, psnr=25.9]
Epoch [1008/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0251, psnr=25.8]
Epoch [1009/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0342, psnr=23.1]
Epoch [1010/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d

Plotting samples for epoch 1100


Epoch [1101/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.41it/s, loss_d=0, loss_g=0.0301, psnr=24.9]
Epoch [1102/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.38it/s, loss_d=0, loss_g=0.0162, psnr=29.3]
Epoch [1103/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.33it/s, loss_d=0, loss_g=0.0225, psnr=25.9]
Epoch [1104/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.28it/s, loss_d=0, loss_g=0.0248, psnr=25.8]
Epoch [1105/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0207, psnr=26.3]
Epoch [1106/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.31it/s, loss_d=0, loss_g=0.0147, psnr=28.9]
Epoch [1107/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.34it/s, loss_d=0, loss_g=0.027, psnr=25]   
Epoch [1108/4000]: 100%|██████████| 50/50 [00:15<00:00,  3.32it/s, loss_d=0, loss_g=0.0294, psnr=24.2]
Epoch [1109/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.34it/s, loss_d=0, loss_g=0.0187, psnr=28.4]
Epoch [1110/4000]: 100%|██████████| 50/50 [00:14<00:00,  3.35it/s, loss_d

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
###############################