In [15]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [16]:
import time
import torch.backends.cudnn as cudnn
from torch import nn
from models import Generator, Discriminator, TruncatedVGG19
from datasets import SRDataset
from utils import *
import csv
from datetime import datetime
start_time = datetime.now()

# Data parameters
data_folder = './'  # folder with JSON data files
crop_size = 96  # crop size of target HR images
scaling_factor = 2  # the scaling factor for the generator; the input LR images will be downsampled from the target HR images by this factor

# Generator parameters
large_kernel_size_g = 9  # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size_g = 3  # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
n_channels_g = 64  # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
n_blocks_g = 16  # number of residual blocks
srresnet_checkpoint = './../E100Bt8_srresnet.pth.tar'  # filepath of the trained SRResNet checkpoint used for initialization

# Discriminator parameters
kernel_size_d = 3  # kernel size in all convolutional blocks
n_channels_d = 64  # number of output channels in the first convolutional block, after which it is doubled in every 2nd block thereafter
n_blocks_d = 8  # number of convolutional blocks
fc_size_d = 1024  # size of the first fully connected layer

# Learning parameters
checkpoint = None  # path to model (SRGAN) checkpoint, None if none
batch_size = 64  # batch size
start_epoch = 0  # start at this epoch
iterations = 2e5  # number of training iterations
workers = 0  # number of workers for loading data in the DataLoader
vgg19_i = 5  # the index i in the definition for VGG loss; see paper or models.py
vgg19_j = 4  # the index j in the definition for VGG loss; see paper or models.py
beta = 1e-3  # the coefficient to weight the adversarial loss in the perceptual loss
print_freq = 1  # print training status once every __ batches
lr = 1e-4  # learning rate
grad_clip = None  # clip if gradients are exploding
index = 0

# Default device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cudnn.benchmark = True


def main():
    """
    Training.
    """
    global start_epoch, epoch, checkpoint, srresnet_checkpoint

    # Initialize model or load checkpoint
    if checkpoint is None:
        # Generator
        generator = Generator(large_kernel_size=large_kernel_size_g,
                              small_kernel_size=small_kernel_size_g,
                              n_channels=n_channels_g,
                              n_blocks=n_blocks_g,
                              scaling_factor=scaling_factor)

        # Initialize generator network with pretrained SRResNet
        generator.initialize_with_srresnet(srresnet_checkpoint=srresnet_checkpoint)

        # Initialize generator's optimizer
        optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()),
                                       lr=lr)

        # Discriminator
        discriminator = Discriminator(kernel_size=kernel_size_d,
                                      n_channels=n_channels_d,
                                      n_blocks=n_blocks_d,
                                      fc_size=fc_size_d)

        # Initialize discriminator's optimizer
        optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()),
                                       lr=lr)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        generator = checkpoint['generator']
        discriminator = checkpoint['discriminator']
        optimizer_g = checkpoint['optimizer_g']
        optimizer_d = checkpoint['optimizer_d']
        print("\nLoaded checkpoint from epoch %d.\n" % (checkpoint['epoch'] + 1))

    # Truncated VGG19 network to be used in the loss calculation
    truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j)
    truncated_vgg19.eval()

    # Loss functions
    content_loss_criterion = nn.MSELoss()
    adversarial_loss_criterion = nn.BCEWithLogitsLoss()

    # Move to default device
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    truncated_vgg19 = truncated_vgg19.to(device)
    content_loss_criterion = content_loss_criterion.to(device)
    adversarial_loss_criterion = adversarial_loss_criterion.to(device)

    # Custom dataloaders
    train_dataset = SRDataset(data_folder,
                              split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='imagenet-norm')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers,
                                               pin_memory=True)

    # Total number of epochs to train for
    epochs = int(iterations // len(train_loader) + 1)
    csvfile = open('./../img/loss_ep20_bt8.csv','w')
    filewriter = csv.writer(csvfile, delimiter = ',', quotechar = '|',quoting = csv.QUOTE_MINIMAL)
    filewriter.writerow(['Nomor','Epoch', 'Cont. Loss','Adv. Loss','Disc. Loss' ])
    
    # Epochs
    for epoch in range(start_epoch, 10):

        # At the halfway point, reduce learning rate to a tenth
        if epoch == int((iterations / 2) // len(train_loader) + 1):
            adjust_learning_rate(optimizer_g, 0.1)
            adjust_learning_rate(optimizer_d, 0.1)

        # One epoch's training
        train(train_loader=train_loader,
              generator=generator,
              discriminator=discriminator,
              truncated_vgg19=truncated_vgg19,
              content_loss_criterion=content_loss_criterion,
              adversarial_loss_criterion=adversarial_loss_criterion,
              optimizer_g=optimizer_g,
              optimizer_d=optimizer_d,
              epoch=epoch,
              idx = index,
              filewriter = filewriter)

        # Save checkpoint
        torch.save({'epoch': epoch,
                    'generator': generator,
                    'discriminator': discriminator,
                    'optimizer_g': optimizer_g,
                    'optimizer_d': optimizer_d},
                   'E500Bt64_srgan.pth.tar')


def train(train_loader, generator, discriminator, truncated_vgg19, content_loss_criterion, adversarial_loss_criterion,
          optimizer_g, optimizer_d, epoch, idx, filewriter):
    """
    One epoch's training.

    :param train_loader: train dataloader
    :param generator: generator
    :param discriminator: discriminator
    :param truncated_vgg19: truncated VGG19 network
    :param content_loss_criterion: content loss function (Mean Squared-Error loss)
    :param adversarial_loss_criterion: adversarial loss function (Binary Cross-Entropy loss)
    :param optimizer_g: optimizer for the generator
    :param optimizer_d: optimizer for the discriminator
    :param epoch: epoch number
    """
    # Set to train mode
    generator.train()
    discriminator.train()  # training mode enables batch normalization

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses_c = AverageMeter()  # content loss
    losses_a = AverageMeter()  # adversarial loss in the generator
    losses_d = AverageMeter()  # adversarial loss in the discriminator

    start = time.time()

    # Batches
    for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to default device
        lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed
        hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96), imagenet-normed

        # GENERATOR UPDATE

        # Generate
        sr_imgs = generator(lr_imgs)  # (N, 3, 96, 96), in [-1, 1]
        sr_imgs = convert_image(sr_imgs, source='[-1, 1]', target='imagenet-norm')  # (N, 3, 96, 96), imagenet-normed

        # Calculate VGG feature maps for the super-resolved (SR) and high resolution (HR) images
        sr_imgs_in_vgg_space = truncated_vgg19(sr_imgs)
        hr_imgs_in_vgg_space = truncated_vgg19(hr_imgs).detach()  # detached because they're constant, targets

        # Discriminate super-resolved (SR) images
        sr_discriminated = discriminator(sr_imgs)  # (N)

        # Calculate the Perceptual loss
        content_loss = content_loss_criterion(sr_imgs_in_vgg_space, hr_imgs_in_vgg_space)
        adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.ones_like(sr_discriminated))
        perceptual_loss = content_loss + beta * adversarial_loss

        # Back-prop.
        optimizer_g.zero_grad()
        perceptual_loss.backward()

        # Clip gradients, if necessary
        if grad_clip is not None:
            clip_gradient(optimizer_g, grad_clip)

        # Update generator
        optimizer_g.step()

        # Keep track of loss
        losses_c.update(content_loss.item(), lr_imgs.size(0))
        losses_a.update(adversarial_loss.item(), lr_imgs.size(0))

        # DISCRIMINATOR UPDATE

        # Discriminate super-resolution (SR) and high-resolution (HR) images
        hr_discriminated = discriminator(hr_imgs)
        sr_discriminated = discriminator(sr_imgs.detach())
        # But didn't we already discriminate the SR images earlier, before updating the generator (G)? Why not just use that here?
        # Because, if we used that, we'd be back-propagating (finding gradients) over the G too when backward() is called
        # It's actually faster to detach the SR images from the G and forward-prop again, than to back-prop. over the G unnecessarily
        # See FAQ section in the tutorial

        # Binary Cross-Entropy loss
        adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.zeros_like(sr_discriminated)) + \
                           adversarial_loss_criterion(hr_discriminated, torch.ones_like(hr_discriminated))

        # Back-prop.
        optimizer_d.zero_grad()
        adversarial_loss.backward()

        # Clip gradients, if necessary
        if grad_clip is not None:
            clip_gradient(optimizer_d, grad_clip)

        # Update discriminator
        optimizer_d.step()

        # Keep track of loss
        losses_d.update(adversarial_loss.item(), hr_imgs.size(0))

        # Keep track of batch times
        batch_time.update(time.time() - start)

        # Reset start time
        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]----'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})----'
                  'Data Time {data_time.val:.3f} ({data_time.avg:.3f})----'
                  'Cont. Loss {loss_c.val:.4f} ({loss_c.avg:.4f})----'
                  'Adv. Loss {loss_a.val:.4f} ({loss_a.avg:.4f})----'
                  'Disc. Loss {loss_d.val:.4f} ({loss_d.avg:.4f})'.format(epoch,
                                                                          i,
                                                                          len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time,
                                                                          loss_c=losses_c,
                                                                          loss_a=losses_a,
                                                                          loss_d=losses_d))
            idx = idx + 1
            filewriter.writerow([idx,epoch,
                                 '{loss_c.val:.4f}'.format(loss_c = losses_c),
                                 '{loss_a.val:.4f}'.format(loss_a = losses_a),
                                 '{loss_d.val:.4f}'.format(loss_d = losses_d)])
            
    del lr_imgs, hr_imgs, sr_imgs, hr_imgs_in_vgg_space, sr_imgs_in_vgg_space, hr_discriminated, sr_discriminated  # free some memory since their histories may be stored


if __name__ == '__main__':
    main()

end_time = datetime.now()
print('Duration: {}'.format(end_time - start_time))


Loaded weights from pre-trained SRResNet.

Epoch: [0][0/36]----Batch Time 0.479 (0.479)----Data Time 0.080 (0.080)----Cont. Loss 0.0403 (0.0403)----Adv. Loss 0.6628 (0.6628)----Disc. Loss 1.3912 (1.3912)
Epoch: [0][1/36]----Batch Time 0.468 (0.473)----Data Time 0.088 (0.084)----Cont. Loss 0.3169 (0.1786)----Adv. Loss 2.4977 (1.5803)----Disc. Loss 2.5913 (1.9913)
Epoch: [0][2/36]----Batch Time 0.452 (0.466)----Data Time 0.074 (0.080)----Cont. Loss 0.0663 (0.1412)----Adv. Loss 0.1663 (1.1090)----Disc. Loss 2.0993 (2.0273)
Epoch: [0][3/36]----Batch Time 0.454 (0.463)----Data Time 0.074 (0.079)----Cont. Loss 0.0653 (0.1222)----Adv. Loss 0.2293 (0.8891)----Disc. Loss 1.8676 (1.9874)
Epoch: [0][4/36]----Batch Time 0.452 (0.461)----Data Time 0.074 (0.078)----Cont. Loss 0.0683 (0.1114)----Adv. Loss 0.6390 (0.8390)----Disc. Loss 1.3948 (1.8689)
Epoch: [0][5/36]----Batch Time 0.453 (0.459)----Data Time 0.075 (0.077)----Cont. Loss 0.0733 (0.1051)----Adv. Loss 1.0815 (0.8795)----Disc. Loss 1.5040

Epoch: [1][15/36]----Batch Time 0.454 (0.458)----Data Time 0.075 (0.075)----Cont. Loss 0.0414 (0.0434)----Adv. Loss 0.6907 (0.7030)----Disc. Loss 1.3807 (1.3840)
Epoch: [1][16/36]----Batch Time 0.460 (0.458)----Data Time 0.075 (0.075)----Cont. Loss 0.0484 (0.0437)----Adv. Loss 0.6885 (0.7022)----Disc. Loss 1.3823 (1.3839)
Epoch: [1][17/36]----Batch Time 0.455 (0.458)----Data Time 0.075 (0.075)----Cont. Loss 0.0407 (0.0435)----Adv. Loss 0.7078 (0.7025)----Disc. Loss 1.3812 (1.3837)
Epoch: [1][18/36]----Batch Time 0.456 (0.457)----Data Time 0.075 (0.075)----Cont. Loss 0.0525 (0.0440)----Adv. Loss 0.7130 (0.7031)----Disc. Loss 1.3795 (1.3835)
Epoch: [1][19/36]----Batch Time 0.454 (0.457)----Data Time 0.075 (0.075)----Cont. Loss 0.0347 (0.0435)----Adv. Loss 0.7156 (0.7037)----Disc. Loss 1.3789 (1.3833)
Epoch: [1][20/36]----Batch Time 0.457 (0.457)----Data Time 0.076 (0.075)----Cont. Loss 0.0377 (0.0433)----Adv. Loss 0.7277 (0.7048)----Disc. Loss 1.3787 (1.3831)
Epoch: [1][21/36]----Batch T

Epoch: [2][30/36]----Batch Time 0.456 (0.456)----Data Time 0.075 (0.074)----Cont. Loss 0.0355 (0.0382)----Adv. Loss 1.0959 (0.9345)----Disc. Loss 1.3765 (1.3032)
Epoch: [2][31/36]----Batch Time 0.455 (0.456)----Data Time 0.073 (0.074)----Cont. Loss 0.0334 (0.0381)----Adv. Loss 1.1011 (0.9397)----Disc. Loss 1.3603 (1.3050)
Epoch: [2][32/36]----Batch Time 0.456 (0.456)----Data Time 0.075 (0.074)----Cont. Loss 0.0436 (0.0382)----Adv. Loss 1.0990 (0.9445)----Disc. Loss 1.3458 (1.3062)
Epoch: [2][33/36]----Batch Time 0.456 (0.456)----Data Time 0.075 (0.074)----Cont. Loss 0.0404 (0.0383)----Adv. Loss 1.0203 (0.9468)----Disc. Loss 1.3041 (1.3062)
Epoch: [2][34/36]----Batch Time 0.458 (0.456)----Data Time 0.076 (0.074)----Cont. Loss 0.0415 (0.0384)----Adv. Loss 0.7964 (0.9425)----Disc. Loss 1.2394 (1.3042)
Epoch: [2][35/36]----Batch Time 0.339 (0.453)----Data Time 0.053 (0.074)----Cont. Loss 0.0365 (0.0383)----Adv. Loss 0.6563 (0.9367)----Disc. Loss 1.2573 (1.3033)
Epoch: [3][0/36]----Batch Ti

Epoch: [4][9/36]----Batch Time 0.458 (0.458)----Data Time 0.075 (0.075)----Cont. Loss 0.0402 (0.0377)----Adv. Loss 0.8319 (0.6557)----Disc. Loss 1.3473 (1.3691)
Epoch: [4][10/36]----Batch Time 0.456 (0.458)----Data Time 0.074 (0.075)----Cont. Loss 0.0358 (0.0375)----Adv. Loss 0.8523 (0.6736)----Disc. Loss 1.3438 (1.3668)
Epoch: [4][11/36]----Batch Time 0.457 (0.457)----Data Time 0.074 (0.075)----Cont. Loss 0.0345 (0.0373)----Adv. Loss 0.9148 (0.6937)----Disc. Loss 1.3432 (1.3648)
Epoch: [4][12/36]----Batch Time 0.457 (0.457)----Data Time 0.075 (0.075)----Cont. Loss 0.0342 (0.0370)----Adv. Loss 0.8499 (0.7057)----Disc. Loss 1.3341 (1.3625)
Epoch: [4][13/36]----Batch Time 0.456 (0.457)----Data Time 0.074 (0.075)----Cont. Loss 0.0341 (0.0368)----Adv. Loss 0.7981 (0.7123)----Disc. Loss 1.3126 (1.3589)
Epoch: [4][14/36]----Batch Time 0.456 (0.457)----Data Time 0.074 (0.075)----Cont. Loss 0.0323 (0.0365)----Adv. Loss 0.7424 (0.7143)----Disc. Loss 1.2796 (1.3536)
Epoch: [4][15/36]----Batch Ti

Epoch: [5][24/36]----Batch Time 0.456 (0.460)----Data Time 0.075 (0.075)----Cont. Loss 0.0341 (0.0379)----Adv. Loss 0.4521 (1.7620)----Disc. Loss 1.5894 (1.9611)
Epoch: [5][25/36]----Batch Time 0.457 (0.460)----Data Time 0.074 (0.074)----Cont. Loss 0.0380 (0.0379)----Adv. Loss 0.6186 (1.7180)----Disc. Loss 1.4982 (1.9433)
Epoch: [5][26/36]----Batch Time 0.457 (0.460)----Data Time 0.075 (0.075)----Cont. Loss 0.0397 (0.0380)----Adv. Loss 0.8539 (1.6860)----Disc. Loss 1.4935 (1.9266)
Epoch: [5][27/36]----Batch Time 0.456 (0.460)----Data Time 0.074 (0.074)----Cont. Loss 0.0329 (0.0378)----Adv. Loss 0.9836 (1.6609)----Disc. Loss 1.4881 (1.9109)
Epoch: [5][28/36]----Batch Time 0.461 (0.460)----Data Time 0.075 (0.074)----Cont. Loss 0.0319 (0.0376)----Adv. Loss 1.1425 (1.6430)----Disc. Loss 1.5560 (1.8987)
Epoch: [5][29/36]----Batch Time 0.459 (0.460)----Data Time 0.075 (0.075)----Cont. Loss 0.0389 (0.0377)----Adv. Loss 1.2065 (1.6285)----Disc. Loss 1.5600 (1.8874)
Epoch: [5][30/36]----Batch T

Epoch: [7][3/36]----Batch Time 0.457 (0.463)----Data Time 0.075 (0.075)----Cont. Loss 0.0415 (0.0373)----Adv. Loss 2.4043 (1.9801)----Disc. Loss 0.3105 (0.4331)
Epoch: [7][4/36]----Batch Time 0.457 (0.462)----Data Time 0.075 (0.075)----Cont. Loss 0.0360 (0.0371)----Adv. Loss 2.3141 (2.0469)----Disc. Loss 0.2707 (0.4006)
Epoch: [7][5/36]----Batch Time 0.457 (0.461)----Data Time 0.074 (0.075)----Cont. Loss 0.0356 (0.0368)----Adv. Loss 2.5648 (2.1332)----Disc. Loss 0.2240 (0.3712)
Epoch: [7][6/36]----Batch Time 0.461 (0.461)----Data Time 0.074 (0.075)----Cont. Loss 0.0383 (0.0370)----Adv. Loss 2.6369 (2.2052)----Disc. Loss 0.1855 (0.3446)
Epoch: [7][7/36]----Batch Time 0.462 (0.461)----Data Time 0.075 (0.075)----Cont. Loss 0.0329 (0.0365)----Adv. Loss 2.8851 (2.2902)----Disc. Loss 0.1413 (0.3192)
Epoch: [7][8/36]----Batch Time 0.456 (0.460)----Data Time 0.074 (0.075)----Cont. Loss 0.0343 (0.0363)----Adv. Loss 3.5203 (2.4268)----Disc. Loss 0.1172 (0.2968)
Epoch: [7][9/36]----Batch Time 0.4

Epoch: [8][18/36]----Batch Time 0.458 (0.459)----Data Time 0.075 (0.075)----Cont. Loss 0.0373 (0.0371)----Adv. Loss 4.1382 (2.8272)----Disc. Loss 0.0963 (0.3313)
Epoch: [8][19/36]----Batch Time 0.456 (0.459)----Data Time 0.074 (0.075)----Cont. Loss 0.0348 (0.0369)----Adv. Loss 4.2538 (2.8985)----Disc. Loss 0.1065 (0.3200)
Epoch: [8][20/36]----Batch Time 0.458 (0.459)----Data Time 0.073 (0.075)----Cont. Loss 0.0325 (0.0367)----Adv. Loss 4.1836 (2.9597)----Disc. Loss 0.0519 (0.3073)
Epoch: [8][21/36]----Batch Time 0.457 (0.459)----Data Time 0.075 (0.075)----Cont. Loss 0.0324 (0.0365)----Adv. Loss 4.5089 (3.0301)----Disc. Loss 0.0339 (0.2948)
Epoch: [8][22/36]----Batch Time 0.461 (0.459)----Data Time 0.073 (0.074)----Cont. Loss 0.0359 (0.0365)----Adv. Loss 4.9194 (3.1123)----Disc. Loss 0.0539 (0.2844)
Epoch: [8][23/36]----Batch Time 0.456 (0.459)----Data Time 0.074 (0.074)----Cont. Loss 0.0354 (0.0365)----Adv. Loss 4.5031 (3.1702)----Disc. Loss 0.0246 (0.2735)
Epoch: [8][24/36]----Batch T

Epoch: [9][33/36]----Batch Time 0.456 (0.459)----Data Time 0.075 (0.075)----Cont. Loss 0.0397 (0.0383)----Adv. Loss 3.2386 (4.0118)----Disc. Loss 0.2046 (0.4134)
Epoch: [9][34/36]----Batch Time 0.461 (0.459)----Data Time 0.078 (0.075)----Cont. Loss 0.0366 (0.0382)----Adv. Loss 3.6153 (4.0005)----Disc. Loss 0.2465 (0.4087)
Epoch: [9][35/36]----Batch Time 0.341 (0.456)----Data Time 0.055 (0.074)----Cont. Loss 0.0344 (0.0381)----Adv. Loss 2.6556 (3.9734)----Disc. Loss 0.2581 (0.4056)
