# **Deblurring DIP**

---

This code is mainly based on DeepRED code available at https://github.com/GaryMataev/DeepRED removing the regularization term

# Import libs

In [None]:
import os
from threading import Thread  # needed since the denoiser is running in parallel
import queue

import numpy as np
import torch
import torch.optim
from models.skip import skip  # our network

from utils.utils import *  # auxiliary functions
from utils.mine_blur_utils2 import *  # blur functions
from utils.data import Data  # class that holds img, psnr, time

from skimage.restoration import denoise_nl_means

from scipy.signal import convolve2d

In [None]:
# got GPU? - if you are not getting the exact article results set CUDNN to False
CUDA_FLAG = True
CUDNN = True 
if CUDA_FLAG:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    # GPU accelerated functionality for common operations in deep neural nets
    torch.backends.cudnn.enabled = CUDNN
    # benchmark mode is good whenever your input sizes for your network do not vary.
    # This way, cudnn will look for the optimal set of algorithms for that particular 
    # configuration (which takes some time). This usually leads to faster runtime.
    # But if your input sizes changes at each iteration, then cudnn will benchmark every
    # time a new size appears, possibly leading to worse runtime performances.
    torch.backends.cudnn.benchmark = CUDNN
    # torch.backends.cudnn.deterministic = True
    dtype = torch.cuda.FloatTensor
else:
    dtype = torch.FloatTensor

# Constants

In [None]:
NOISE_SIGMA = 10
STD_BLUR    = 0.8
DIM_FILTER  = 11
BLUR_TYPE = 'gauss_blur'  # 'gauss_blur' or 'uniform_blur' that the two only options
GRAY_SCALE = True  # if gray scale is False means we have rgb image, the psnr will be compared on Y. ch.
                    # if gray scale is True it will turn rgb to gray scale
USE_FOURIER = False

# graphs labels:
X_LABELS = ['Iterations']*3
Y_LABELS = ['PSNR between x and net (db)', 'PSNR with original image (db)', 'loss']

# Algorithm NAMES (to get the relevant image: use data_dict[alg_name].img)
# for example use data_dict['Clean'].img to get the clean image
ORIGINAL  = 'Clean'
CORRUPTED = 'Blurred'
DIP_NLM   = 'DIP'

# Load image for DeBlurring

In [None]:
def load_imgs_deblurring(fname, blur_type, noise_sigma,STD_BLUR, DIM_FILTER,plot=False):
    """  Loads an image, and add gaussian blur
    Args: 
         fname: path to the image
         blur_type: 'uniform' or 'gauss'
         noise_sigma: noise added after blur
         covert2gray: should we convert to gray scale image?
         plot: will plot the images
    Out:
         dictionary of images and dictionary of psnrs
    """
    img_pil, img_np = load_and_crop_image(fname)        
    if GRAY_SCALE:
        img_np = rgb2gray(img_pil)
    kernel = get_h(blur_type,STD_BLUR,DIM_FILTER)
    kernel_torch = np_to_torch(kernel)  
    blurred = torch_to_np(blur_th(np_to_torch(img_np), kernel_torch))
    blurred = np.clip(blurred + np.random.normal(scale=noise_sigma/255., size=blurred.shape), 0, 1).astype(np.float32)
    data_dict = { ORIGINAL: Data(img_np), 
                 CORRUPTED: Data(blurred, compare_PSNR(img_np, blurred,   on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)) }
    if plot:
        plot_dict(data_dict)
    return data_dict,kernel_torch

In [None]:
# Get the LR and HR images
data_dict,kernel_torch = load_imgs_deblurring('datasets/skyscraper.jpeg', BLUR_TYPE, NOISE_SIGMA,STD_BLUR, DIM_FILTER, plot=True)

In [None]:
data_dict['Blurred'].img.shape

#  ESTIMATING THE NOISE

In [None]:
lap_kernel = np.array([[1,-2,1], [-2, 4, -2], [1,-2,1]])
h=data_dict[CORRUPTED].img[:,:,:].shape[2]
w=data_dict[CORRUPTED].img[:,:,:].shape[1]

def estimate_variance(img):
  out = convolve2d(img, lap_kernel, mode='valid')
  out = np.sum(np.abs(out))
  out = (out*np.sqrt(0.5*np.pi)/(6*(h-2)*(w-2)))
  return out

print(data_dict[CORRUPTED].img[:,:,:].shape)
NOISE_SIGMA = estimate_variance(data_dict[CORRUPTED].img[0,:,:])*255
print(NOISE_SIGMA)

# THE NETWORK

In [None]:
def get_network_and_input(img_shape, input_depth=32, pad='reflection',
                          upsample_mode='bilinear', use_interpolate=True, align_corners=False,
                          act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4,
                          num_scales=5, downsample_mode='stride', INPUT='noise'):  # 'meshgrid'
    """ Getting the relevant network and network input (based on the image shape and input depth)
        We are using the same default params as in DIP article
        img_shape - the image shape (ch, x, y)
    """
    n_channels = img_shape[0]
    net = skip(input_depth, n_channels,
               num_channels_down=[skip_n33d] * num_scales if isinstance(skip_n33d, int) else skip_n33d,
               num_channels_up=[skip_n33u] * num_scales if isinstance(skip_n33u, int) else skip_n33u,
               num_channels_skip=[skip_n11] * num_scales if isinstance(skip_n11, int) else skip_n11,
               upsample_mode=upsample_mode, use_interpolate=use_interpolate, align_corners=align_corners,
               downsample_mode=downsample_mode, need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun).type(dtype)
    net_input = get_noise(input_depth, INPUT, img_shape[1:]).type(dtype).detach()
    return net, net_input

# Deep Image Prior

In [None]:
def train_via_admm(net, net_input, kernel_torch, y, tau, noise_lev,             
                   clean_img=None, plot_array={}, algorithm_name="",             
                   save_path="", admm_iter=1400, LR=0.004, noise_factor=0.01):  

    """ training the network using
        ## Must Params ##
        net                 - the network to be trained
        net_input           - the network input
        denoiser_function   - an external denoiser function, used as black box, this function
                              must get numpy noisy image, and return numpy denoised image
        H                   - the blur kernel
        y                   - the blurred image
        
        # optional params #
        clean_img           - the original image if exist for psnr compare only, or None (default)
        plot_array          - prints params at the begging of the training and plot images at the required indices
        algorithm_name      - the name that would show up while running, just to know what we are running ;)
        admm_iter           - total number of admm epoch
        LR                  - the lr of the network
        sigma_f             - the sigma to send the denoiser function
        update_iter         - denoised image updated every 'update_iter' iteration
        method              - 'fixed_point' or 'grad' or 'mixed' 
                
        # equation params #  
        beta                - regularization parameter (lambda in the article)
        mu                  - ADMM parameter
        LR_x                - learning rate of the parameter x, needed only if method!=fixed point
        # more
        noise_factor       - the amount of noise added to the input of the network
    """
    
    # get optimizer and loss function:
    mse = torch.nn.MSELoss().type(dtype)  # using MSE loss
    
    # additional noise added to the input:
    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()

    # run RED via ADMM, initialize:
    optimizer = torch.optim.Adam(net.parameters(), lr=LR)  # using ADAM opt
    y_torch = np_to_torch(y).type(dtype)
    avg = np.rint(y)
    list_psnr=[]
    list_stopping=[]
   
    
    # ADMM:
    for i in range(1, 1 + admm_iter):

        rho = tau*noise_lev*np.sqrt(y.shape[0]*y.shape[1]*y.shape[2] - 1) 
        
        # step 1, update network:
        optimizer.zero_grad()
        net_input = net_input_saved + (noise.normal_() * noise_factor)
        out = net(net_input)
        out_np = torch_to_np(out)
        # loss:
        loss_y = mse(blur_th(out, kernel_torch), y_torch)
        total_loss = loss_y 
        total_loss.backward()
        optimizer.step()

        # Averaging:
        avg = avg * .99 + out_np * .01

        # show psnrs:
        psnr_noisy = compare_PSNR(out_np, y,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)

        #stopping = np.sqrt(np.sum(np.square(out_np-y)))/ rho
        stopping = np.sqrt(np.sum(np.square(torch_to_np(blur_th(out.data, kernel_torch))-y)))/ rho 
        list_stopping.append(stopping)

        # show psnrs:
        if clean_img is not None:
            psnr_net = compare_PSNR(clean_img, out_np, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
            psnr_avg = compare_PSNR(clean_img, avg, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
            list_psnr.append(psnr_avg)
            print('\r', algorithm_name, '%04d/%04d Loss %f' % (i, admm_iter, total_loss.item()),
                  'psnrs: net: %.2f avg: %.2f' % (psnr_net, psnr_avg), 'stopping: %.2f' %(stopping), end='')
            if i in plot_array:  # plot images
              tmp_dict = {'Clean': Data(clean_img),
                          'Blurred': Data(y),
                          'Net': Data(out_np, psnr_net),
                          'avg': Data(avg, psnr_avg)
                          }
              plot_dict(tmp_dict)
        else:
            print('\r', algorithm_name, 'iteration %04d/%04d Loss %f' % (i, admm_iter, total_loss.item()), end='')

    return avg, list_psnr,list_stopping

## Let's Go:

In [10]:
def run_and_plot(name, plot_checkpoints={}):
    global data_dict
    noise_lev = NOISE_SIGMA/255
    tau=1
    net, net_input = get_network_and_input(img_shape=data_dict[CORRUPTED].img.shape)
    clean,list_psnr,list_stopping = train_via_admm(net, net_input, kernel_torch, data_dict[CORRUPTED].img, tau, noise_lev, 
                           algorithm_name=name, plot_array=plot_checkpoints,
                           clean_img=data_dict[ORIGINAL].img)
    data_dict[name] = Data(clean, compare_PSNR(data_dict[ORIGINAL].img, clean, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE))
    plot_dict(data_dict)

    return clean,list_psnr,list_stopping


plot_checkpoints = {1, 10, 100, 200, 400, 600,800, 1000, 1200,1400,1600,1800,2000,2200,2400,2600,2800,3000, 5000, 10000, 20000}
clean,list_psnr,list_stopping = run_and_plot(DIP_NLM, plot_checkpoints)