Code for **"Blind restoration of a JPEG-compressed image"** and **"Blind image denoising"** figures. Select `fname` below to switch between the two.

- To see overfitting set `num_iter` to a large value.

# Import libs

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import datetime
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import numpy as np
from models import *
from models.skip_deep import skip_deep

import torch
import torch.optim

from skimage.measure import compare_psnr
from utils.denoising_utils import *

torch.backends.cudnn.enabled = True
# torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor
# dtype = torch.DoubleTensor

imsize =-1
PLOT = True
sigma = 25
sigma_ = sigma/255.

In [None]:
## deJPEG 
# fname = 'data/denoising/snail.jpg'

## denoising
# fname = '/home/semion/master_thesis/code/data/25/F16_GT.png'
# fname = '/home/semion/master_thesis/code/data/25/kodim02_GT.png'
# fname = '/home/semion/master_thesis/code/PGDIP/data/denoising/IM
fname = '/home/semion/master_thesis/code/PGDIP/data/ref/house.png'
# fname = './grid_hex.png'
# fname_noisy = '/home/semion/master_thesis/code/data/25/F16_noisy.png'

# Load image

In [None]:
# Add synthetic noise
orig_img_pil = crop_image(get_image(fname, imsize)[0], d=32)
orig_img_np = pil_to_np(orig_img_pil)

# img_noisy_pil = crop_image(get_image(fname_noisy, imsize)[0], d=32)
# img_noisy_np = pil_to_np(img_noisy_pil)

orig_img_noisy_pil, orig_img_noisy_np = get_noisy_image(orig_img_np, sigma_)

if PLOT:
    plot_image_grid([orig_img_np, orig_img_noisy_np], 4, 6);


# Setup

In [None]:
INPUT = 'noise' # 'meshgrid'
pad = 'reflection'
OPT_OVER = 'net' # 'net,input'

# reg_noise_std = 1./300. # set to 1./20. for sigma=50
reg_noise_std = 1./30. # set to 1./20. for sigma=50
LR = 0.01

OPTIMIZER='adam' # 'LBFGS'
show_every = 100


num_iter = 4000
# num_iter = 1
input_depth = 32 
figsize = 12 
exp_weight=0.9
scales = 5

# patch_size = 128
patch_size = 128
patch_stride = 16
# fade_size = 8

# make sure patch size and stride match image dimensions
assert float(int((orig_img_np.shape[1] - patch_size)/patch_stride)) == (orig_img_np.shape[1] - patch_size)/patch_stride
assert float(int((orig_img_np.shape[2] - patch_size)/patch_stride)) == (orig_img_np.shape[2] - patch_size)/patch_stride

patch_x_n = int((orig_img_np.shape[1] - patch_size)/patch_stride + 1)
patch_y_n = int((orig_img_np.shape[2] - patch_size)/patch_stride + 1)

patches = []
for x in range(0, patch_x_n):
    patch_row = []
    for y in range(0, patch_y_n):
        start_time = datetime.datetime.now()
#         x_start = max(x * patch_stride - fade_size, 0)
#         x_end = min(patch_size + x * patch_stride + fade_size, orig_img_np.shape[1])
#         y_start = max(y * patch_stride - fade_size, 0)
#         y_end = min(patch_size + y * patch_stride + fade_size, orig_img_np.shape[2])
        x_start = x * patch_stride
        x_end = patch_size + x * patch_stride
        y_start = y * patch_stride
        y_end = y * patch_stride + patch_size
        
        img_noisy_np = orig_img_noisy_np[:,x_start:x_end, y_start:y_end]
        img_np = orig_img_np[:,x_start:x_end, y_start:y_end]
        
        # plot_image_grid(patches, nrow = 9)
        depth = 32
        net = skip_deep(input_depth, 3, num_channels_down = [8]*depth,
                                                    num_channels_up =   [8]*depth,
                                                    num_channels_skip = [4]*depth,
                                                    filter_size_down=3,
                                                    filter_size_up=3,
                                                    upsample_mode='bilinear', downsample_mode='stride',
                                                    need_sigmoid=True, need_bias=True, pad=pad, act_fun="LeakyReLU").type(dtype)

        net_input = get_noise(input_depth, INPUT, (img_np.shape[1], img_np.shape[2])).type(dtype).detach() * 1 #000.0
        # net_input = img_noisy_var
        # Compute number of parameters
        s  = sum([np.prod(list(p.size())) for p in net.parameters()]); 
        print ('Number of params: %d' % s)

        # Loss
        mse = torch.nn.MSELoss().type(dtype)

        img_noisy_var = np_to_torch(img_noisy_np).type(dtype)
        # print(net)

        # Optimize

        net_input_saved = net_input.detach().clone()
        noise = net_input.detach().clone()
        out_avg = None
        last_net = None
        psrn_noisy_last = 0

        i = 0
        psnr = []
        def closure():

            global i, out_avg, psrn_noisy_last, last_net, psnr

            if reg_noise_std > 0:
                net_input = net_input_saved + (noise.normal_() * reg_noise_std)

            out = net(net_input)

            # Smoothing
            if exp_weight is not None:
                if out_avg is None:
                    out_avg = out.detach()
                else:
                    out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

            total_loss = mse(out, img_noisy_var)
            total_loss.backward()


            psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().type(torch.FloatTensor).numpy()[0]) 
            psrn_gt    = compare_psnr(img_np, out.detach().cpu().type(torch.FloatTensor).numpy()[0]) 
            psrn_gt_sm = compare_psnr(img_np, out_avg.detach().cpu().type(torch.FloatTensor).numpy()[0]) 
            psnr.append(psrn_gt)
            print ('Iteration %05d    Loss %f   PSNR_noisy: %f   PSRN_gt: %f PSNR_gt_sm: %f' % (i, total_loss.item(), psrn_noisy, psrn_gt, psrn_gt_sm), '\r', end='')
            if  PLOT and ((i % show_every == 0) or (i == num_iter - 1)):
                out_np = torch_to_np(out)
                plot_image_grid([
                    np.clip(out_np, 0, 1), np.clip(torch_to_np(out_avg), 0, 1),
                    np.clip(img_noisy_np, 0, 1), np.clip(img_np, 0, 1)
                ], factor=figsize, nrow=4, interpolation=None)
                plt.plot(psnr)


            # Backtracking
            if i % show_every:
                if psrn_noisy - psrn_noisy_last < -5: 
                    print('Falling back to previous checkpoint.')

                    for new_param, net_param in zip(last_net, net.parameters()):
                        net_param.data.copy_(new_param.cuda())

                    return total_loss*0
                else:
                    last_net = [x.data.cpu() for x in net.parameters()]
                    psrn_noisy_last = psrn_noisy

            i += 1

            return total_loss

        p = get_params(OPT_OVER, net, net_input)
        optimize(OPTIMIZER, p, closure, LR, num_iter)
        print("finished, time: {}".format(datetime.datetime.now() - start_time))
        patch_row.append(out_avg)
    patches.append(patch_row)



In [None]:
out = torch.zeros(1, 3, orig_img_np.shape[1], orig_img_np.shape[2]).type(dtype)
out_counter = torch.zeros(1, 3, orig_img_np.shape[1], orig_img_np.shape[2]).type(dtype)

fade_1d = torch.linspace(0.0, 1.0, patch_stride).type(dtype)
rev_fade_1d = torch.linspace(1.0, 0.0, patch_stride).type(dtype)
print(fade_1d.shape)
x_fade_2d = fade_1d.unsqueeze(1).expand(1, 3, patch_stride, patch_size)
x_rev_fade_2d = rev_fade_1d.unsqueeze(1).expand(1, 3, patch_stride, patch_size)
print(x_fade_2d.shape)
y_fade_2d = fade_1d.expand(1, 3, patch_size, patch_stride)
y_rev_fade_2d = rev_fade_1d.expand(1, 3, patch_size, patch_stride)

for x in range(patch_x_n):
    for y in range(patch_y_n):
        x_start = x * patch_stride
        x_end = patch_size + x * patch_stride
        y_start = y * patch_stride
        y_end = y * patch_stride
        
        fade_matrix = torch.ones(*patches[x][y].shape).type(dtype)
        if x != 0:
            fade_matrix[:,:,:patch_stride,:] *= x_fade_2d
        if x != patch_x_n - 1:
            fade_matrix[:,:,-patch_stride:,:] *= x_rev_fade_2d
        if y != 0:
            fade_matrix[:,:,:,:patch_stride] *= y_fade_2d
        if y != patch_y_n - 1:
            fade_matrix[:,:,:,-patch_stride:] *= y_rev_fade_2d
        
        
        out[:,:,x * patch_stride:patch_size + x*patch_stride,y * patch_stride:patch_size + y*patch_stride] += fade_matrix * patches[x][y]
        out_counter[:,:,x * patch_stride:patch_size + x*patch_stride,y * patch_stride:patch_size + y*patch_stride] += fade_matrix
out /= out_counter

print(torch.cuda.max_memory_allocated()/1024**2)
# out_np = var_to_np(net(net_input))
q = plot_image_grid([np.clip(torch_to_np(out), 0, 1)], factor=8);
compare_psnr(orig_img_np, torch_to_np(out))

* batch norm, nocudnn: max 31.97, 1800: 31.15
* instance norm, nocudnn: max 32.19, 1800: 32.19


### 

In [None]:
img_noisy_var - out_avg
q = plot_image_grid([
    img_np,
    img_noisy_np,
    (img_noisy_np-img_np)*1,
#     img_noisy_np,
    np.clip(torch_to_np(out_avg), 0, 1),
    np.clip((img_np - torch_to_np(out_avg))*1, 0, 1),
    np.clip((img_noisy_np - torch_to_np(out_avg))*1, 0, 1),
    ], factor=13, nrow=3)

In [None]:
# fade_1d = torch.linspace(0.0, 1.0, patch_stride)
# print(fade_1d.expand(1, 3, 4, 16))
# print(fade_1d.unsqueeze(1).expand(1, 3, 16, 4))
fade_1d = torch.linspace(0.0, 1.0, patch_stride)
rev_fade_1d = torch.linspace(1.0, 0.0, patch_stride)
print(fade_1d.shape)
x_fade_2d = fade_1d.expand(1, 3, 16, 4)
# x_rev_fade_2d = rev_fade_1d.expand(1, 3, patch_stride, patch_size)