In [1]:
import time
import torch.backends.cudnn as cudnn
from torch import nn
from easydict import EasyDict as edict
from models import Generator, Discriminator, TruncatedVGG19
from datasets import SRDataset
from utils import *
from solver import train

In [2]:
# config
config = edict()
config.csv_folder = '../data/SRDataset'
config.HR_data_folder = '../data/SRDataset/div2k/DIV2K_train_HR'
config.LR_data_folder = '../data/SRDataset/div2k/DIV2K_train_LR_bicubic/X4'
config.crop_size = 96
config.scaling_factor = 4

# Generator parameters
config.G = edict()
config.G.large_kernel_size = 9
config.G.small_kernel_size = 3
config.G.n_channels = 64
config.G.n_blocks = 16

# Discriminator parameters
config.D = edict()
config.D.kernel_size = 3
config.D.n_channels = 64
config.D.n_blocks = 8
config.D.fc_size = 1024

# Learning parameters
config.checkpoint = None # path to model (SRGAN) checkpoint, None if none
config.batch_size = 16
config.start_epoch = 0
config.epochs = 2
config.workers = 4
config.vgg19_i = 5  # the index i in the definition for VGG loss; see paper
config.vgg19_j = 4  # the index j in the definition for VGG loss; see paper
config.beta = 1e-3  # the coefficient to weight the adversarial loss in the perceptual loss
config.print_freq = 50
config.lr = 1e-4

# Default device
# config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config.device = "mps"
cudnn.benchmark = True

In [3]:
# Truncated VGG19 network to be used in the loss calculation
truncated_vgg19 = TruncatedVGG19(i=config.vgg19_i, j=config.vgg19_j)
truncated_vgg19.eval()



TruncatedVGG19(
  (truncated_vgg19): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), s

In [4]:
if config.checkpoint is None:
    # Generator
    generator = Generator()

    # Initialize generator's optimizer
    optimizer_g = torch.optim.SGD(params=filter(lambda p: p.requires_grad, generator.parameters()),
                                   lr=config.lr, momentum=0.8)

    # Discriminator
    discriminator = Discriminator()
    optimizer_d = torch.optim.SGD(params=filter(lambda p: p.requires_grad, discriminator.parameters()),
                                   lr=config.lr, momentum=0.8)

else:
    checkpoint = torch.load(config.checkpoint)
    config.start_epoch = checkpoint['epoch'] + 1
    generator = checkpoint['generator']
    discriminator = checkpoint['discriminator']
    optimizer_g = checkpoint['optimizer_g']
    optimizer_d = checkpoint['optimizer_d']
    print("\nLoaded checkpoint from epoch %d.\n" % (checkpoint['epoch'] + 1))

In [5]:
# Loss functions
content_loss_criterion = nn.MSELoss()
adversarial_loss_criterion = nn.BCEWithLogitsLoss()

In [6]:
# Move to default device
generator = generator.to(config.device)
discriminator = discriminator.to(config.device)
truncated_vgg19 = truncated_vgg19.to(config.device)
content_loss_criterion = content_loss_criterion.to(config.device)
adversarial_loss_criterion = adversarial_loss_criterion.to(config.device)

In [7]:
# Custom dataloaders
train_dataset = SRDataset(split='train', config=config)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=config.batch_size,
                                           shuffle=True, 
                                           num_workers=config.workers,
                                           pin_memory=True)

In [8]:
# Epochs
for epoch in range(config.start_epoch, config.epochs):
    # At the halfway point, reduce learning rate to a tenth
    if epoch == int(config.epochs / 2 + 1):
        adjust_learning_rate(optimizer_g, 0.1)
        adjust_learning_rate(optimizer_d, 0.1)
    # One epoch's training
    train(train_loader=train_loader,
          generator=generator,
          discriminator=discriminator,
          truncated_vgg19=truncated_vgg19,
          content_loss_criterion=content_loss_criterion,
          adversarial_loss_criterion=adversarial_loss_criterion,
          optimizer_g=optimizer_g,
          optimizer_d=optimizer_d,
          epoch=epoch,
          device=config.device,
          beta=config.beta,
          print_freq=config.print_freq)
    # Save checkpoint
    torch.save({'epoch': epoch,
                'generator': generator,
                'discriminator': discriminator,
                'optimizer_g': optimizer_g,
                'optimizer_d': optimizer_d},
                'checkpoint_srgan.pth.tar')

Epoch: [0][0/5000]----Batch Time 3.132 (3.132)----Data Time 1.726 (1.726)----Cont. Loss 0.4878 (0.4878)----Adv. Loss 0.6719 (0.6719)----Disc. Loss 1.4287 (1.4287)
Epoch: [0][50/5000]----Batch Time 0.701 (0.747)----Data Time 0.000 (0.034)----Cont. Loss 0.4172 (0.4619)----Adv. Loss 0.9919 (0.8379)----Disc. Loss 0.9982 (1.2050)
Epoch: [0][100/5000]----Batch Time 0.689 (0.721)----Data Time 0.000 (0.018)----Cont. Loss 0.5545 (0.4507)----Adv. Loss 1.2362 (0.9899)----Disc. Loss 0.7733 (1.0284)
Epoch: [0][150/5000]----Batch Time 0.701 (0.711)----Data Time 0.000 (0.012)----Cont. Loss 0.4524 (0.4498)----Adv. Loss 1.6291 (1.1531)----Disc. Loss 0.4719 (0.8766)
Epoch: [0][200/5000]----Batch Time 0.688 (0.706)----Data Time 0.000 (0.009)----Cont. Loss 0.2751 (0.4485)----Adv. Loss 2.0575 (1.3165)----Disc. Loss 0.2821 (0.7547)
Epoch: [0][250/5000]----Batch Time 0.704 (0.704)----Data Time 0.000 (0.007)----Cont. Loss 0.3101 (0.4463)----Adv. Loss 2.2950 (1.4722)----Disc. Loss 0.2161 (0.6592)
Epoch: [0][30

Epoch: [0][2500/5000]----Batch Time 0.689 (0.694)----Data Time 0.000 (0.001)----Cont. Loss 0.4618 (0.4282)----Adv. Loss 4.7620 (3.7272)----Disc. Loss 0.0160 (0.1082)
Epoch: [0][2550/5000]----Batch Time 0.685 (0.694)----Data Time 0.000 (0.001)----Cont. Loss 0.3906 (0.4277)----Adv. Loss 4.8015 (3.7482)----Disc. Loss 0.0159 (0.1064)
Epoch: [0][2600/5000]----Batch Time 0.685 (0.694)----Data Time 0.000 (0.001)----Cont. Loss 0.4216 (0.4274)----Adv. Loss 4.7786 (3.7687)----Disc. Loss 0.0166 (0.1046)
Epoch: [0][2650/5000]----Batch Time 0.693 (0.694)----Data Time 0.000 (0.001)----Cont. Loss 0.3450 (0.4275)----Adv. Loss 4.8557 (3.7888)----Disc. Loss 0.0144 (0.1030)
Epoch: [0][2700/5000]----Batch Time 0.702 (0.694)----Data Time 0.000 (0.001)----Cont. Loss 0.5169 (0.4275)----Adv. Loss 4.8790 (3.8084)----Disc. Loss 0.0142 (0.1013)
Epoch: [0][2750/5000]----Batch Time 0.694 (0.694)----Data Time 0.000 (0.001)----Cont. Loss 0.2882 (0.4272)----Adv. Loss 4.9129 (3.8278)----Disc. Loss 0.0140 (0.0997)
Epoc

Epoch: [1][0/5000]----Batch Time 2.494 (2.494)----Data Time 1.740 (1.740)----Cont. Loss 0.2939 (0.2939)----Adv. Loss 5.5233 (5.5233)----Disc. Loss 0.0074 (0.0074)
Epoch: [1][50/5000]----Batch Time 0.686 (0.732)----Data Time 0.000 (0.035)----Cont. Loss 0.4179 (0.4172)----Adv. Loss 5.5546 (5.5296)----Disc. Loss 0.0074 (0.0076)
Epoch: [1][100/5000]----Batch Time 0.685 (0.710)----Data Time 0.000 (0.018)----Cont. Loss 0.4462 (0.4146)----Adv. Loss 5.5230 (5.5293)----Disc. Loss 0.0076 (0.0076)
Epoch: [1][150/5000]----Batch Time 0.685 (0.703)----Data Time 0.000 (0.012)----Cont. Loss 0.4878 (0.4117)----Adv. Loss 5.5455 (5.5319)----Disc. Loss 0.0074 (0.0076)
Epoch: [1][200/5000]----Batch Time 0.691 (0.699)----Data Time 0.000 (0.009)----Cont. Loss 0.2407 (0.4113)----Adv. Loss 5.5685 (5.5358)----Disc. Loss 0.0069 (0.0076)
Epoch: [1][250/5000]----Batch Time 0.681 (0.697)----Data Time 0.000 (0.007)----Cont. Loss 0.4119 (0.4126)----Adv. Loss 5.5607 (5.5416)----Disc. Loss 0.0073 (0.0076)
Epoch: [1][30

  out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
  out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
  stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0,
  stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0,
  out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
  out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
  stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0,
  stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0,
  out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
  out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
  stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0,
  stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0,
  out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
  out_channels = (n_

KeyboardInterrupt: 