In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from srresnet import _NetG


In [2]:
from torchvision.datasets import ImageFolder
from torchvision import transforms as tf
from torch.utils.data import Dataset, DataLoader

In [3]:
class SRDataset(Dataset):
    def __init__(self,root, cropsize=512, downsample_factor=4):
        self.imagefolder = ImageFolder(root,transform=tf.RandomCrop(cropsize))
        self.cropsize = cropsize
        self.ds_size = int(cropsize/downsample_factor)
    
    def __len__(self):
        return len(self.imagefolder)
    def __getitem__(self,idx):
        HR_img,_ = self.imagefolder[idx]
        LR_img = tf.Resize(size=self.ds_size)(HR_img)
        tt = tf.ToTensor()
        return (tt(LR_img),tt(HR_img))
    

In [36]:
import argparse
import math, random
import os, time

In [5]:
# Training settings
parser = argparse.ArgumentParser(description="PyTorch SRResNet")
parser.add_argument("--batchSize", type=int, default=16, help="training batch size")
parser.add_argument("--nEpochs", type=int, default=500, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4")
parser.add_argument("--step", type=int, default=200, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=500")
parser.add_argument("--cuda", action="store_true", help="Use cuda?")
parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
parser.add_argument("--threads", type=int, default=0, help="Number of threads for data loader to use, Default: 1")
parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)")
parser.add_argument("--vgg_loss", action="store_true", help="Use content loss?")
parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")

_StoreAction(option_strings=['--gpus'], dest='gpus', nargs=None, const=None, default='0', type=<class 'str'>, choices=None, help='gpu ids (default: 0)', metavar=None)

In [32]:
args = ['--cuda']

opt,_ = parser.parse_known_args(args)
print(opt)

Namespace(batchSize=16, cuda=True, gpus='0', lr=0.0001, nEpochs=500, pretrained='', resume='', start_epoch=1, step=200, threads=0, vgg_loss=False)


In [33]:
dataset = SRDataset('./data/DIV2K_train_HR/')
training_data_loader = DataLoader(dataset=dataset, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

In [45]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10"""
    lr = opt.lr * (0.1 ** (epoch // opt.step))
    return lr 

def train(training_data_loader, optimizer, model, criterion, epoch):

    lr = adjust_learning_rate(optimizer, epoch-1)
    
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    print("Epoch={}, lr={}".format(epoch, optimizer.param_groups[0]["lr"]))
    model.train()
    epoch_start = time.time()
    iter_start = epoch_start
    for iteration, batch in enumerate(training_data_loader, 1):

        input, target = batch#Variable(batch[0]), Variable(batch[1], requires_grad=False)

        if opt.cuda:
            input = input.cuda()
            target = target.cuda()

        output = model(input)
        loss = criterion(output, target)

        if opt.vgg_loss:
            content_input = netContent(output)
            content_target = netContent(target)
            content_target = content_target.detach()
            content_loss = criterion(content_input, content_target)

        optimizer.zero_grad()

        if opt.vgg_loss:
            netContent.zero_grad()
            content_loss.backward(retain_graph=True)

        loss.backward()

        optimizer.step()

        if iteration%5 == 0:
            if opt.vgg_loss:
                print("===> Epoch[{}]({}/{}): Loss: {:.5} Content_loss {:.5}".format(epoch, iteration, len(training_data_loader), loss.data[0], content_loss.data[0]))
            else:
                cur_time = time.time()
                print("===> Epoch[{}]({}/{}): Loss: {:.5}, {:.5} sec".format(epoch, iteration, \
                                                        len(training_data_loader), loss.item(),cur_time-iter_start))
                iter_start = cur_time
def save_checkpoint(model, epoch):
    model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch)
    state = {"epoch": epoch ,"model": model}
    if not os.path.exists("checkpoint/"):
        os.makedirs("checkpoint/")

    torch.save(state, model_out_path)

    print("Checkpoint saved to {}".format(model_out_path))

In [None]:
global opt, model, netContent


cuda = opt.cuda
if cuda:
    print("=> use gpu id: '{}'".format(opt.gpus))
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")

opt.seed = random.randint(1, 10000)
print("Random Seed: ", opt.seed)
torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)



print("===> Building model")
model = _NetG()
criterion = nn.MSELoss(size_average=False)

print("===> Setting GPU")
if cuda:
    model = model.cuda()
    criterion = criterion.cuda()

# optionally resume from a checkpoint
if opt.resume:
    if os.path.isfile(opt.resume):
        print("=> loading checkpoint '{}'".format(opt.resume))
        checkpoint = torch.load(opt.resume)
        opt.start_epoch = checkpoint["epoch"] + 1
        model.load_state_dict(checkpoint["model"].state_dict())
    else:
        print("=> no checkpoint found at '{}'".format(opt.resume))

# # optionally copy weights from a checkpoint
# if opt.pretrained:
#     if os.path.isfile(opt.pretrained):
#         print("=> loading model '{}'".format(opt.pretrained))
#         weights = torch.load(opt.pretrained)
#         model.load_state_dict(weights['model'].state_dict())
#     else:
#         print("=> no model found at '{}'".format(opt.pretrained))

print("===> Setting Optimizer")
optimizer = optim.Adam(model.parameters(), lr=opt.lr)

print("===> Training")
for epoch in range(opt.start_epoch, opt.nEpochs + 1):
    train(training_data_loader, optimizer, model, criterion, epoch)
    save_checkpoint(model, epoch)

=> use gpu id: '0'
Random Seed:  6683
===> Building model
===> Setting GPU
===> Setting Optimizer
===> Training
Epoch=1, lr=0.0001




===> Epoch[1](5/50): Loss: 6.9615e+06, 38.578 sec
===> Epoch[1](10/50): Loss: 4.4263e+06, 38.078 sec
===> Epoch[1](15/50): Loss: 2.6299e+06, 40.091 sec
===> Epoch[1](20/50): Loss: 1.7302e+06, 39.888 sec
===> Epoch[1](25/50): Loss: 1.2659e+06, 42.608 sec
===> Epoch[1](30/50): Loss: 1.0826e+06, 39.608 sec
===> Epoch[1](35/50): Loss: 8.4497e+05, 42.51 sec
===> Epoch[1](40/50): Loss: 7.2048e+05, 43.045 sec
===> Epoch[1](45/50): Loss: 5.9632e+05, 42.384 sec
===> Epoch[1](50/50): Loss: 4.9184e+05, 42.812 sec
Checkpoint saved to checkpoint/model_epoch_1.pth
Epoch=2, lr=0.0001
===> Epoch[2](5/50): Loss: 4.8988e+05, 38.561 sec
===> Epoch[2](10/50): Loss: 3.9906e+05, 36.034 sec
===> Epoch[2](15/50): Loss: 4.3013e+05, 38.633 sec
===> Epoch[2](20/50): Loss: 4.5071e+05, 41.193 sec
===> Epoch[2](25/50): Loss: 3.8299e+05, 49.189 sec
===> Epoch[2](30/50): Loss: 4.0605e+05, 43.073 sec
===> Epoch[2](35/50): Loss: 3.0479e+05, 45.39 sec
===> Epoch[2](40/50): Loss: 2.9909e+05, 41.06 sec
===> Epoch[2](45/50

===> Epoch[15](15/50): Loss: 1.2422e+05, 42.727 sec
===> Epoch[15](20/50): Loss: 8.7748e+04, 45.023 sec
===> Epoch[15](25/50): Loss: 1.3539e+05, 42.092 sec
===> Epoch[15](30/50): Loss: 7.1163e+04, 45.776 sec
===> Epoch[15](35/50): Loss: 5.8893e+04, 43.98 sec
===> Epoch[15](40/50): Loss: 8.3425e+04, 44.402 sec
===> Epoch[15](45/50): Loss: 1.1226e+05, 42.16 sec
===> Epoch[15](50/50): Loss: 7.5799e+04, 41.968 sec
Checkpoint saved to checkpoint/model_epoch_15.pth
Epoch=16, lr=0.0001
===> Epoch[16](5/50): Loss: 8.2504e+04, 42.77 sec
===> Epoch[16](10/50): Loss: 6.4892e+04, 37.698 sec
===> Epoch[16](15/50): Loss: 9.6299e+04, 44.776 sec
===> Epoch[16](20/50): Loss: 6.5285e+04, 48.17 sec
===> Epoch[16](25/50): Loss: 6.4993e+04, 46.671 sec
===> Epoch[16](30/50): Loss: 1.0002e+05, 42.818 sec
===> Epoch[16](35/50): Loss: 5.416e+04, 44.489 sec
===> Epoch[16](40/50): Loss: 8.534e+04, 44.939 sec
===> Epoch[16](45/50): Loss: 5.4126e+04, 47.058 sec
===> Epoch[16](50/50): Loss: 8.6746e+04, 43.107 sec
C

===> Epoch[29](15/50): Loss: 6.8529e+04, 46.001 sec
===> Epoch[29](20/50): Loss: 6.1478e+04, 45.689 sec
===> Epoch[29](25/50): Loss: 7.7218e+04, 45.243 sec
===> Epoch[29](30/50): Loss: 6.568e+04, 47.982 sec
===> Epoch[29](35/50): Loss: 7.1655e+04, 47.535 sec
===> Epoch[29](40/50): Loss: 5.789e+04, 47.197 sec
===> Epoch[29](45/50): Loss: 7.089e+04, 45.84 sec
===> Epoch[29](50/50): Loss: 7.3177e+04, 47.315 sec
Checkpoint saved to checkpoint/model_epoch_29.pth
Epoch=30, lr=0.0001
===> Epoch[30](5/50): Loss: 7.5143e+04, 46.242 sec
===> Epoch[30](10/50): Loss: 6.2938e+04, 38.35 sec
===> Epoch[30](15/50): Loss: 9.024e+04, 42.73 sec
===> Epoch[30](20/50): Loss: 8.8133e+04, 45.78 sec
===> Epoch[30](25/50): Loss: 6.4394e+04, 46.007 sec
===> Epoch[30](30/50): Loss: 7.7392e+04, 45.902 sec
===> Epoch[30](35/50): Loss: 5.9338e+04, 40.879 sec
===> Epoch[30](40/50): Loss: 8.0174e+04, 45.205 sec
===> Epoch[30](45/50): Loss: 8.6627e+04, 43.97 sec
===> Epoch[30](50/50): Loss: 5.711e+04, 44.618 sec
Check

===> Epoch[43](15/50): Loss: 5.8968e+04, 41.812 sec
===> Epoch[43](20/50): Loss: 7.0219e+04, 47.077 sec
===> Epoch[43](25/50): Loss: 5.7639e+04, 43.622 sec
===> Epoch[43](30/50): Loss: 7.0916e+04, 48.237 sec
===> Epoch[43](35/50): Loss: 5.0032e+04, 41.334 sec
===> Epoch[43](40/50): Loss: 3.9835e+04, 45.541 sec
===> Epoch[43](45/50): Loss: 4.1159e+04, 44.885 sec
===> Epoch[43](50/50): Loss: 4.6625e+04, 47.077 sec
Checkpoint saved to checkpoint/model_epoch_43.pth
Epoch=44, lr=0.0001
===> Epoch[44](5/50): Loss: 4.952e+04, 35.116 sec
===> Epoch[44](10/50): Loss: 4.9444e+04, 40.5 sec
===> Epoch[44](15/50): Loss: 7.9869e+04, 43.99 sec
===> Epoch[44](20/50): Loss: 5.2326e+04, 44.191 sec
===> Epoch[44](25/50): Loss: 6.1542e+04, 45.505 sec
===> Epoch[44](30/50): Loss: 4.8606e+04, 48.862 sec
===> Epoch[44](35/50): Loss: 3.6115e+04, 42.437 sec
===> Epoch[44](40/50): Loss: 5.5455e+04, 44.119 sec
===> Epoch[44](45/50): Loss: 5.5143e+04, 49.154 sec
===> Epoch[44](50/50): Loss: 5.0833e+04, 48.784 sec

===> Epoch[57](15/50): Loss: 4.4363e+04, 46.972 sec
===> Epoch[57](20/50): Loss: 6.6825e+04, 54.679 sec
===> Epoch[57](25/50): Loss: 3.5238e+04, 41.916 sec
===> Epoch[57](30/50): Loss: 2.3898e+04, 45.521 sec
===> Epoch[57](35/50): Loss: 6.1028e+04, 46.717 sec
===> Epoch[57](40/50): Loss: 5.8975e+04, 41.06 sec
===> Epoch[57](45/50): Loss: 6.789e+04, 47.76 sec
===> Epoch[57](50/50): Loss: 5.7815e+04, 46.188 sec
Checkpoint saved to checkpoint/model_epoch_57.pth
Epoch=58, lr=0.0001
===> Epoch[58](5/50): Loss: 5.0534e+04, 39.858 sec
===> Epoch[58](10/50): Loss: 6.5478e+04, 42.867 sec
===> Epoch[58](15/50): Loss: 4.7609e+04, 46.605 sec
===> Epoch[58](20/50): Loss: 4.9202e+04, 47.171 sec
===> Epoch[58](25/50): Loss: 3.1351e+04, 49.526 sec
===> Epoch[58](30/50): Loss: 4.5442e+04, 47.175 sec
===> Epoch[58](35/50): Loss: 4.0005e+04, 48.735 sec
===> Epoch[58](40/50): Loss: 2.9117e+04, 49.519 sec
===> Epoch[58](45/50): Loss: 5.0135e+04, 47.697 sec
===> Epoch[58](50/50): Loss: 6.6482e+04, 44.775 se

===> Epoch[71](15/50): Loss: 3.4317e+04, 42.889 sec
===> Epoch[71](20/50): Loss: 6.4946e+04, 45.455 sec
===> Epoch[71](25/50): Loss: 3.6394e+04, 50.53 sec
===> Epoch[71](30/50): Loss: 6.8225e+04, 46.146 sec
===> Epoch[71](35/50): Loss: 6.162e+04, 48.989 sec
===> Epoch[71](40/50): Loss: 7.6342e+04, 48.541 sec
===> Epoch[71](45/50): Loss: 5.0728e+04, 48.704 sec
===> Epoch[71](50/50): Loss: 3.8961e+04, 44.03 sec
Checkpoint saved to checkpoint/model_epoch_71.pth
Epoch=72, lr=0.0001
===> Epoch[72](5/50): Loss: 6.3833e+04, 37.507 sec
===> Epoch[72](10/50): Loss: 4.0216e+04, 41.477 sec
===> Epoch[72](15/50): Loss: 5.4614e+04, 48.768 sec
===> Epoch[72](20/50): Loss: 5.4236e+04, 46.426 sec
===> Epoch[72](25/50): Loss: 4.525e+04, 46.992 sec
===> Epoch[72](30/50): Loss: 3.0835e+04, 51.896 sec
===> Epoch[72](35/50): Loss: 3.4769e+04, 49.619 sec
===> Epoch[72](40/50): Loss: 4.6309e+04, 44.398 sec
===> Epoch[72](45/50): Loss: 5.2316e+04, 47.647 sec
===> Epoch[72](50/50): Loss: 4.4204e+04, 48.152 sec

===> Epoch[85](15/50): Loss: 6.4493e+04, 43.841 sec
===> Epoch[85](20/50): Loss: 5.5414e+04, 47.594 sec
===> Epoch[85](25/50): Loss: 4.9092e+04, 48.103 sec
===> Epoch[85](30/50): Loss: 6.8237e+04, 46.401 sec
===> Epoch[85](35/50): Loss: 4.4536e+04, 46.114 sec
===> Epoch[85](40/50): Loss: 6.2341e+04, 45.577 sec
===> Epoch[85](45/50): Loss: 6.0822e+04, 45.579 sec
===> Epoch[85](50/50): Loss: 5.8822e+04, 49.963 sec
Checkpoint saved to checkpoint/model_epoch_85.pth
Epoch=86, lr=0.0001
===> Epoch[86](5/50): Loss: 3.8827e+04, 35.399 sec
===> Epoch[86](10/50): Loss: 4.3035e+04, 42.988 sec
===> Epoch[86](15/50): Loss: 5.9283e+04, 46.551 sec
===> Epoch[86](20/50): Loss: 4.6279e+04, 52.206 sec
===> Epoch[86](25/50): Loss: 3.801e+04, 48.819 sec
===> Epoch[86](30/50): Loss: 5.0917e+04, 50.284 sec
===> Epoch[86](35/50): Loss: 7.1288e+04, 44.022 sec
===> Epoch[86](40/50): Loss: 4.0935e+04, 44.798 sec
===> Epoch[86](45/50): Loss: 4.6332e+04, 47.219 sec
===> Epoch[86](50/50): Loss: 5.9775e+04, 54.874 

===> Epoch[99](15/50): Loss: 3.7214e+04, 47.943 sec
===> Epoch[99](20/50): Loss: 3.9526e+04, 49.612 sec
===> Epoch[99](25/50): Loss: 8.26e+04, 54.73 sec
===> Epoch[99](30/50): Loss: 6.5459e+04, 46.642 sec
===> Epoch[99](35/50): Loss: 2.2757e+04, 47.216 sec
===> Epoch[99](40/50): Loss: 6.249e+04, 47.83 sec
===> Epoch[99](45/50): Loss: 3.6848e+04, 43.076 sec
===> Epoch[99](50/50): Loss: 4.8021e+04, 46.133 sec
Checkpoint saved to checkpoint/model_epoch_99.pth
Epoch=100, lr=0.0001
===> Epoch[100](5/50): Loss: 3.5715e+04, 34.353 sec
===> Epoch[100](10/50): Loss: 5.0779e+04, 41.28 sec
===> Epoch[100](15/50): Loss: 4.2053e+04, 45.117 sec
===> Epoch[100](20/50): Loss: 4.4528e+04, 49.96 sec
===> Epoch[100](25/50): Loss: 3.5751e+04, 47.866 sec
===> Epoch[100](30/50): Loss: 5.0278e+04, 49.596 sec
===> Epoch[100](35/50): Loss: 3.6322e+04, 49.351 sec
===> Epoch[100](40/50): Loss: 6.871e+04, 48.844 sec
===> Epoch[100](45/50): Loss: 4.752e+04, 47.681 sec
===> Epoch[100](50/50): Loss: 3.8067e+04, 53.5

===> Epoch[113](5/50): Loss: 3.9653e+04, 38.311 sec
===> Epoch[113](10/50): Loss: 2.958e+04, 43.347 sec
===> Epoch[113](15/50): Loss: 3.6289e+04, 48.591 sec
===> Epoch[113](20/50): Loss: 6.0201e+04, 55.522 sec
===> Epoch[113](25/50): Loss: 5.7952e+04, 47.702 sec
===> Epoch[113](30/50): Loss: 3.5365e+04, 47.233 sec
===> Epoch[113](35/50): Loss: 4.7984e+04, 47.958 sec
===> Epoch[113](40/50): Loss: 4.2773e+04, 46.818 sec
===> Epoch[113](45/50): Loss: 9.9195e+04, 48.867 sec
===> Epoch[113](50/50): Loss: 6.7708e+04, 44.416 sec
Checkpoint saved to checkpoint/model_epoch_113.pth
Epoch=114, lr=0.0001
===> Epoch[114](5/50): Loss: 3.9884e+04, 40.794 sec
===> Epoch[114](10/50): Loss: 5.4499e+04, 44.349 sec
===> Epoch[114](15/50): Loss: 4.1009e+04, 44.943 sec
===> Epoch[114](20/50): Loss: 4.7241e+04, 46.785 sec
===> Epoch[114](25/50): Loss: 3.627e+04, 45.647 sec
===> Epoch[114](30/50): Loss: 3.0269e+04, 47.844 sec
===> Epoch[114](35/50): Loss: 5.3099e+04, 48.174 sec
===> Epoch[114](40/50): Loss: 3

===> Epoch[126](45/50): Loss: 3.9191e+04, 40.664 sec
===> Epoch[126](50/50): Loss: 6.7072e+04, 46.371 sec
Checkpoint saved to checkpoint/model_epoch_126.pth
Epoch=127, lr=0.0001
===> Epoch[127](5/50): Loss: 6.1487e+04, 35.653 sec
===> Epoch[127](10/50): Loss: 5.3008e+04, 40.141 sec
===> Epoch[127](15/50): Loss: 7.0678e+04, 42.229 sec
===> Epoch[127](20/50): Loss: 5.0942e+04, 46.841 sec
===> Epoch[127](25/50): Loss: 3.7428e+04, 45.177 sec
===> Epoch[127](30/50): Loss: 3.6417e+04, 43.917 sec
===> Epoch[127](35/50): Loss: 4.7144e+04, 45.327 sec
===> Epoch[127](40/50): Loss: 5.2675e+04, 45.762 sec
===> Epoch[127](45/50): Loss: 4.6096e+04, 49.49 sec
===> Epoch[127](50/50): Loss: 3.4254e+04, 48.112 sec
Checkpoint saved to checkpoint/model_epoch_127.pth
Epoch=128, lr=0.0001
===> Epoch[128](5/50): Loss: 6.0155e+04, 36.839 sec
===> Epoch[128](10/50): Loss: 2.7917e+04, 41.829 sec
===> Epoch[128](15/50): Loss: 4.5499e+04, 42.652 sec
===> Epoch[128](20/50): Loss: 3.7595e+04, 48.914 sec
===> Epoch[