Code for **"Blind restoration of a JPEG-compressed image"** and **"Blind image denoising"** figures. Select `fname` below to switch between the two.

- To see overfitting set `num_iter` to a large value.

# Import libs

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import numpy as np
from models.skip2 import skip2
from models.skip3 import skip3
from models.residual import DilatedResidualNetwork
from models import get_net

import torch
import torch.optim

from skimage.measure import compare_psnr
from utils.denoising_utils import *

# torch.backends.cudnn.enabled = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor
# dtype = torch.DoubleTensor

imsize =-1
PLOT = True
sigma = 25
sigma_ = sigma/255.

In [None]:
## deJPEG 
# fname = 'data/denoising/snail.jpg'

## denoising
# fname = '/home/semion/master_thesis/code/data/25/F16_GT.png'
fname = '/home/semion/master_thesis/code/PGDIP/data/ref/house.png'
# fname_noisy = '/home/semion/master_thesis/code/data/25/F16_noisy.png'

# Load image

In [None]:
# Add synthetic noise
img_pil = crop_image(get_image(fname, imsize)[0], d=32)
img_np = pil_to_np(img_pil)

# img_noisy_pil = crop_image(get_image(fname_noisy, imsize)[0], d=32)
# img_noisy_np = pil_to_np(img_noisy_pil)

img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)

if PLOT:
    plot_image_grid([img_np, img_noisy_np], 4, 6);


# Setup

In [None]:
INPUT = 'noise' # 'meshgrid'
pad = 'reflection'
OPT_OVER = 'net' # 'net,input'

reg_noise_std = 1./30. # set to 1./20. for sigma=50
# reg_noise_std = 1./20. # set to 1./20. for sigma=50
LR = 0.01

OPTIMIZER='adam' # 'LBFGS'
show_every = 50


num_iter = 1800
# num_iter = 1
# input_depth = 128 
input_depth = 128
figsize = 4 
exp_weight=0.99

# net = get_net(input_depth, 'skip', pad,
#               skip_n33d=128, 
#               skip_n33u=128, 
#               skip_n11=[0, 4, 4, 4, 4], 
#               num_scales=5,
#               upsample_mode='bilinear').type(dtype)

NSIZE = 5
chdown = [128]*NSIZE
# chdown = [128, 64, 32, 16, 8]
chskip = [4]*NSIZE
# chskip = [0]*NSIZE
# chskip = [0, 0, 0, 0, 4]
# chdown[0] = 128
# chdown[-1] = 128
# chskip[-1] = 4
z_scale = 1e1

ARCH = "base"
# net = DilatedResidualNetwork(input_channels=3,
#                             down_channels=chdown,
#                             down_stride =[2]*NSIZE,
#                             down_dilation = [1]*NSIZE).type(dtype)
# net = skip3(input_depth, 3, num_channels_up = chdown, num_channels_skip=chskip,
net = skip2(input_depth, 3, num_channels_up = chdown, num_channels_skip=chskip,
# net = skip2(input_depth, 3, num_channels_up = [512, 256, 128, 64, 32, 16, 8], num_channels_skip=[4]*NSIZE,
# net = skip2(input_depth, 3, num_channels_up = [10, 20], num_channels_skip=[1, 2],
            upsample_mode='bilinear', need_sigmoid=True, need_bias=True, pad=pad, act_fun="LeakyReLU").type(dtype)

# net_input = get_noise(input_depth, INPUT, (img_pil.size[1]/2**NSIZE, img_pil.size[0]/2**NSIZE)).type(dtype).detach() * z_scale #000.0
# net_input = get_noise(input_depth, INPUT, (img_pil.size[1]/2**2, img_pil.size[0]/2**2)).type(dtype).detach() * 10000 #000.0
net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() * 1 #000.0
# Compute number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]); 
print ('Number of params: %d' % s)
# print(net)
# Loss
mse = torch.nn.MSELoss().type(dtype)

img_noisy_var = np_to_torch(img_noisy_np).type(dtype)
# net_input = img_noisy_var

# Optimize

In [None]:
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()
out_avg = None
last_net = None
psrn_noisy_last = 0

i = 0
psnr = []
def closure():
    
    global i, out_avg, psrn_noisy_last, last_net, psnr, net_input
    
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    
    out = net(net_input)
    
    # Smoothing
    if exp_weight is not None:
        if out_avg is None:
            out_avg = out.detach()
        else:
            out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)
            
    total_loss = mse(out, img_noisy_var)
    total_loss.backward()
        
    
    psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().type(torch.FloatTensor).numpy()[0]) 
    psrn_gt    = compare_psnr(img_np, out.detach().cpu().type(torch.FloatTensor).numpy()[0]) 
    psrn_gt_sm = compare_psnr(img_np, out_avg.detach().cpu().type(torch.FloatTensor).numpy()[0]) 
    psnr.append(psrn_gt)
    print ('Iteration %05d    Loss %f   PSNR_noisy: %f   PSRN_gt: %f PSNR_gt_sm: %f' % (i, total_loss.item(), psrn_noisy, psrn_gt, psrn_gt_sm), '\r', end='')
    if  PLOT and i % show_every == 0:
        out_np = torch_to_np(out)
        plot_image_grid([np.clip(out_np, 0, 1), np.clip(torch_to_np(out_avg), 0, 1)], factor=figsize, nrow=1)
        plt.plot(psnr)
        
    
    # Backtracking
    if i % show_every:
        if psrn_noisy - psrn_noisy_last < -5: 
            print('Falling back to previous checkpoint.')

            for new_param, net_param in zip(last_net, net.parameters()):
                net_param.data.copy_(new_param.cuda())

            return total_loss*0
        else:
            last_net = [x.data.cpu() for x in net.parameters()]
            psrn_noisy_last = psrn_noisy
            
    i += 1

    return total_loss

p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

plt.plot(psnr)
plt.savefig("psnr_{}_{}_{}_{}_{}.png".format(ARCH, chdown, chskip, z_scale, LR))
print(torch.cuda.max_memory_allocated()/1024**2)

In [None]:
out_np = var_to_np(net(net_input))
q = plot_image_grid([np.clip(out_np, 0, 1), img_np], factor=13);

* batch norm, nocudnn: max 31.97, 1800: 31.15
* instance norm, nocudnn: max 32.19, 1800: 32.19


### 