Code for **super-resolution** (figures $1$ and $5$ from main paper).. Change `factor` to $8$ to reproduce images from fig. $9$ from supmat.

You can play with parameters and see how they affect the result. 

In [None]:
"""
*Uncomment if running on colab* 
Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab 
"""

'''
!git clone https://github.com/wei-tianyu/deep-remote-prior
!mv deep-remote-prior/* ./
'''

In [None]:
from tensorflow.python.client import device_lib 
device_lib.list_local_devices()

# Import libs

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline
import cv2
import argparse
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
import time
from models import *

import torch
import torch.optim

from skimage.metrics import peak_signal_noise_ratio
from models.downsampler import Downsampler

from utils.sr_utils import *

#torch.backends.cudnn.enabled = True
#torch.backends.cudnn.benchmark =True
dtype = torch.FloatTensor

imsize = -1 

#factor is important!!!
factor = 4 # 8
#enforse_div32 = 'CROP' # we usually need the dimensions to be divisible by a power of two (32 in this case)
enforse_div32 = None
PLOT = True

# To produce images from the paper we took *_GT.png images from LapSRN viewer for corresponding factor,
# e.g. x4/zebra_GT.png for factor=4, and x8/zebra_GT.png for factor=8 
path_to_image = '545test_set/sr/butterfly.png'
picname = 'bbb'



# Load image and baselines

In [None]:
# Starts here
imgs = load_LR_HR_imgs_sr(path_to_image , imsize, factor, enforse_div32)

imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np'] = get_baselines(imgs['LR_pil'], imgs['HR_pil'])

if PLOT:
    plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np']], 4,12);
    print ('PSNR bicubic: %.4f   PSNR nearest: %.4f' %  (
                                        peak_signal_noise_ratio(imgs['HR_np'], imgs['bicubic_np']), 
                                        peak_signal_noise_ratio(imgs['HR_np'], imgs['nearest_np'])))

# Set up parameters and net

In [None]:
input_depth = 32
 
INPUT =     'noise'
pad   =     'reflection'
OPT_OVER =  'net'
KERNEL_TYPE='lanczos2'

LR = 0.01
tv_weight = 0.0

OPTIMIZER = 'adam'

'''
if factor == 4: 
    num_iter = 2000
    reg_noise_std = 0.02
elif factor == 2:
    num_iter = 1500
    reg_noise_std = 0.01
elif factor == 8:
    num_iter = 4000
    reg_noise_std = 0.05
else:
    assert False, 'We did not experiment with other factors'
'''

num_iter = 2000
reg_noise_std = 0

In [None]:
net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()

NET_TYPE = 'skip' # UNet, ResNet
net = get_net(input_depth, 'skip', pad,
              skip_n33d=128, 
              skip_n33u=128, 
              skip_n11=4,
              num_scales=5,
              upsample_mode='bilinear').type(dtype)

# Losses
mse = torch.nn.MSELoss().type(dtype)

#using original code
#img_bbb_pil, img_bbb_np = get_image('/content/corrupt_butterfly.png', -1)
#img_LR_var = np_to_torch(img_bbb_np).type(dtype)

#using corrupted image from DeepRED code
img_LR_var = np_to_torch(imgs['LR_np']).type(dtype)

downsampler = Downsampler(n_planes=3, factor=factor, kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype)

# Define closure and optimize

In [None]:
import torchvision.transforms as transforms

filename_lr = 'dip_psnr_lr_' + picname + '.txt'
filename_hr = 'dip_psnr_hr_' + picname + '.txt'
filename_loss = 'dip_psnr_loss_' + picname + '.txt'
filename_time = 'dip_time_' + picname + '.txt'
print(filename_lr)
def closure():
    global i, net_input, start
    if (i == 1):
      start = time.time()
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)

    out_HR = net(net_input)
    out_LR = downsampler(out_HR)

    total_loss = mse(out_LR, img_LR_var) 
    
    if tv_weight > 0:
        total_loss += tv_weight * tv_loss(out_HR)
        
    total_loss.backward()
    time_consuming = time.time() - start
    # Log

    #测试要计算的中心区域
    #img_HR_crop_np_test = imgs['HR_np'][:,50:210, 120:390]
    #plot_image_grid([np.clip(img_HR_crop_np_test, 0, 1)], factor=13, nrow=1)

    #计算PSNR
    psnr_LR = peak_signal_noise_ratio(imgs['LR_np'], torch_to_np(out_LR))
    psnr_HR = peak_signal_noise_ratio(imgs['HR_np'], torch_to_np(out_HR))
    
    #计算中心裁剪的PSNR
    #img_HR_crop_np = imgs['HR_np'][:,50:210, 120:390]
    #out_HR_npnp = torch_to_np(out_HR)
    #out_HR_crop_np = out_HR_npnp[:,50:210, 120:390]
    #psnr_HR = peak_signal_noise_ratio(img_HR_crop_np, out_HR_crop_np)   #中心裁剪的PSNR
    #效果不好

    #image_pil = np_to_pil(imgs['HR_np'])
    #image_out_HR_pil = np_to_pil(torch_to_np(out_HR))
    #image_crop = crop_image(image_pil, 1)
    #image_out_HR_crop = crop_image(image_out_HR_pil, 1)
    #image_crop_np = pil_to_np(image_crop)
    #out_HR_crop_np = pil_to_np(image_out_HR_crop)
    #psnr_HR = peak_signal_noise_ratio(image_crop_np, out_HR_crop_np)


    print ('Iteration %05d    PSNR_LR %.3f   PSNR_HR %.3f' % (i, psnr_LR, psnr_HR), '\r', end='')

    #plot_image_grid([np.clip(out_HR_np, 0, 1)], factor=13, nrow=1)                
    # History
    psnr_history.append([psnr_LR, psnr_HR])
    
    if PLOT and i % 100 == 0:
        out_HR_np = torch_to_np(out_HR)
        #out_HR_np_crop = out_HR_np[:,100:200, 300:400]
        plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], np.clip(out_HR_np, 0, 1)], factor=13, nrow=3)
        # psrn_noise, psrn_gt, psrn_gt_sm, loss curve
        with open(filename_lr, 'a') as f:
          f.write(str(psnr_LR)+"\t")
          #f.write("\t")
        with open(filename_hr, 'a') as f:
          f.write(str(psnr_HR)+"\t")
          #f.write("\t")
        with open(filename_loss, 'a') as f:
          f.write(str(total_loss.item())+"\t")
        with open(filename_time, 'a') as f:
          f.write(str(time_consuming)+"\t")
          #f.write("\t")
        print (i, "runtime:", time_consuming)
        print ('Iteration %05d    Loss %f   psnr_LR: %f   psnr_HR: %f' % (i, total_loss.item(), psnr_LR, psnr_HR),'\n')
    i += 1
    
    return total_loss

psnr_history = [] 
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

i = 1
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

In [None]:
out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)
result_deep_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])

# For the paper we acually took `_bicubic.png` files from LapSRN viewer and used `result_deep_prior` as our result
plot_image_grid([imgs['HR_np'],
                 imgs['bicubic_np'],
                 out_HR_np], factor=4, nrow=1);

In [None]:
print(i)