In [1]:
import os
import numpy as np
import torch
import cv2
from skimage.metrics import structural_similarity as compare_ssim
import time
from utils import load_model_ISTA, load_model_DRS, load_model_CNN
from utils import find_test_image_circular_pad as denoise
import matplotlib.pyplot as plt
import bm3d

def proj(im_input, minval, maxval):
    im_out = np.where(im_input > maxval, maxval, im_input)
    im_out = np.where(im_out < minval, minval, im_out)
    return im_out

def psnr(x,im_orig):
    norm2 = np.mean((x - im_orig) ** 2)
    psnr = -10 * np.log10(norm2)
    return psnr


In [2]:
#Atransposex
def funcAtranspose(im_input, mask, fx, fy):
    m,n = im_input.shape
    fx = int(1/fx)
    fy = int(1/fy)
    im_inputres = np.zeros([m*fx, m*fy], im_input.dtype)
    for i in range(m):
        for j in range(n):
            im_inputres[fx*i,fy*j] = im_input[i,j]
    im_output = im_inputres
    return im_output

#Ax
def funcA(im_input, mask, fx, fy):
    m,n = im_input.shape
    w = len(mask[0])
    r = int((w - 1) / 2)
    im_output = im_input
    im_outputres = cv2.resize(im_output, (0,0), fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
    return im_outputres


In [3]:
def pnp_fbs_superresolution(model, im_input, fx, fy, mask, denoiser, **opts):

    lamda = opts.get('lamda', 2.0)
    rho = opts.get('rho', 1.0)
    maxitr = opts.get('maxitr', 100)
    verbose = opts.get('verbose',1)
    sigma = opts.get('sigma', 5)
    stride = opts.get('stride', 8)

    """ Initialization. """
    index = np.nonzero(mask)
    y = funcAtranspose(im_input, mask, fx, fy)
    m, n = y.shape

    x = cv2.resize(im_input, (m, n))

    """ Main loop. """
    for i in range(maxitr):

        xold = np.copy(x)

        """ Update gradient. """
        xoldhat = funcA(x, mask, fx, fy)
        gradx = funcAtranspose(xoldhat, mask, fx, fy) - y

        """ Denoising step. """

        xtilde = np.copy(xold - rho * gradx)

        xtilde_torch = np.reshape(xtilde, (1,1,m,n))
        xtilde_torch = torch.from_numpy(xtilde_torch).type(torch.FloatTensor).cuda()
        if denoiser == "Proposed_ISTA" or denoiser == "Proposed_DRS":
            x = denoise(xtilde_torch, model, stride, 64, 500).cpu().numpy()
            x = np.reshape(x, (m,n))
        if denoiser == "DnCNN" or denoiser == "RealSN_DnCNN" or denoiser == "SimpleCNN" or denoiser == "RealSN_SimpleCNN":
            res = model(xtilde_torch).cpu().numpy()
            res = np.reshape(res, (m,n))
            x = xtilde - res
        if denoiser == "BM3D":
            x = bm3d.bm3d(xtilde, sigma_psd=5/255, stage_arg=bm3d.BM3DStages.ALL_STAGES)

    return x


In [5]:
dir_name = 'data/Set12/'
image_arr = ['01.png', '02.png', '03.png', '04.png', '05.png', '06.png', '07.png', '08.png', '09.png', '10.png', '11.png', '12.png']
for imagename in image_arr:
    input_str = dir_name + imagename
    print(input_str)

    K = 2 # downsampling factor
    # ---- load the ground truth ----
    im_orig = cv2.imread(input_str, 0)/255.0
    m,n = im_orig.shape

    # ---- blur the image
    kernel = cv2.getGaussianKernel(9, 1)
    mask = np.outer(kernel, kernel.transpose())
    w = len(mask[0])
    r = int((w - 1) / 2)
    im_orig = cv2.copyMakeBorder(im_orig, r, r, r, r, borderType=cv2.BORDER_WRAP)
    im_blur = cv2.filter2D(im_orig, -1, mask)
    im_blur = im_blur[r:r+m, r:r+n]
    im_orig = im_orig[r:r+m, r:r+m]
    fx = 1./K
    fy = 1./K

    # ---- add noise -----
    noise_level = 5.0 / 255.0
    gauss = np.random.normal(0.0, noise_level, im_orig.shape)
    im_noisy = im_orig + gauss
    im_noisy2 = np.copy(im_noisy)
    im_noisy2[im_noisy2<0.]=0.
    im_noisy2[im_noisy2>1.]=1.
    bicubic_img = cv2.resize(im_orig, None, fx = K, fy = K, interpolation = cv2.INTER_CUBIC)

    # ---- set options -----
    maxitr = 10
    sigma = 5
    rho = 1.
    opts = dict(sigma=sigma, rho=rho, maxitr=maxitr, verbose=True)

    denoiser = "Proposed_ISTA"
    path = "models/proposed_denoiser_ISTA_sigma" + str(sigma) +".pth"

    model = load_model_ISTA(sigma, path)
    print(denoiser)
    opts = dict(sigma=sigma, rho=rho, maxitr=maxitr, verbose=True)
    # ---- plug and play -----
    out = pnp_fbs_superresolution(model, im_noisy, fx, fy, mask, denoiser, **opts)

    psnr_ours = psnr(out, bicubic_img)
    ssim_ours = compare_ssim(out, bicubic_img, data_range=1.)
    print('sigma = {}, rho = {} - PNSR: {}, SSIM = {}'.format(sigma, rho, psnr_ours, ssim_ours))


    denoiser = "Proposed_DRS"
    path = "models/proposed_denoiser_DRS_sigma" + str(sigma) +".pth"
    model = load_model_DRS(sigma, path)
    print(denoiser)

    # ---- plug and play -----
    out = pnp_fbs_superresolution(model, im_noisy, fx, fy, mask, denoiser, **opts)

    # ---- results ----

    psnr_ours = psnr(out, bicubic_img)
    ssim_ours = compare_ssim(out, bicubic_img, data_range=1.)
    print('sigma = {}, rho = {} - PNSR: {}, SSIM = {}'.format(sigma, rho, psnr_ours, ssim_ours))
