In [1]:
import os
import easydict
import time

import numpy as np
import cv2
import math

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim

from torch.utils import data as D
from torch.utils.data import DataLoader

from tqdm import tqdm
import glob

In [2]:
opt = easydict.EasyDict({
    "multi_gpu": True,
    "use_cuda": True,
    "device": 'cpu',
    
    "n_epochs": 5000,
    "batch_size": 100,
    "start_epoch": 1,
    "batch_size": 128,
    "lr": 1e-4, # Adam: learning rate
    "b1": 0.9, # Adam: decay of first order momentum of gradient
    "b2": 0.999, # Adam: decay of second order momentum of gradient
    
    "resume": True,
    "resume_best": False,
    "checkpoint_dir": './checkpoin_dir',
    "train_dir": '../../data/super-resolution/div2k_100/train_patches_x2lr64',
    "valid_dir": '../../data/super-resolution/div2k_100/valid_patches_x2lr64',
    "test_dir": '../../data/super-resolution/div2k_100/test_images',
    
    "lr_img": 64,
    "res_scale": 2,
    "n_channels": 3
})

print(opt)

{'multi_gpu': True, 'use_cuda': True, 'device': 'cpu', 'n_epochs': 5000, 'batch_size': 128, 'start_epoch': 1, 'lr': 0.0001, 'b1': 0.9, 'b2': 0.999, 'resume': True, 'resume_best': False, 'checkpoint_dir': './checkpoin_dir', 'train_dir': '../../data/super-resolution/div2k_100/train_patches_x2lr64', 'valid_dir': '../../data/super-resolution/div2k_100/valid_patches_x2lr64', 'test_dir': '../../data/super-resolution/div2k_100/test_images', 'lr_img': 64, 'res_scale': 2, 'n_channels': 3}


In [3]:
def normalize_img(img):
    img = img / 256
    img = img.astype(np.float32)
    return img

In [4]:
class DatasetFromFolder(data.Dataset):
    def __init__(self, data_dir):
        super(DatasetFromFolder, self).__init__()

        lr_dir = os.path.join(data_dir, 'lr')
        hr_dir = os.path.join(data_dir, 'hr')
        
        self.dsets = {}
        self.dsets['lr'] = [os.path.join(lr_dir, x) for x in os.listdir(lr_dir)]
        self.dsets['hr'] = [os.path.join(hr_dir, x) for x in os.listdir(hr_dir)]

    def __getitem__(self, idx):
        input = cv2.imread(self.dsets['lr'][idx])
        target = cv2.imread(self.dsets['hr'][idx])

        input = cv2.resize(input, (target.shape[1], target.shape[0]), interpolation=cv2.INTER_CUBIC)
        
        input = normalize_img(input)
        target = normalize_img(target)

        input = np.transpose(input, (2, 0, 1))
        target = np.transpose(target, (2, 0, 1))

        input = torch.from_numpy(input).type(torch.FloatTensor)
        target = torch.from_numpy(target).type(torch.FloatTensor)

        return input, target

    def __len__(self):
        return len(self.dsets['lr'])

In [5]:
class MeanShift(nn.Conv2d):
    def __init__(
        self, rgb_range,
        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False

In [6]:
class SRCNN(nn.Module):
    def __init__(self, opt):
        super(SRCNN, self).__init__()
        pix_range = 1.0
        
        self.sub_mean = MeanShift(pix_range)
        self.add_mean = MeanShift(pix_range, sign=1)
        
        self.conv1 = nn.Conv2d(opt.n_channels, 64, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, opt.n_channels, kernel_size=5, padding=2)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = self.conv3(x)
        
        out = self.add_mean(x)
        return out

In [7]:
def save_checkpoint(srcnn, epoch, loss):
    checkpoint_dir = opt.checkpoint_dir
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, "models_epoch_%04d_loss_%.6f.pth" % (epoch, loss))
    
    if torch.cuda.device_count() > 1 and opt.multi_gpu:
        state = {"epoch": epoch, "srcnn": srcnn.module}
    else:
        state = {"epoch": epoch, "srcnn": srcnn}

    torch.save(state, checkpoint_path)
    print("Checkpoint saved to {}".format(checkpoint_path))

In [8]:
def load_model(checkpoint_dir):
    # we will check from last, best, and specific epoch_num model
    checkpoint_list = glob.glob(os.path.join(checkpoint_dir, "*.pth"))
    checkpoint_list.sort()

    if opt.resume_best:
        loss_list = list(map(lambda x: float(os.path.basename(x).split('_')[4][:-4]), checkpoint_list))
        best_loss_idx = loss_list.index(min(loss_list))
        checkpoint_path = checkpoint_list[best_loss_idx]
    else:
        checkpoint_path = checkpoint_list[len(checkpoint_list) - 1]

    srcnn = SRCNN(opt)

    if os.path.isfile(checkpoint_path):
        print("=> loading checkpoint '{}'".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        
        n_epoch = checkpoint['epoch']
        srcnn.load_state_dict(checkpoint['srcnn'].state_dict())
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(checkpoint_path, n_epoch))
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_path))
        n_epoch = 0

    return n_epoch + 1, srcnn

In [9]:
def train(opt, model, optimizer, data_loader, loss_criterion):
    print("===> Training")
    start_time = time.time()

    total_loss = 0.0
    total_psnr = 0.0

    for iteration, batch in enumerate(tqdm(data_loader), 1):
        x, target = batch[0], batch[1]
        if opt.use_cuda:
            x = x.to(opt.device)
            target = target.to(opt.device)

        out = model(x)

        loss = loss_criterion(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss
        psnr = 10 * math.log10(1 / loss.item())
        total_psnr += psnr

#         print("Training %.2fs => Epoch[%d/%d](%d/%d): Loss: %.5f PSNR: %.5f" %
#             (time.time() - start_time, opt.epoch_num, opt.n_epochs, iteration, len(data_loader), loss.item(), psnr))

    total_loss = total_loss / iteration
    total_psnr = total_psnr / iteration
    
    print("Training %.2fs => Epoch[%d/%d]: Loss: %.5f PSNR: %.5f" %
        (time.time() - start_time, opt.epoch_num, opt.n_epochs, total_loss, total_psnr))

    return (total_loss, total_psnr)

In [10]:
def evaluate(opt, model, data_loader, loss_criterion):
    print("===> Validation")
    start_time = time.time()

    total_loss = 0.0
    total_psnr = 0.0
    with torch.no_grad():
        for iteration, batch in enumerate(tqdm(data_loader), 1):
            x, target = batch[0], batch[1]

            if opt.use_cuda:
                x = x.to(opt.device)
                target = target.to(opt.device)

            out = model(x)
            
            loss = loss_criterion(out, target)
            total_loss += loss

            psnr = 10 * math.log10(1 / loss.item())
            total_psnr += psnr

#             print("Validation %.2fs => Epoch[%d/%d](%d/%d): Loss: %.5f PSNR: %.5f" %
#                 (time.time() - start_time, opt.epoch_num, opt.n_epochs, iteration, len(data_loader), loss.item(), psnr))

    total_loss = total_loss / iteration
    total_psnr = total_psnr / iteration
    
    print("Validation %.2fs => Epoch[%d/%d]: Loss: %.5f PSNR: %.5f" %
        (time.time() - start_time, opt.epoch_num, opt.n_epochs, total_loss, total_psnr))

    return (total_loss, total_psnr)

In [11]:
def run_train(opt, training_data_loader, validation_data_loader):
    # Define what device we are using
    if not os.path.exists(opt.checkpoint_dir):
        os.makedirs(opt.checkpoint_dir)

    log_file = os.path.join(opt.checkpoint_dir, "srcnn_log.csv")

    print('[Initialize networks for training]')
    srcnn = SRCNN(opt)
    l2_criterion = nn.MSELoss()
    print(srcnn)

    # optionally resume from a checkpoint
    if opt.resume:
        opt.start_epoch, srcnn = load_model(opt.checkpoint_dir)
    else:
        with open(log_file, mode='w') as f:
            f.write("epoch,train_loss,valid_loss\n")

    print("===> Setting GPU")
    print("CUDA Available: ", torch.cuda.is_available())
    if opt.use_cuda and torch.cuda.is_available():
        opt.use_cuda = True
        opt.device = 'cuda'
    else:
        opt.use_cuda = False
        opt.device = 'cpu'
        
    if torch.cuda.device_count() > 1 and opt.multi_gpu:
        print("Use " + str(torch.cuda.device_count()) + " GPUs")
        srcnn = nn.DataParallel(srcnn)

    if opt.use_cuda:
        srcnn = srcnn.to(opt.device)
        l2_criterion = l2_criterion.to(opt.device)

    print("===> Setting Optimizer")
    optimizer = torch.optim.Adam(srcnn.parameters(),  lr=opt.lr, betas=(opt.b1, opt.b2))

    for epoch in range(opt.start_epoch, opt.n_epochs):
        opt.epoch_num = epoch
        train_loss, train_psnr = train(opt, srcnn, optimizer, training_data_loader, loss_criterion=l2_criterion)
        valid_loss, valid_psnr = evaluate(opt, srcnn, validation_data_loader, loss_criterion=l2_criterion)

        with open(log_file, mode='a') as f:
            f.write("%d,%08f,%08f,%08f,%08f\n" % (
                epoch,
                train_loss,
                train_psnr,
                valid_loss,
                valid_psnr
            ))
        save_checkpoint(srcnn, epoch, valid_loss)

In [None]:
train_dir = opt.train_dir
valid_dir = opt.valid_dir
print("train_dir is: {}".format(train_dir))
print("valid_dir is: {}".format(valid_dir))

target_size = (opt.lr_img * opt.res_scale, opt.lr_img * opt.res_scale)


train_dataset = DatasetFromFolder(train_dir)
valid_dataset = DatasetFromFolder(valid_dir)

training_data_loader = DataLoader(dataset=train_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=True)
validation_data_loader = DataLoader(dataset=valid_dataset,
                                    batch_size=opt.batch_size,
                                    shuffle=False)

run_train(opt, training_data_loader, validation_data_loader)

train_dir is: ../../data/super-resolution/div2k_100/train_patches_x2lr64
valid_dir is: ../../data/super-resolution/div2k_100/valid_patches_x2lr64
[Initialize networks for training]
SRCNN(
  (sub_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (add_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (relu1): ReLU()
  (conv2): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (relu2): ReLU()
  (conv3): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)
=> loading checkpoint './checkpoin_dir\models_epoch_0023_loss_0.000916.pth'
=> loaded checkpoint './checkpoin_dir\models_epoch_0023_loss_0.000916.pth' (epoch 23)
===> Setting GPU
CUDA Available:  True
===> Setting Optimizer
===> Training


  0%|                                                                                          | 0/120 [00:00<?, ?it/s]

In [None]:
def test_model(opt, test_data_loader):
    sr_compare_dir = os.path.join(opt.result_img_dir, "compare")
    sr_result_dir = os.path.join(opt.result_img_dir, "sr")

    if not os.path.exists(sr_result_dir):
        os.makedirs(sr_result_dir)
    if not os.path.exists(sr_compare_dir):
        os.makedirs(sr_compare_dir)

    opt.start_epoch, srcnn = load_model(opt.checkpoint_dir)
    criterion = nn.MSELoss()

    if torch.cuda.device_count() > 1 and opt.multi_gpu:
        print("Use " + str(torch.cuda.device_count()) + " GPUs")
        srcnn = nn.DataParallel(srcnn)

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        srcnn = srcnn.to(opt.device)
        criterion = criterion.to(opt.device)

    hr_img_sz = (opt.lr_img_y * opt.res_scale, opt.lr_img_x * opt.res_scale)
    result_img = np.zeros((hr_img_sz[0], hr_img_sz[1] * 3, opt.channels))

    with torch.no_grad():
        total_num = 0
        sum_bicubic_psnr = 0.0
        sum_sr_psnr = 0.0
        avg_bicubic_psnr = 0.0
        avg_sr_psnr = 0.0

        start_time = time.time()
        for iteration, batch in enumerate(test_data_loader, 1):
            input, target, file_names = Variable(batch[0]), Variable(batch[1], requires_grad=False), batch[2]

            if use_cuda:
                input = input.cuda()
                target = target.cuda()

            out = srcnn(input)

            bicubic_loss = criterion(input, target)
            sr_loss = criterion(out, target)

            bicubic_psnr =  10 * math.log10(1 / bicubic_loss.item())
            sr_psnr =  10 * math.log10(1 / sr_loss.item())

            sum_bicubic_psnr += bicubic_psnr
            sum_sr_psnr += sr_psnr

            avg_bicubic_psnr = sum_bicubic_psnr / iteration
            avg_sr_psnr = sum_sr_psnr / iteration

            print("Time: {:.2f} ===> {}/{}: bicubic_psnr: {}, sr_psnr: {}".format(time.time() - start_time, iteration, len(test_data_loader), bicubic_psnr, sr_psnr))
            print("avg_bicubic_nspr: {}".format(avg_bicubic_psnr))
            print("avg_sr_nspr: {}".format(avg_sr_psnr))

            for i in range(input.size()[0]):
                input_arr = (input[i].data).cpu().numpy()
                hr_sr_arr = (out[i].data).cpu().numpy()
                hr_target_arr = (target[i]).cpu().data.numpy()

                result_img = np.concatenate((input_arr, hr_sr_arr, hr_target_arr), axis=2)
                # result_img[:, 0:hr_img_sz, :] = input_cv2
                # result_img[:, hr_img_sz:2*hr_img_sz, :] = hr_sr
                # result_img[:, 2*hr_img_sz:3*hr_img_sz, :] = hr_target_arr

                # result_img.save(os.path.join(opt.result_img_dir, "%09d.png" % (total_num)))
                # print(file_names[i])
                imsave(os.path.join(sr_compare_dir, file_names[i]), result_img)
                imsave(os.path.join(sr_result_dir, file_names[i]), hr_sr_arr)
                total_num += 1
                # cv2.imshow('result', result_img)
                # cv2.waitKey()