# **Deblurring DIPWTV**

---

This code is mainly based on DeepRED code available at https://github.com/GaryMataev/DeepRED changing the regularization term as in ADMM-DIPTV code available at https://github.com/sedaboni/ADMM-DIPTV 

# Import libs

In [None]:
import os
from threading import Thread  # for running the denoiser in parallel
import queue

import numpy as np
import torch
import torch.optim
from models.skip import skip  # our network

from utils.utils import *  # auxiliary functions
from utils.utils_mine import *
from utils.mine_blur_utils2 import *  
from utils.data import Data  # class that holds img, psnr, time

from skimage.restoration import denoise_nl_means
import random 

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

from scipy.signal import convolve2d



In [None]:
# got GPU? - if you are not getting the exact article results set CUDNN to False
CUDA_FLAG = True
CUDNN = True 
if CUDA_FLAG:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    # GPU accelerated functionality for common operations in deep neural nets
    torch.backends.cudnn.enabled = CUDNN
    # benchmark mode is good whenever your input sizes for your network do not vary.
    # This way, cudnn will look for the optimal set of algorithms for that particular 
    # configuration (which takes some time). This usually leads to faster runtime.
    # But if your input sizes changes at each iteration, then cudnn will benchmark every
    # time a new size appears, possibly leading to worse runtime performances.
    torch.backends.cudnn.benchmark = CUDNN
    # torch.backends.cudnn.deterministic = True
    dtype = torch.cuda.FloatTensor
else:
    dtype = torch.FloatTensor

# CONSTANCTS

In [None]:
NOISE_SIGMA = 5 #2**.5  # sqrt(2), I haven't tests other options
STD_BLUR    = 1.6
DIM_FILTER  = 21
BLUR_TYPE = 'gauss_blur'  # 'gauss_blur' or 'uniform_blur' that the two only options
GRAY_SCALE = False  # if gray scale is False means we have rgb image, the psnr will be compared on Y. ch.
                    # if gray scale is True it will turn rgb to gray scale
USE_FOURIER = False

# graphs labels:
X_LABELS = ['Iterations']*3
Y_LABELS = ['PSNR between x and net (db)', 'PSNR with original image (db)', 'loss']

# Algorithm NAMES (to get the relevant image: use data_dict[alg_name].img)
# for example use data_dict['Clean'].img to get the clean image
ORIGINAL = 'Clean'
CORRUPTED = 'Blurred'
NLM = 'NLM'
DIP_NLM = 'DIP-WTV'

# Load image for Denoising

In [None]:
def load_imgs_deblurring(fname, blur_type, noise_sigma,STD_BLUR, DIM_FILTER,plot=False):
    """  Loads an image, and add gaussian blur
    Args: 
         fname: path to the image
         blur_type: 'uniform' or 'gauss'
         noise_sigma: noise added after blur
         covert2gray: should we convert to gray scale image?
         plot: will plot the images
    Out:
         dictionary of images and dictionary of psnrs
    """
    img_pil, img_np = load_and_crop_image(fname)        
    if GRAY_SCALE:
        img_np = rgb2gray(img_pil)
    kernel = get_h(blur_type,STD_BLUR,DIM_FILTER)
    kernel_torch = np_to_torch(kernel)  
    blurred = torch_to_np(blur_th(np_to_torch(img_np), kernel_torch))
    blurred = np.clip(blurred + np.random.normal(scale=noise_sigma/255., size=blurred.shape), 0, 1).astype(np.float32)
    data_dict = { ORIGINAL: Data(img_np), 
                 CORRUPTED: Data(blurred, compare_PSNR(img_np, blurred,   on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)) }
    if plot:
        plot_dict(data_dict)
    return data_dict,kernel_torch

In [None]:
# load the image and add noise - for real use send same image file to fclean and fnoisy and ignore psnrs
data_dict,kernel_torch = load_imgs_deblurring('datasets/watercastle.png', BLUR_TYPE, NOISE_SIGMA,STD_BLUR, DIM_FILTER, plot=True)

#  ESTIMATING THE NOISE

In [None]:
lap_kernel = np.array([[1,-2,1], [-2, 4, -2], [1,-2,1]])
h=data_dict[CORRUPTED].img[:,:,:].shape[2]
w=data_dict[CORRUPTED].img[:,:,:].shape[1]

def estimate_variance(img):
  out = convolve2d(img, lap_kernel, mode='valid')
  out = np.sum(np.abs(out))
  out = (out*np.sqrt(0.5*np.pi)/(6*(h-2)*(w-2)))
  return out

print(data_dict[CORRUPTED].img[:,:,:].shape)
NOISE_SIGMA = estimate_variance(data_dict[CORRUPTED].img[0,:,:])*255
print(NOISE_SIGMA)

# THE NETWORK

In [None]:
def get_network_and_input(img_shape, input_depth=32, pad='reflection',
                          upsample_mode='bilinear', use_interpolate=True, align_corners=False,
                          act_fun='LeakyReLU', skip_n33d=[16,32,64,128,128], skip_n33u=[16,32,64,128,128], skip_n11=4,
                          num_scales=5, downsample_mode='stride', INPUT='noise'):  # 'meshgrid'
    """ Getting the relevant network and network input (based on the image shape and input depth)
        We are using the same default params as in DIP article
        img_shape - the image shape (ch, x, y)
    """
    n_channels = img_shape[0]
    net = skip(input_depth, n_channels,
               num_channels_down=[skip_n33d] * num_scales if isinstance(skip_n33d, int) else skip_n33d,
               num_channels_up=[skip_n33u] * num_scales if isinstance(skip_n33u, int) else skip_n33u,
               num_channels_skip=[skip_n11] * num_scales if isinstance(skip_n11, int) else skip_n11,
               upsample_mode=upsample_mode, use_interpolate=use_interpolate, align_corners=align_corners,
               downsample_mode=downsample_mode, need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun).type(dtype)
    net_input = get_noise(input_depth, INPUT, img_shape[1:]).type(dtype).detach()
    return net, net_input

In [None]:
size = data_dict['Clean'].img.shape
h = size[-2]
w = size[-1]
Dh_psf = np.array([ [0, 0, 0], [1, -1, 0], [0, 0, 0]])
Dv_psf = np.array([ [0, 1, 0], [0, -1, 0], [0, 0, 0]])
Id_psf = np.array([[1]])

Id_DFT = torch.from_numpy(psf2otf(Id_psf, [h,w])).cuda()
Dh_DFT = torch.from_numpy(psf2otf(Dh_psf, [h,w])).cuda()
Dv_DFT = torch.from_numpy(psf2otf(Dv_psf, [h,w])).cuda()

DhT_DFT = torch.conj(Dh_DFT)
DvT_DFT = torch.conj(Dv_DFT)

# Deep Image prior via ADMM with weighted TV

In [None]:
def train_via_admm(net, net_input, kernel_torch,y,  noise_lev,tau, org_img=None,                      # y is the noisy image
                   plot_array={}, algorithm_name="", admm_iter=5000, save_path="",           # path to save params
                   LR=0.001,                                                                      # learning rate
                   mu=0.0008, LR_x=None, noise_factor=0.033,        #0.033  LR_x needed only if method!=fixed_point
                   threshold=40, threshold_step=0.01, increase_reg=0.033):                # increase regularization 
    """ training the network using
        ## Must Params ##
        net                 - the network to be trained
        net_input           - the network input
        denoiser_function   - an external denoiser function, used as black box, this function
                              must get numpy noisy image, and return numpy denoised image
        y                   - the noisy image
        sigma               - the noise level (int 0-255)
        
        # optional params #
        org_img             - the original image if exist for psnr compare only, or None (default)
        plot_array          - prints params at the begging of the training and plot images at the required indices
        admm_iter           - total number of admm epoch
        LR                  - the lr of the network in admm (step 2)
        sigma_f             - the sigma to send the denoiser function
        update_iter         - denoised image updated every 'update_iter' iteration
        method              - 'fixed_point' or 'grad' or 'mixed' 
        algorithm_name      - the name that would show up while running, just to know what we are running ;)
                
        # equation params #  
        beta                - regularization parameter (lambda in the article)
        mu                  - ADMM parameter
        LR_x                - learning rate of the parameter x, needed only if method!=fixed point
        # more
        noise_factor       - the amount of noise added to the input of the network
        threshold          - when the image become close to the noisy image at this psnr
        increase_reg       - we going to increase regularization by this amount
        threshold_step     - and keep increasing it every step
    """
    # To print
    list_psnr=[]
    list_stopping=[]

    # get optimizer and loss function:
    mse = torch.nn.MSELoss().type(dtype)  # using MSE loss
    # additional noise added to the input:
    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()
    if org_img is not None: 
        psnr_y = compare_PSNR(org_img, y,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)  # get the noisy image psnr
    # optimizer and scheduler
    optimizer = torch.optim.Adam(net.parameters(), lr=LR)  # using ADAM opt
    
    y_torch = np_to_torch(y).type(dtype)
    x, mu_h, mu_v = y.copy(), np.zeros_like(y), np.zeros_like(y)
    t_h, t_v = np.zeros_like(y), np.zeros_like(y)
    f_x, avg, avg2, avg3 = x.copy(), np.rint(y), np.rint(y), np.rint(y)
    img_queue = queue.Queue()
    
    #inner_iter=1

    for i in range(1, 1 + admm_iter):
        
      rho = tau*noise_lev*np.sqrt(y.shape[0]*y.shape[1]*y.shape[2] - 1)

      # step 1, update network:
      optimizer.zero_grad()
      net_input = net_input_saved + (noise.normal_() * noise_factor)
      out = net(net_input)
      out_np = torch_to_np(out)
              
      # loss:
      [Dh_out, Dv_out] = D(out, Dh_DFT, Dv_DFT) #computing the gradient
      Dh_out_np        = torch_to_np(Dh_out)
      Dv_out_np        = torch_to_np(Dv_out)
      loss_y = mse(blur_th(out, kernel_torch), y_torch)
      loss_x = mse(Dh_out.type(dtype), np_to_torch(t_h - mu_h).type(dtype)) + mse(Dv_out.type(dtype), np_to_torch(t_v - mu_v).type(dtype))
      total_loss = loss_y + mu * loss_x
      total_loss.backward()
      optimizer.step()
          
      # step 2, update x using a denoiser and result from step 1 
      q_h                 = Dh_out_np + mu_h
      q_v                 = Dv_out_np + mu_v
      q_norm              = np.sqrt(np.power(q_h,2) + np.power(q_v,2))
      weight              = np.divide(np.power(np.linalg.norm(out_np-y),2)/(6*h*w),q_norm)
      q_norm[q_norm == 0] = weight[q_norm == 0]/mu
      q_norm              = np.clip(q_norm - weight/mu , 0, q_norm - weight/mu)/q_norm
      t_h                 = (q_norm*q_h)
      t_v                 = (q_norm*q_v)

      np.clip(t_h, -1, 1, out=t_h)
      np.clip(t_v, -1, 1, out=t_v)

      # step 3, update u
      mu_h = (mu_h + (Dh_out_np - t_h))
      mu_v = (mu_v + (Dv_out_np - t_v))

      # Averaging:
      avg = avg * .99 + out_np * .01

      stopping = np.sqrt(np.sum(np.square(torch_to_np(blur_th(out.data, kernel_torch))-y)))/ rho 
      list_stopping.append(stopping)
        
      # show psnrs: 
      psnr_noisy = compare_PSNR(out_np, y,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
        
      if psnr_noisy > threshold:
          mu = mu + increase_reg
          threshold += threshold_step
        
      if org_img is not None:
          psnr_net, psnr_avg = compare_PSNR(org_img, out_np,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE), compare_PSNR(org_img, avg, on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
          list_psnr.append(psnr_avg)
          print('\r', algorithm_name, '%04d/%04d Loss %f' % (i, admm_iter, total_loss.item()),
                  'psnrs: y: %.2f psnr_noisy: %.2f net: %.2f avg: %.2f' % (psnr_y, psnr_noisy, psnr_net, psnr_avg), 
                  'params: stopping: %.2f' %(stopping), end='')
          if i in plot_array:  # plot images
              tmp_dict = {'Clean': Data(org_img),
                          'Noisy': Data(y, psnr_y),
                          'Net': Data(out_np, psnr_net),
                          'avg': Data(avg, psnr_avg),
                          }
              plot_dict(tmp_dict)
      else:
          print('\r', algorithm_name, 'iteration %04d/%04d Loss %f' % (i, admm_iter, total_loss.item()), end='')
  
    return avg,list_psnr,list_stopping

## Let's Go:

In [None]:
def run_and_plot(name, plot_checkpoints={}):
    global data_dict
    noise_lev = NOISE_SIGMA/255
    tau=1 #lasciare a 1 se ci si fida della stima del rumore fatta dalla funzione considerata
    net, net_input = get_network_and_input(img_shape=data_dict[CORRUPTED].img.shape)
    denoised_img,list_psnr,list_stopping = train_via_admm(net, net_input, kernel_torch,data_dict[CORRUPTED].img, noise_lev,tau,
                                  plot_array=plot_checkpoints, algorithm_name=name,
                                  org_img=data_dict[ORIGINAL].img)
    data_dict[name] = Data(denoised_img, compare_PSNR(data_dict[ORIGINAL].img, denoised_img,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE))
    plot_dict(data_dict)

    return denoised_img,list_psnr,list_stopping


plot_checkpoints = {1, 10, 50, 100, 250, 500, 2000, 3500, 5000} 
denoised_img,list_psnr,list_stopping=run_and_plot(DIP_NLM, plot_checkpoints)  # you may try it with different denoisers