# **Denoising constrained DRED PGDA**

---

This code is mainly based on DeepRED code available at https://github.com/GaryMataev/DeepRED

This notebook is the implementation of the PGDA version of the following paper: 

**Constrained and unconstrained deep image prior optimization models with automatic regularization** by *Pasquale Cascarano, Giorgia Franchini, Erich Kobler, Federica Porta and Andrea Sebastiani*

# Import libs

In [None]:
import os
from threading import Thread  # for running the denoiser in parallel
import queue

import numpy as np
import torch
import torch.optim
from models.skip import skip  # our network

from utils.utils import *  # auxiliary functions
from utils.data import Data  # class that holds img, psnr, time

from skimage.restoration import denoise_nl_means

from scipy.signal import convolve2d

In [None]:
# got GPU? - if you are not getting the exact article results set CUDNN to False
CUDA_FLAG = True
CUDNN = True 
if CUDA_FLAG:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    # GPU accelerated functionality for common operations in deep neural nets
    torch.backends.cudnn.enabled = CUDNN
    # benchmark mode is good whenever your input sizes for your network do not vary.
    # This way, cudnn will look for the optimal set of algorithms for that particular 
    # configuration (which takes some time). This usually leads to faster runtime.
    # But if your input sizes changes at each iteration, then cudnn will benchmark every
    # time a new size appears, possibly leading to worse runtime performances.
    torch.backends.cudnn.benchmark = CUDNN
    # torch.backends.cudnn.deterministic = True
    dtype = torch.cuda.FloatTensor
else:
    dtype = torch.FloatTensor

# CONSTANCTS

In [None]:
SIGMA = 35
GRAY_SCALE = False        # if gray scale is False means we have rgb image, the psnr will be compared on Y. ch.
                          # if gray scale is True it will turn rgb to gray scale
# graphs labels:
X_LABELS = ['Iterations']*3
Y_LABELS = ['PSNR between x and net (db)', 'PSNR with original image (db)', 'loss']

# Algorithm NAMES (to get the relevant image: use data_dict[alg_name].img)
# for example use data_dict['Clean'].img to get the clean image
ORIGINAL = 'Clean'
CORRUPTED = 'Noisy'
NLM = 'NLM'
DIP_NLM = 'CDRED (NLM)'

# Load image for Denoising

In [None]:
def load_image(fclean, fnoisy=None, sigma=25, plot=False):
    """ 
        fname - input file name
        d - Make dimensions divisible by `d`
        sigma - the amount of noise you want to add noise to the image
        Return a numpy image, and a noisy numpy image with sigma selected
    """
    _, img_np = load_and_crop_image(fclean)
    if fnoisy is None:
        img_noisy_np = np.clip(img_np + np.random.normal(scale=sigma / 255., size=img_np.shape), 0, 1).astype(
            np.float32)
        # img_noisy_np = pil_to_np(np_to_pil(img_noisy_np)) # making it an image then loading it back to numpy
    else:
        _, img_noisy_np = load_and_crop_image(fnoisy)
    data_dict = {ORIGINAL: Data(img_np), CORRUPTED: Data(img_noisy_np, compare_PSNR(img_np, img_noisy_np,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE))}
    if plot:
        plot_dict(data_dict)
    return data_dict

In [None]:
# load the image and add noise - for real use send same image file to fclean and fnoisy and ignore psnrs
data_dict = load_image('datasets/Set5/woman_GT.bmp', sigma=SIGMA, plot=True)

#  ESTIMATING THE NOISE

In [None]:
lap_kernel = np.array([[1,-2,1], [-2, 4, -2], [1,-2,1]])
h=data_dict[CORRUPTED].img[:,:,:].shape[2]
w=data_dict[CORRUPTED].img[:,:,:].shape[1]

def estimate_variance(img):
  out = convolve2d(img, lap_kernel, mode='valid')
  out = np.sum(np.abs(out))
  out = (out*np.sqrt(0.5*np.pi)/(6*(h-2)*(w-2)))
  return out

print(data_dict[CORRUPTED].img[:,:,:].shape)
NOISE_SIGMA = estimate_variance(data_dict[CORRUPTED].img[0,:,:])*255
print(NOISE_SIGMA)

# THE NETWORK

In [None]:
def get_network_and_input(img_shape, input_depth=32, pad='reflection',
                          upsample_mode='bilinear', use_interpolate=True, align_corners=False,
                          act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4,
                          num_scales=5, downsample_mode='stride', INPUT='noise'):  # 'meshgrid'
    """ Getting the relevant network and network input (based on the image shape and input depth)
        We are using the same default params as in DIP article
        img_shape - the image shape (ch, x, y)
    """
    n_channels = img_shape[0]
    net = skip(input_depth, n_channels,
               num_channels_down=[skip_n33d] * num_scales if isinstance(skip_n33d, int) else skip_n33d,
               num_channels_up=[skip_n33u] * num_scales if isinstance(skip_n33u, int) else skip_n33u,
               num_channels_skip=[skip_n11] * num_scales if isinstance(skip_n11, int) else skip_n11,
               upsample_mode=upsample_mode, use_interpolate=use_interpolate, align_corners=align_corners,
               downsample_mode=downsample_mode, need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun).type(dtype)
    net_input = get_noise(input_depth, INPUT, img_shape[1:]).type(dtype).detach()
    return net, net_input

## The Non Local Means denoiser

In [None]:
def non_local_means(noisy_np_img, sigma, fast_mode=True):
    """ get a numpy noisy image
        returns a denoised numpy image using Non-Local-Means
    """ 
    sigma = sigma / 255.
    h = 0.6 * sigma if fast_mode else 0.8 * sigma
    patch_kw = dict(h=h,                   # Cut-off distance, a higher h results in a smoother image
                    sigma=sigma,           # sigma provided
                    fast_mode=fast_mode,   # If True, a fast version is used. If False, the original version is used.
                    patch_size=5,          # 5x5 patches (Size of patches used for denoising.)
                    patch_distance=6,      # 13x13 search area
                    multichannel=False)
    denoised_img = []
    n_channels = noisy_np_img.shape[0]
    for c in range(n_channels):
        denoise_fast = denoise_nl_means(noisy_np_img[c, :, :], **patch_kw)
        denoised_img += [denoise_fast]
    return np.array(denoised_img, dtype=np.float32)

In [None]:
# Run Non-Local-Means
denoised_img = non_local_means(data_dict[CORRUPTED].img, sigma=NOISE_SIGMA)
data_dict[NLM] = Data(denoised_img, compare_PSNR(data_dict[ORIGINAL].img, denoised_img,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE))
plot_dict(data_dict)

# Constrained Deep Image prior via PGDA with RED

In [None]:
def train_via_admm(net, net_input, denoiser_function, y, noise_lev,tau, org_img=None,                      # y is the noisy image
                   plot_array={}, algorithm_name="", admm_iter=5000, save_path="",           # path to save params
                   LR=0.001,                                                                      # learning rate
                   sigma_f=3, update_iter=10, method='fixed_point',   # method: 'fixed_point' or 'grad' or 'mixed'
                   beta=1, mu=0.5,mu_r=1, LR_x=None, noise_factor=0.033,    #0.033 LR_x needed only if method!=fixed_point
                   threshold=40, threshold_step=0.01, increase_reg=0.03):                # increase regularization 
    """ training the network using
        ## Must Params ##
        net                 - the network to be trained
        net_input           - the network input
        denoiser_function   - an external denoiser function, used as black box, this function
                              must get numpy noisy image, and return numpy denoised image
        y                   - the noisy image
        sigma               - the noise level (int 0-255)
        
        # optional params #
        org_img             - the original image if exist for psnr compare only, or None (default)
        plot_array          - prints params at the begging of the training and plot images at the required indices
        admm_iter           - total number of admm epoch
        LR                  - the lr of the network in admm (step 2)
        sigma_f             - the sigma to send the denoiser function
        update_iter         - denoised image updated every 'update_iter' iteration
        method              - 'fixed_point' or 'grad' or 'mixed' 
        algorithm_name      - the name that would show up while running, just to know what we are running ;)
                
        # equation params #  
        beta                - regularization parameter (lambda in the article)
        mu                  - ADMM parameter
        LR_x                - learning rate of the parameter x, needed only if method!=fixed point
        # more
        noise_factor       - the amount of noise added to the input of the network
        threshold          - when the image become close to the noisy image at this psnr
        increase_reg       - we going to increase regularization by this amount
        threshold_step     - and keep increasing it every step
    """
    # get optimizer and loss function:
    mse = torch.nn.MSELoss().type(dtype)  # using MSE loss
    # additional noise added to the input:
    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()
    if org_img is not None:
        psnr_y = compare_PSNR(org_img, y,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)  # get the noisy image psnr
    # x update method:
    if method == 'fixed_point':
        swap_iter = admm_iter + 1
        LR_x = None
    elif method == 'grad':
        swap_iter = -1
    elif method == 'mixed':
        swap_iter = admm_iter // 2
    else:
        assert False, "method can be 'fixed_point' or 'grad' or 'mixed' only "
    
    # optimizer and scheduler
    optimizer = torch.optim.Adam(net.parameters(), lr=LR)  # using ADAM opt
    
    y_torch = np_to_torch(y).type(dtype)

    x, u,u_r, r, out_np, res,out_np_old = y.copy(), np.zeros_like(y),np.zeros_like(y),y.copy(), np.zeros_like(y),y.copy(),np.zeros_like(y)
    f_x, avg, avg2, avg3 = x.copy(), np.rint(y), np.rint(y), np.rint(y)
    list_psnr=[]
    list_stopping=[]
    img_queue = queue.Queue()
    denoiser_thread = Thread(target=lambda q, f, f_args: q.put(f(*f_args)),
                             args=(img_queue, denoiser_function, [x.copy(), sigma_f]))
    denoiser_thread.start()

    for i in range(1, 1 + admm_iter):

        rho = tau*noise_lev*np.sqrt(y.shape[0]*y.shape[1]*y.shape[2] - 1)  

        # step 1, update x using a denoiser and result from step 1
        if i % update_iter == 0:  # the denoiser work in parallel
            denoiser_thread.join()
            f_x = img_queue.get()
            denoiser_thread = Thread(target=lambda q, f, f_args: q.put(f(*f_args)),
                                     args=(img_queue, denoiser_function, [x.copy(), sigma_f]))
            denoiser_thread.start()

        x_old = x

        if i < swap_iter:
            x = 1 / (beta + mu) * (beta * f_x + mu * (out_np + u))
        else:
            x = x - LR_x * (beta * (x - f_x) + mu * (x - out_np - u))
        np.clip(x, 0, 1, out=x)  # making sure that image is in bounds

        # step 2 projection
        r_old = r
        r = res + u_r  
        r_norm = np.sqrt(np.sum(np.square(r)))
        if r_norm > rho:
          r = r*(rho/r_norm)
          r_norm_2 = np.sqrt(np.sum(np.square(r)))

        # step 3, update network:
        optimizer.zero_grad()
        net_input = net_input_saved + (noise.normal_() * noise_factor)
        out = net(net_input)
        out_np = torch_to_np(out)

        # loss:
        loss_y = mse(out-y_torch, np_to_torch(r_old-u_r).type(dtype))
        loss_x = mse(out, np_to_torch(x_old - u).type(dtype))
        total_loss = mu_r*loss_y + mu * loss_x
        total_loss.backward()
        optimizer.step()

        out = net(net_input)
        out_np = torch_to_np(out)

        res = out_np-y
        
        # step 4, update u and u_r
        u = u +  0.001*(out_np - x)
        u_r = u_r + 0.001*(res - r)

        # Averaging:
        avg = avg * .99 + out_np * .01

        stopping = np.sqrt(np.sum(np.square(out_np-y)))/ rho 
        #stopping = 255*(np.sqrt(np.sum(np.square(org_img-y)))/ np.sqrt(y.shape[0]*y.shape[1]*y.shape[2] - 1))
        list_stopping.append(stopping)

        #if stopping < 1:
        #  break
        
        out_np_old = out_np
        

        # show psnrs:
        psnr_noisy = compare_PSNR(out_np, y)
        if psnr_noisy > threshold:
            mu = mu + increase_reg
            beta = beta + increase_reg
            threshold += threshold_step
        if org_img is not None:
            psnr_net, psnr_avg = compare_PSNR(org_img, out_np,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE), compare_PSNR(org_img, avg,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
            psnr_x, psnr_x_u = compare_PSNR(org_img, x,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE), compare_PSNR(org_img, x - u,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE)
            list_psnr.append(psnr_avg)
            print('\r', algorithm_name, '%04d/%04d Loss %f' % (i, admm_iter, total_loss.item()),
                  'psnrs: y: %.2f psnr_noisy: %.2f net: %.2f' % (psnr_y, psnr_noisy, psnr_net),
                  'x: %.2f x-u: %.2f avg: %.2f' % (psnr_x, psnr_x_u, psnr_avg), 'params: rho: %.2f r_norm: %.2f  r_norm2: %.2f stopping: %.2f mu: %.2f' % (rho,r_norm,r_norm_2,stopping,mu), end='')
            if i in plot_array:  # plot images
                tmp_dict = {'Clean': Data(org_img),
                            #'Noisy': Data(y, psnr_y),
                            'Net': Data(out_np, psnr_net),
                            #'x-u': Data(x - u, psnr_x_u),
                            'avg': Data(avg, psnr_avg),
                            'x': Data(x, psnr_x),
                            'u': Data((u - np.min(u)) / (np.max(u) - np.min(u))),
                            'u_r': Data((u_r - np.min(u_r)) / (np.max(u_r) - np.min(u_r))),
                            'r': Data(r)
                            }
                plot_dict(tmp_dict)
        else:
            print('\r', algorithm_name, 'iteration %04d/%04d Loss %f' % (i, admm_iter, total_loss.item()), end='')
    
    if denoiser_thread.is_alive():
        denoiser_thread.join()  # joining the thread
    return avg,list_psnr,list_stopping

## Let's Go:

In [None]:
def run_and_plot(denoiser, name, plot_checkpoints={}):
    global data_dict
    noise_lev = NOISE_SIGMA/255
    tau=1
    net, net_input = get_network_and_input(img_shape=data_dict[CORRUPTED].img.shape)
    denoised_img,list_psnr,list_stopping = train_via_admm(net, net_input, denoiser, data_dict[CORRUPTED].img,noise_lev,tau,
                                  plot_array=plot_checkpoints, algorithm_name=name,
                                  org_img=data_dict[ORIGINAL].img)
    data_dict[name] = Data(denoised_img, compare_PSNR(data_dict[ORIGINAL].img, denoised_img,on_y=(not GRAY_SCALE), gray_scale=GRAY_SCALE))
    plot_dict(data_dict)

    return denoised_img,list_psnr,list_stopping


plot_checkpoints = {1, 10, 50, 100, 250, 500, 2000, 3500, 4000, 4500, 5000,5500,6000,6500,7000,7500,8000} 
denoised_img,list_psnr,list_stopping=run_and_plot(non_local_means, DIP_NLM, plot_checkpoints)  # you may try it with different denoisers