In [1]:
import torch
from torch.utils.data import DataLoader

from CNN.dataset import *
from CNN.model import *
from CNN.utils import *

In [2]:
device = torch.device('cuda:0')

data_lst = ['Set5', 'Set14', 'BSD100']

# **LaUD (finetuned)**

In [3]:
LaUDx2 = LaUD(rudp=3, scale=2, ch=64)
LaUDx2.load_state_dict(torch.load('./CNN/Models/LaUD_finetuned_X2_3RUDP.pt'))
LaUDx2 = LaUDx2.to(device)

LaUDx4 = LaUD(rudp=3, scale=4, ch=64)
LaUDx4.load_state_dict(torch.load('./CNN/Models/LaUD_finetuned_X4_3RUDP.pt'))
LaUDx4 = LaUDx4.to(device)

LaUDx8 = LaUD(rudp=3, scale=8, ch=64)
LaUDx8.load_state_dict(torch.load('./CNN/Models/LaUD_finetuned_X8_3RUDP.pt'))
LaUDx8 = LaUDx8.to(device)

In [4]:
models = [LaUDx2, LaUDx4, LaUDx8]
info = ''
for testdata in data_lst:
    info += '\n[For {}, PSNR and SSIM]\n'.format(testdata)
    for i in range(len(models)):
        m = models[i]
        mag = 2**(i+1)
        
        testset = test_dataset(root_path='./datasets/for_test', type=testdata,
                               is_resize=False, resize_h=None, resize_w=None, is_rcrop=False, crop_h=None, crop_w=None, scale=mag, 
                               is_rrot=False, rand_hori_flip=False, rand_vert_flip=False, grayscale=False, norm=True)
        testloader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)
    
        m.eval()
        with torch.no_grad():
            psnr_sum = 0 
            ssim_sum = 0
            for iteration, data in enumerate(testloader):
                hr_img, lr_img = data[0].to(device), data[1].to(device)
    
                sr_lst, det_lst = m(lr_img)
    
                psnr_sum += cal_psnr(hr_img, sr_lst[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
                ssim_sum += cal_ssim(hr_img, sr_lst[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')

        info += 'LaUDX{}_3RUDP: {:.2f} dB and {:.4f}\n'.format(mag, psnr_sum/len(testloader), ssim_sum/len(testloader))
    
print(info)


[For Set5, PSNR and SSIM]
LaUDX2_3RUDP: 38.45 dB and 0.9625
LaUDX4_3RUDP: 32.82 dB and 0.9021
LaUDX8_3RUDP: 27.51 dB and 0.7882

[For Set14, PSNR and SSIM]
LaUDX2_3RUDP: 34.65 dB and 0.9256
LaUDX4_3RUDP: 29.06 dB and 0.7939
LaUDX8_3RUDP: 25.34 dB and 0.6569

[For BSD100, PSNR and SSIM]
LaUDX2_3RUDP: 32.54 dB and 0.9042
LaUDX4_3RUDP: 27.89 dB and 0.7472
LaUDX8_3RUDP: 25.04 dB and 0.6102



# **Ablation Models**

In [5]:
mag = 2

plain = PlainNet(ch=64)
plain.load_state_dict(torch.load('./CNN/Models/PlainNet_pret_ablation_X2.pt'))
plain = plain.to(device)

plain_wdet = PlainNet_wdet(ch=64)
plain_wdet.load_state_dict(torch.load('./CNN/Models/PlainNet_wdet_pret_ablation_X2.pt'))
plain_wdet = plain_wdet.to(device)

laud_wo_det = LaUD_wo_det(rudp=3, scale=mag, ch=64)
laud_wo_det.load_state_dict(torch.load('./CNN/Models/LaUD_wo_det_pret_ablation_X2_3RUDP.pt'))
laud_wo_det = laud_wo_det.to(device)

laud = LaUD(rudp=3, scale=mag, ch=64)
laud.load_state_dict(torch.load('./CNN/Models/LaUD_pret_X2_3RUDP.pt'))
laud = laud.to(device)

In [6]:
for testdata in data_lst:
    testset = test_dataset(root_path='./datasets/for_test', type=testdata,
                           is_resize=False, resize_h=None, resize_w=None, is_rcrop=False, crop_h=None, crop_w=None, scale=mag, 
                           is_rrot=False, rand_hori_flip=False, rand_vert_flip=False, grayscale=False, norm=True)
    testloader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)

    plain.eval()
    plain_wdet.eval()
    laud_wo_det.eval()
    laud.eval()
    with torch.no_grad():
        plain_psnr_sum = 0 
        plain_ssim_sum = 0
        plain_wdet_psnr_sum = 0 
        plain_wdet_ssim_sum = 0
        LaUD_wo_det_psnr_sum = 0 
        LaUD_wo_det_ssim_sum = 0
        LaUD_psnr_sum = 0 
        LaUD_ssim_sum = 0
        for iteration, data in enumerate(testloader):
            hr_img, lr_img = data[0].to(device), data[1].to(device)
            
            sr_lst_plain = plain(lr_img)
            sr_lst_plain_wdet, det_lst_plain_wdet = plain_wdet(lr_img)
            sr_lst_laud_wo_det = laud_wo_det(lr_img)
            sr_lst_laud, det_lst_laud = laud(lr_img)

            plain_psnr_sum += cal_psnr(hr_img, sr_lst_plain[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            plain_ssim_sum += cal_ssim(hr_img, sr_lst_plain[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')
            
            plain_wdet_psnr_sum += cal_psnr(hr_img, sr_lst_plain_wdet[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            plain_wdet_ssim_sum += cal_ssim(hr_img, sr_lst_plain_wdet[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')

            LaUD_wo_det_psnr_sum += cal_psnr(hr_img, sr_lst_laud_wo_det[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            LaUD_wo_det_ssim_sum += cal_ssim(hr_img, sr_lst_laud_wo_det[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')

            LaUD_psnr_sum += cal_psnr(hr_img, sr_lst_laud[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            LaUD_ssim_sum += cal_ssim(hr_img, sr_lst_laud[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')
            
        info = '[For {} X{}, PSNR and SSIM]\n'.format(testdata, mag)
        info += 'PlainNet: {:.4f} dB and {:.4f}\n'.format(plain_psnr_sum/len(testloader), plain_ssim_sum/len(testloader))
        info += 'PlainNet_wdet: {:.4f} dB and {:.4f}\n'.format(plain_wdet_psnr_sum/len(testloader), plain_wdet_ssim_sum/len(testloader))
        info += 'LaUD_wo_det: {:.4f} dB and {:.4f}\n'.format(LaUD_wo_det_psnr_sum/len(testloader), LaUD_wo_det_ssim_sum/len(testloader))
        info += 'LaUD: {:.4f} dB and {:.4f}\n'.format(LaUD_psnr_sum/len(testloader), LaUD_ssim_sum/len(testloader))
        print(info)

[For Set5 X2, PSNR and SSIM]
PlainNet: 38.0857 dB and 0.9613
PlainNet_wdet: 38.2841 dB and 0.9619
LaUD_wo_det: 38.3154 dB and 0.9620
LaUD: 38.4237 dB and 0.9625

[For Set14 X2, PSNR and SSIM]
PlainNet: 33.9680 dB and 0.9205
PlainNet_wdet: 34.2761 dB and 0.9224
LaUD_wo_det: 34.6050 dB and 0.9250
LaUD: 34.7677 dB and 0.9256

[For BSD100 X2, PSNR and SSIM]
PlainNet: 32.2910 dB and 0.9015
PlainNet_wdet: 32.4013 dB and 0.9029
LaUD_wo_det: 32.4888 dB and 0.9037
LaUD: 32.5504 dB and 0.9045



# **Ablation in Appendix**

In [7]:
mag = 2
data_lst = ['Set5', 'Set14', 'BSD100', 'Urban100']

LaUD_1RUDP = LaUD(rudp=1, scale=mag, ch=64)
LaUD_1RUDP.load_state_dict(torch.load('./CNN/Models/LaUD_pret_X2_1RUDP.pt'))
LaUD_1RUDP = LaUD_1RUDP.to(device)

LaUD_2RUDP = LaUD(rudp=2, scale=mag, ch=64)
LaUD_2RUDP.load_state_dict(torch.load('./CNN/Models/LaUD_pret_X2_2RUDP.pt'))
LaUD_2RUDP = LaUD_2RUDP.to(device)

LaUD_3RUDP = LaUD(rudp=3, scale=mag, ch=64)
LaUD_3RUDP.load_state_dict(torch.load('./CNN/Models/LaUD_pret_X2_3RUDP.pt'))
LaUD_3RUDP = LaUD_3RUDP.to(device)

In [9]:
for testdata in data_lst:
    testset = test_dataset(root_path='./datasets/for_test', type=testdata,
                           is_resize=False, resize_h=None, resize_w=None, is_rcrop=False, crop_h=None, crop_w=None, scale=mag, 
                           is_rrot=False, rand_hori_flip=False, rand_vert_flip=False, grayscale=False, norm=True)
    testloader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)

    LaUD_1RUDP.eval()
    LaUD_2RUDP.eval()
    LaUD_3RUDP.eval()
    with torch.no_grad():
        LaUD_1RUDP_psnr_sum = 0 
        LaUD_1RUDP_ssim_sum = 0
        LaUD_2RUDP_psnr_sum = 0 
        LaUD_2RUDP_ssim_sum = 0
        LaUD_3RUDP_psnr_sum = 0 
        LaUD_3RUDP_ssim_sum = 0
        for iteration, data in enumerate(testloader):
            hr_img, lr_img = data[0].to(device), data[1].to(device)
            
            sr_lst_1RUDP, det_lst_1RUDP = LaUD_1RUDP(lr_img)
            sr_lst_2RUDP, det_lst_2RUDP = LaUD_2RUDP(lr_img)
            sr_lst_3RUDP, det_lst_3RUDP = LaUD_3RUDP(lr_img)

            LaUD_1RUDP_psnr_sum += cal_psnr(hr_img, sr_lst_1RUDP[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            LaUD_1RUDP_ssim_sum += cal_ssim(hr_img, sr_lst_1RUDP[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')
            
            LaUD_2RUDP_psnr_sum += cal_psnr(hr_img, sr_lst_2RUDP[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            LaUD_2RUDP_ssim_sum += cal_ssim(hr_img, sr_lst_2RUDP[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')
            
            LaUD_3RUDP_psnr_sum += cal_psnr(hr_img, sr_lst_3RUDP[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            LaUD_3RUDP_ssim_sum += cal_ssim(hr_img, sr_lst_3RUDP[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')

        info = '[For {} X{}, PSNR and SSIM]\n'.format(testdata, mag)
        info += 'LaUD_1RUPD: {:.4f} dB and {:.4f}\n'.format(LaUD_1RUDP_psnr_sum/len(testloader), LaUD_1RUDP_ssim_sum/len(testloader))
        info += 'LaUD_2RUPD: {:.4f} dB and {:.4f}\n'.format(LaUD_2RUDP_psnr_sum/len(testloader), LaUD_2RUDP_ssim_sum/len(testloader))
        info += 'LaUD_3RUPD: {:.4f} dB and {:.4f}\n'.format(LaUD_3RUDP_psnr_sum/len(testloader), LaUD_3RUDP_ssim_sum/len(testloader))
        print(info)

[For Set5 X2, PSNR and SSIM]
LaUD_1RUPD: 38.1402 dB and 0.9613
LaUD_2RUPD: 38.2802 dB and 0.9619
LaUD_3RUPD: 38.4237 dB and 0.9625

[For Set14 X2, PSNR and SSIM]
LaUD_1RUPD: 34.0151 dB and 0.9209
LaUD_2RUPD: 34.4278 dB and 0.9238
LaUD_3RUPD: 34.7677 dB and 0.9256

[For BSD100 X2, PSNR and SSIM]
LaUD_1RUPD: 32.3038 dB and 0.9015
LaUD_2RUPD: 32.4483 dB and 0.9035
LaUD_3RUPD: 32.5504 dB and 0.9045

[For Urban100 X2, PSNR and SSIM]
LaUD_1RUPD: 32.5285 dB and 0.9396
LaUD_2RUPD: 33.4187 dB and 0.9476
LaUD_3RUPD: 34.0834 dB and 0.9529



In [10]:
mag = 2
data_lst = ['Set5', 'Set14', 'BSD100', 'Urban100']

laud_wo_det = LaUD_wo_det(rudp=3, scale=mag, ch=64)
laud_wo_det.load_state_dict(torch.load('./CNN/Models/LaUD_wo_det_pret_ablation_X2_3RUDP.pt'))
laud_wo_det = laud_wo_det.to(device)

laud_l2Loss = LaUD(rudp=3, scale=mag, ch=64)
laud_l2Loss.load_state_dict(torch.load('./CNN/Models/LaUD_l2Loss_for_det_pret_ablation_X2_3RUDP.pt'))
laud_l2Loss = laud_l2Loss.to(device)

laud = LaUD(rudp=3, scale=mag, ch=64)
laud.load_state_dict(torch.load('./CNN/Models/LaUD_pret_X2_3RUDP.pt'))
laud = laud.to(device)

In [11]:
for testdata in data_lst:
    testset = test_dataset(root_path='./datasets/for_test', type=testdata,
                           is_resize=False, resize_h=None, resize_w=None, is_rcrop=False, crop_h=None, crop_w=None, scale=mag, 
                           is_rrot=False, rand_hori_flip=False, rand_vert_flip=False, grayscale=False, norm=True)
    testloader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)

    laud_wo_det.eval()
    laud_l2Loss.eval()
    laud.eval()
    with torch.no_grad():
        LaUD_wo_det_psnr_sum = 0
        LaUD_wo_det_ssim_sum = 0
        LaUD_l2Loss_psnr_sum = 0 
        LaUD_l2Loss_ssim_sum = 0
        LaUD_psnr_sum = 0 
        LaUD_ssim_sum = 0
        for iteration, data in enumerate(testloader):
            hr_img, lr_img = data[0].to(device), data[1].to(device)

            sr_lst_wo_det = laud_wo_det(lr_img)
            sr_lst_l2Loss, det_lst_l2Loss = laud_l2Loss(lr_img)
            sr_lst_LaUD, det_lst_LaUD = laud(lr_img)

            LaUD_wo_det_psnr_sum += cal_psnr(hr_img, sr_lst_wo_det[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            LaUD_wo_det_ssim_sum += cal_ssim(hr_img, sr_lst_wo_det[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')

            LaUD_l2Loss_psnr_sum += cal_psnr(hr_img, sr_lst_l2Loss[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            LaUD_l2Loss_ssim_sum += cal_ssim(hr_img, sr_lst_l2Loss[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')
            
            LaUD_psnr_sum += cal_psnr(hr_img, sr_lst_LaUD[-1], crop_border=mag, minmax='-1_1', clamp=True, gray_scale=True, ver='YCrCb_BT601')
            LaUD_ssim_sum += cal_ssim(hr_img, sr_lst_LaUD[-1], crop_border=0, minmax='-1_1', filter_size=11, filter_sigma=1.5, clamp=True, grayscale=True, ver='YCrCb_BT601')
            
        info = '[For {} X{}, PSNR and SSIM]\n'.format(testdata, mag)
        info += 'LaUD_SR_l1: {:.4f} dB and {:.4f}\n'.format(LaUD_wo_det_psnr_sum/len(testloader), LaUD_wo_det_ssim_sum/len(testloader))
        info += 'LaUD_SR_l1_Det_l2: {:.4f} dB and {:.4f}\n'.format(LaUD_l2Loss_psnr_sum/len(testloader), LaUD_l2Loss_ssim_sum/len(testloader))
        info += 'LaUD_SR_l1_Det_l1: {:.4f} dB and {:.4f}\n'.format(LaUD_psnr_sum/len(testloader), LaUD_ssim_sum/len(testloader))
        print(info)

[For Set5 X2, PSNR and SSIM]
LaUD_SR_l1: 38.3154 dB and 0.9620
LaUD_SR_l1_Det_l2: 38.3899 dB and 0.9627
LaUD_SR_l1_Det_l1: 38.4237 dB and 0.9625

[For Set14 X2, PSNR and SSIM]
LaUD_SR_l1: 34.6050 dB and 0.9250
LaUD_SR_l1_Det_l2: 34.5614 dB and 0.9252
LaUD_SR_l1_Det_l1: 34.7677 dB and 0.9256

[For BSD100 X2, PSNR and SSIM]
LaUD_SR_l1: 32.4888 dB and 0.9037
LaUD_SR_l1_Det_l2: 32.5058 dB and 0.9041
LaUD_SR_l1_Det_l1: 32.5504 dB and 0.9045

[For Urban100 X2, PSNR and SSIM]
LaUD_SR_l1: 33.6879 dB and 0.9497
LaUD_SR_l1_Det_l2: 33.8258 dB and 0.9514
LaUD_SR_l1_Det_l1: 34.0834 dB and 0.9529

