In [2]:
import os
import easydict
import time

import numpy as np
import cv2
import math

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim

from torch.utils import data as D
from torch.utils.data import DataLoader

from tqdm import tqdm
import glob

In [75]:
opt = easydict.EasyDict({
    "resume": False,
    "resume_best": False,
    
    "multi_gpu": True,
    "use_cuda": True,
    "device": 'cpu',
    
    "n_epochs": 5000,
    "batch_size": 100,
    "start_epoch": 1,
    "batch_size": 128,
    "lr": 1e-4, # Adam: learning rate
    "b1": 0.9, # Adam: decay of first order momentum of gradient
    "b2": 0.999, # Adam: decay of second order momentum of gradient
    
    "checkpoint_dir": './checkpoin_dir',
    "train_dir": '../../data/super-resolution/div2k_100/train_patches_x2lr64',
    "valid_dir": '../../data/super-resolution/div2k_100/valid_patches_x2lr64',
    "test_dir": '../../data/super-resolution/div2k_100/test_images',
    "test_result_dir": '../../data/super-resolution/div2k_100/test_result',
    
    "lr_img": 64,
    "res_scale": 2,
    "n_channels": 3
})

print(opt)

{'resume': False, 'resume_best': False, 'multi_gpu': True, 'use_cuda': True, 'device': 'cpu', 'n_epochs': 5000, 'batch_size': 128, 'start_epoch': 1, 'lr': 0.0001, 'b1': 0.9, 'b2': 0.999, 'checkpoint_dir': './checkpoin_dir', 'train_dir': '../../data/super-resolution/div2k_100/train_patches_x2lr64', 'valid_dir': '../../data/super-resolution/div2k_100/valid_patches_x2lr64', 'test_dir': '../../data/super-resolution/div2k_100/test_images', 'test_result_dir': '../../data/super-resolution/div2k_100/test_result', 'lr_img': 64, 'res_scale': 2, 'n_channels': 3}


In [76]:
def normalize_img(img):
    img = img / 255.
    img = img.astype(np.float32)
    return img

In [77]:
class DatasetFromFolder(data.Dataset):
    def __init__(self, data_dir):
        super(DatasetFromFolder, self).__init__()

        lr_dir = os.path.join(data_dir, 'lr')
        hr_dir = os.path.join(data_dir, 'hr')
        
        self.dsets = {}
        self.dsets['lr'] = [os.path.join(lr_dir, x) for x in os.listdir(lr_dir)]
        self.dsets['hr'] = [os.path.join(hr_dir, x) for x in os.listdir(hr_dir)]
        self.dsets['file_name'] = os.listdir(hr_dir)

    def __getitem__(self, idx):
        input = cv2.imread(self.dsets['lr'][idx])
        target = cv2.imread(self.dsets['hr'][idx])
        file_name = self.dsets['file_name'][idx]

        input = cv2.resize(input, (target.shape[1], target.shape[0]), interpolation=cv2.INTER_CUBIC)
        
        input = normalize_img(input)
        target = normalize_img(target)

        input = np.transpose(input, (2, 0, 1))
        target = np.transpose(target, (2, 0, 1))

        input = torch.from_numpy(input).type(torch.FloatTensor)
        target = torch.from_numpy(target).type(torch.FloatTensor)

        return input, target, file_name

    def __len__(self):
        return len(self.dsets['lr'])

In [78]:
class MeanShift(nn.Conv2d):
    def __init__(
        self, rgb_range,
        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False

In [79]:
class SRCNN(nn.Module):
    def __init__(self, opt):
        super(SRCNN, self).__init__()
        pix_range = 1.0
        
        self.sub_mean = MeanShift(pix_range)
        self.add_mean = MeanShift(pix_range, sign=1)
        
        self.conv1 = nn.Conv2d(opt.n_channels, 64, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, opt.n_channels, kernel_size=5, padding=2)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = self.conv3(x)
        
        out = self.add_mean(x)
        return out

In [80]:
def save_checkpoint(srcnn, epoch, loss):
    checkpoint_dir = os.path.abspath(opt.checkpoint_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, "models_epoch_%04d_loss_%.6f.pth" % (epoch, loss))
    
    if torch.cuda.device_count() > 1 and opt.multi_gpu:
        state = {"epoch": epoch, "srcnn": srcnn.module}
    else:
        state = {"epoch": epoch, "srcnn": srcnn}

    torch.save(state, checkpoint_path)
    print("Checkpoint saved to {}".format(checkpoint_path))

In [81]:
def load_model(checkpoint_dir):
    # we will check from last, best, and specific epoch_num model
    checkpoint_list = glob.glob(os.path.join(checkpoint_dir, "*.pth"))
    checkpoint_list.sort()

    if opt.resume_best:
        loss_list = list(map(lambda x: float(os.path.basename(x).split('_')[4][:-4]), checkpoint_list))
        best_loss_idx = loss_list.index(min(loss_list))
        checkpoint_path = checkpoint_list[best_loss_idx]
    else:
        checkpoint_path = checkpoint_list[len(checkpoint_list) - 1]

    srcnn = SRCNN(opt)

    if os.path.isfile(checkpoint_path):
        print("=> loading checkpoint '{}'".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        
        n_epoch = checkpoint['epoch']
        srcnn.load_state_dict(checkpoint['srcnn'].state_dict())
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(checkpoint_path, n_epoch))
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_path))
        n_epoch = 0

    return n_epoch + 1, srcnn

In [82]:
def train(opt, model, optimizer, data_loader, loss_criterion):
    print("===> Training")
    start_time = time.time()

    total_loss = 0.0
    total_psnr = 0.0

    for iteration, batch in enumerate(tqdm(data_loader), 1):
        x, target = batch[0], batch[1]
        if opt.use_cuda:
            x = x.to(opt.device)
            target = target.to(opt.device)

        out = model(x)

        loss = loss_criterion(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss
        psnr = 10 * math.log10(1 / loss.item())
        total_psnr += psnr

    total_loss = total_loss / iteration
    total_psnr = total_psnr / iteration
    
    print("Training %.2fs => Epoch[%d/%d]: Loss: %.5f PSNR: %.5f" %
        (time.time() - start_time, opt.epoch_num, opt.n_epochs, total_loss, total_psnr))

    return (total_loss, total_psnr)

In [83]:
def evaluate(opt, model, data_loader, loss_criterion):
    print("===> Validation")
    start_time = time.time()

    total_loss = 0.0
    total_psnr = 0.0
    with torch.no_grad():
        for iteration, batch in enumerate(tqdm(data_loader), 1):
            x, target = batch[0], batch[1]

            if opt.use_cuda:
                x = x.to(opt.device)
                target = target.to(opt.device)

            out = model(x)
            
            loss = loss_criterion(out, target)
            total_loss += loss

            psnr = 10 * math.log10(1 / loss.item())
            total_psnr += psnr

    total_loss = total_loss / iteration
    total_psnr = total_psnr / iteration
    
    print("Validation %.2fs => Epoch[%d/%d]: Loss: %.5f PSNR: %.5f" %
        (time.time() - start_time, opt.epoch_num, opt.n_epochs, total_loss, total_psnr))

    return (total_loss, total_psnr)

In [84]:
def run_train(opt, training_data_loader, validation_data_loader):
    # Define what device we are using
    if not os.path.exists(opt.checkpoint_dir):
        os.makedirs(opt.checkpoint_dir)

    log_file = os.path.join(opt.checkpoint_dir, "srcnn_log.csv")

    print('[Initialize networks for training]')
    srcnn = SRCNN(opt)
    l2_criterion = nn.MSELoss()
    print(srcnn)

    # optionally resume from a checkpoint
    if opt.resume:
        opt.start_epoch, srcnn = load_model(opt.checkpoint_dir)
    else:
        with open(log_file, mode='w') as f:
            f.write("epoch,train_loss,valid_loss\n")

    print("===> Setting GPU")
    print("CUDA Available: ", torch.cuda.is_available())
    if opt.use_cuda and torch.cuda.is_available():
        opt.use_cuda = True
        opt.device = 'cuda'
    else:
        opt.use_cuda = False
        opt.device = 'cpu'
        
    if torch.cuda.device_count() > 1 and opt.multi_gpu:
        print("Use " + str(torch.cuda.device_count()) + " GPUs")
        srcnn = nn.DataParallel(srcnn)

    if opt.use_cuda:
        srcnn = srcnn.to(opt.device)
        l2_criterion = l2_criterion.to(opt.device)

    print("===> Setting Optimizer")
    optimizer = torch.optim.Adam(srcnn.parameters(),  lr=opt.lr, betas=(opt.b1, opt.b2))

    for epoch in range(opt.start_epoch, opt.n_epochs):
        opt.epoch_num = epoch
        train_loss, train_psnr = train(opt, srcnn, optimizer, training_data_loader, loss_criterion=l2_criterion)
        valid_loss, valid_psnr = evaluate(opt, srcnn, validation_data_loader, loss_criterion=l2_criterion)

        with open(log_file, mode='a') as f:
            f.write("%d,%08f,%08f,%08f,%08f\n" % (
                epoch,
                train_loss,
                train_psnr,
                valid_loss,
                valid_psnr
            ))
        save_checkpoint(srcnn, epoch, valid_loss)

In [None]:
train_dir = opt.train_dir
valid_dir = opt.valid_dir
print("train_dir is: {}".format(train_dir))
print("valid_dir is: {}".format(valid_dir))

target_size = (opt.lr_img * opt.res_scale, opt.lr_img * opt.res_scale)


train_dataset = DatasetFromFolder(train_dir)
valid_dataset = DatasetFromFolder(valid_dir)

training_data_loader = DataLoader(dataset=train_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=True)
validation_data_loader = DataLoader(dataset=valid_dataset,
                                    batch_size=opt.batch_size,
                                    shuffle=False)

run_train(opt, training_data_loader, validation_data_loader)

In [13]:
from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

In [51]:
def get_test_dataset(data_dir):
    lr_dir = os.path.join(data_dir, 'lr')
    hr_dir = os.path.join(data_dir, 'hr')
        
    dsets = {}
    dsets['lr'] = [os.path.join(lr_dir, x) for x in os.listdir(lr_dir)]
    dsets['hr'] = [os.path.join(hr_dir, x) for x in os.listdir(hr_dir)]
    dsets['file_name'] = os.listdir(hr_dir)
    
    return dsets

In [73]:
def test_model(opt, test_dataset):
    
    lr_list = test_dataset['lr']
    hr_list = test_dataset['hr']
    filename_list = test_dataset['file_name']
    
    sr_compare_dir = os.path.join(opt.test_result_dir, "compare")
    sr_result_dir = os.path.join(opt.test_result_dir, "sr")

    if not os.path.exists(sr_result_dir):
        os.makedirs(sr_result_dir)
    if not os.path.exists(sr_compare_dir):
        os.makedirs(sr_compare_dir)

    opt.resume_best = True
    _, srcnn = load_model(opt.checkpoint_dir)
    criterion = nn.MSELoss()

    if torch.cuda.device_count() > 1 and opt.multi_gpu:
        print("Use " + str(torch.cuda.device_count()) + " GPUs")
        srcnn = nn.DataParallel(srcnn)

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        srcnn = srcnn.to(opt.device)
        criterion = criterion.to(opt.device)

    hr_img_sz = (opt.lr_img * opt.res_scale, opt.lr_img * opt.res_scale)
    result_img = np.zeros((hr_img_sz[0], hr_img_sz[1] * 3, opt.n_channels))

    with torch.no_grad():
        total_num = 0.
        sum_bicubic_psnr = 0.
        sum_sr_psnr = 0.
        sum_bicubic_ssim = 0.
        sum_sr_ssim = 0.
        
        avg_bicubic_psnr = 0.
        avg_sr_psnr = 0.
        avg_bicubic_ssim = 0.
        avg_sr_ssim = 0.

        start_time = time.time()
        for batch in zip(lr_list, hr_list, filename_list):
            input_path, target_path, file_name = batch[0], batch[1], batch[2]
            
            input = cv2.imread(input_path)
            target = cv2.imread(target_path)
            
            input = cv2.resize(input, (target.shape[1], target.shape[0]), interpolation=cv2.INTER_CUBIC)
            
            input = normalize_img(input)
            target = normalize_img(target)
            
            input = np.transpose(input, (2, 0, 1))
            target = np.transpose(target, (2, 0, 1))
            
            input = input.reshape(1, input.shape[0], input.shape[1], input.shape[2])
            target = target.reshape(1, target.shape[0], target.shape[1], target.shape[2])
            
            input = torch.from_numpy(input).type(torch.FloatTensor)
            target = torch.from_numpy(target).type(torch.FloatTensor)

            if use_cuda:
                input = input.to(opt.device)
#                 target = target.to(opt.device)

            out = srcnn(input)
            
            for i in range(input.size(0)): 
                if use_cuda:
                    input_arr = input[i].to('cpu').data.numpy()
                    sr_arr = out[i].to('cpu').data.numpy()
#                     target_arr = target[i].to('cpu').data.numpy()
                else:
                    input_arr = input[i].data.numpy()
                    sr_arr = out[i].data.numpy()
#                     target_arr = target[i].to('cpu').data.numpy()

                target_arr = target[i].data.numpy()

                compare_img = np.concatenate((input_arr, sr_arr, target_arr), axis=2)
                
#                 bicubic_psnr = psnr(input_arr, target_arr)
#                 print("sr_arr.max():", sr_arr.max())
#                 print("sr_arr.min():", sr_arr.min())
#                 sr_psnr = psnr(sr_arr, target_arr)
                
#                 sum_bicubic_psnr += bicubic_psnr
#                 sum_sr_psnr += sr_psnr
                
#                 bicubic_ssim = ssim(input_arr, target_arr)
#                 sr_ssim = ssim(input_arr, target_arr)
                
#                 sum_bicubic_ssim += bicubic_ssim
#                 sum_sr_ssim += sr_ssim
                
                cv2.imwrite(os.path.join(sr_compare_dir, file_name), compare_img)
                cv2.imwrite(os.path.join(sr_result_dir, file_name), sr_arr)
                
#                 print("Bicubic PSNR: {:.8f}, Bicubic SSIM: {:.8f}, SR PSNR: {:.8f}, SR SSIM: {:.8f}".format(
#                     bicubic_psnr, sr_psnr, sr_ssim, sr_psnr))
#                 total_num += 1

        avg_bicubic_psnr = sum_bicubic_psnr / total_num
        avg_sr_psnr = sum_sr_psnr / total_num
        
        avg_bicubic_ssim = sum_bicubic_psnr / total_num
        avg_sr_ssim = sum_sr_psnr / total_num

        print("Time: {:.2f}".format(time.time() - start_time))
        print("Bicubic PSNR: {:.8f}".format(avg_bicubic_psnr))
        print("Bicubic SSIM: {:.8f}".format(avg_bicubic_ssim))
        print("SR PSNR: {:.8f}".format(avg_sr_psnr))
        print("SR SSIM: {:.8f}".format(avg_sr_ssim))

In [74]:
test_dir = opt.test_dir
print("test_dir is: {}".format(test_dir))

test_dataset = get_test_dataset(test_dir)

test_model(opt, test_dataset)

test_dir is: ../../data/super-resolution/div2k_100/test_images
=> loading checkpoint './checkpoin_dir\models_epoch_0053_loss_0.000780.pth'
=> loaded checkpoint './checkpoin_dir\models_epoch_0053_loss_0.000780.pth' (epoch 53)


NameError: name 'bicubic_psnr' is not defined