# Setup


In [None]:
from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.insert(1, '/content/drive/My Drive/personal-deep-decoder')

# Decoder model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def add_module(self, module):
    self.add_module(str(len(self) + 1), module)

torch.nn.Module.add = add_module


def conv(in_f, out_f, kernel_size, stride=1, pad='zero'):
    padder = None
    to_pad = int((kernel_size - 1) / 2)
    if pad == 'reflection':
        padder = nn.ReflectionPad2d(to_pad)
        to_pad = 0
  
    convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=False)

    layers = filter(lambda x: x is not None, [padder, convolver])
    return nn.Sequential(*layers)

def decodernw(
        num_output_channels=3, 
        num_channels_up=[128]*5, 
        filter_size_up=1,
        need_sigmoid=True, 
        pad ='reflection', 
        upsample_mode='bilinear', 
        act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 
        bn_before_act = False,
        bn_affine = True,
        upsample_first = True,
        ):
    
    num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]]
    n_scales = len(num_channels_up) 
    
    if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) :
        filter_size_up   = [filter_size_up]*n_scales
    model = nn.Sequential()

    
    for i in range(len(num_channels_up)-1):
        
        if upsample_first:
            model.add(conv( num_channels_up[i], num_channels_up[i+1],  filter_size_up[i], 1, pad=pad))
            if upsample_mode!='none' and i != len(num_channels_up)-2:
                model.add(nn.Upsample(scale_factor=2, mode=upsample_mode))
            #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode))	
        else:
            if upsample_mode!='none' and i!=0:
                model.add(nn.Upsample(scale_factor=2, mode=upsample_mode))
            #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode))	
            model.add(conv( num_channels_up[i], num_channels_up[i+1],  filter_size_up[i], 1, pad=pad))        
        
        if i != len(num_channels_up)-1:	
            if(bn_before_act): 
                model.add(nn.BatchNorm2d( num_channels_up[i+1] ,affine=bn_affine))
            model.add(act_fun)
            if(not bn_before_act): 
                model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine))
      
    model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad))
    if need_sigmoid:
        model.add(nn.Sigmoid())
    
    return model



# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_f, out_f):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Conv2d(in_f, out_f, 1, 1, padding=0, bias=False)
        
    def forward(self, x):
        residual = x
        out = self.conv(x)
        out += residual
        return out

def resdecoder(
        num_output_channels=3, 
        num_channels_up=[128]*5, 
        filter_size_up=1,
        need_sigmoid=True, 
        pad='reflection', 
        upsample_mode='bilinear', 
        act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 
        bn_before_act = False,
        bn_affine = True,
        ):
    
    num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]]
    n_scales = len(num_channels_up) 
    
    if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) :
        filter_size_up   = [filter_size_up]*n_scales

    model = nn.Sequential()

    for i in range(len(num_channels_up)-2):
        
        model.add( ResidualBlock( num_channels_up[i], num_channels_up[i+1]) )
        
        if upsample_mode!='none':
            model.add(nn.Upsample(scale_factor=2, mode=upsample_mode))	
            #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode))	
        
        if i != len(num_channels_up)-1:	
            model.add(act_fun)
            #model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine))
                
    # new
    model.add(ResidualBlock( num_channels_up[-1], num_channels_up[-1]))
    #model.add(nn.BatchNorm2d( num_channels_up[-1] ,affine=bn_affine))
    model.add(act_fun)
    # end new
    
    model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad))
    
    if need_sigmoid:
        model.add(nn.Sigmoid())
    
    return model


## Study the decoder net

In [None]:
from torchsummary import summary

In [None]:
k=128
num_channels = [k]*6
output_depth=1
net = decodernw(output_depth ,num_channels_up=num_channels,upsample_first=True).type(torch.cuda.FloatTensor)

In [None]:
summary(net, input_size=(k, 16, 16))

##Study the resdecoder

In [None]:
net = resdecoder(output_depth, num_channels_up=num_channels).type(torch.cuda.FloatTensor)

In [None]:
summary(net, input_size=(64, 16, 16))

# Conv decoder model

In [None]:
import torch
import torch.nn as nn
import numpy as np
import copy

def add_module(self, module):
    self.add_module(str(len(self) + 1), module)

torch.nn.Module.add = add_module

class conv_model(nn.Module):
    def __init__(self, num_layers, strides, num_channels, out_depth, hidden_size, upsample_mode, act_fun, bn_affine=True, bias=False, need_last=False, kernel_size=3):
        super(conv_model, self).__init__()
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.upsample_mode = upsample_mode
        self.act_fun = act_fun
        self.layer_inds = [] # record index of the layers that generate output in the sequential mode (after each BatchNorm)
        self.combinations = None # this holds input of the last layer which is upsampled versions of previous layers
        #self.dtype = dtype

        cntr = 1
        #torch.set_default_tensor_type(dtype)
        net1 = nn.Sequential()
        for i in range(num_layers-1):
            
            net1.add(nn.Upsample(size=hidden_size[i], mode=upsample_mode,align_corners=True))
            cntr += 1
            
            conv = nn.Conv2d(num_channels, num_channels, kernel_size, strides[i], padding=(kernel_size-1)//2, bias=bias)
            net1.add(conv)
            cntr += 1
            
            net1.add(act_fun)
            cntr += 1
            
            net1.add(nn.BatchNorm2d( num_channels, affine=bn_affine))
            if i != num_layers - 2: # penultimate layer will automatically be concatenated if skip connection option is chosen
                self.layer_inds.append(cntr)
            cntr += 1

        net2 = nn.Sequential()
        
        nic = num_channels
        
        if need_last: # orignal code default False, but we call it True
            net2.add( nn.Conv2d(nic, num_channels, kernel_size, strides[i], padding=(kernel_size-1)//2, bias=bias) )
            net2.add(act_fun)
            net2.add(nn.BatchNorm2d( num_channels, affine=bn_affine))
            nic = num_channels
            
        net2.add(nn.Conv2d(nic, out_depth, 1, 1, padding=0, bias=bias))
        
        self.net1 = net1 # actual convdecoder network
        self.net2 = net2 # (default seting) one-layer net converting number of channels
        
    def forward(self, x, scale_out=1):
        ''' run input thru net1 (convdecoder) then net2 (converts number of channels
        provide options for skip connections (default False) and scaling factors (default 1) '''
        out1 = self.net1(x)
        self.combinations = copy.copy(out1)
        out2 = self.net2(out1)
        return out2*scale_out
    def up_sample(self,img):
        ''' single upsampling layer '''
        samp_block = nn.Upsample(size=self.hidden_size[-1], mode=self.upsample_mode)#,align_corners=True)
        img = samp_block(img)
        return img

def convdecoder(
        in_size, #default [16,16]
        out_size,#default [256,256]
        out_depth, #default 3
        num_layers, #default 6
        strides, #default [1]*6,
        num_channels, #default 64

        kernel_size=3,
        upsample_mode='nearest', #default 'bilinear', 
        act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 
        bn_affine = True,
        nonlin_scales=False,
        bias=False,
        need_last=True, #False,
        ):
    
    ''' determine how to scale the network based on specified input size and output size
        where output hidden_size is size of each hidden layer of network
        e.g. input [8,4] and output [640,368] would yield hidden_size of:
            [(15, 8), (28, 15), (53, 28), (98, 53), (183, 102), (343, 193), (640, 368)]
        provide option for nonlinear scaling (default False) and different activation functions
        call conv_model(...), defined above 

        Note: I removed unnecessary args, e.g. skips, intermeds, pad, etc. 
              decoder_conv_old.py for original code'''

    # scaling factor layer-to-layer in x and y direction
    # e.g. (scale_x, scale_y) = (1.87, 1.91)
    scale_x,scale_y = (out_size[0]/in_size[0])**(1./(num_layers-1)), (out_size[1]/in_size[1])**(1./(num_layers-1))
    
    if nonlin_scales: # default false
        xscales = np.ceil( np.linspace(scale_x * in_size[0],out_size[0],num_layers-1) )
        yscales = np.ceil( np.linspace(scale_y * in_size[1],out_size[1],num_layers-1) )
        hidden_size = [(int(x),int(y)) for (x,y) in zip(xscales,yscales)]
    else:
        hidden_size = [(int(np.ceil(scale_x**n * in_size[0])),
                        int(np.ceil(scale_y**n * in_size[1]))) for n in range(1, (num_layers-1))] + [out_size]
    #print(hidden_size)
    
    model = conv_model(num_layers, strides, num_channels, out_depth, hidden_size,
                         upsample_mode=upsample_mode, 
                         act_fun=act_fun,
                         bn_affine=bn_affine,
                         bias=bias,
                         need_last=need_last,
                         kernel_size=kernel_size)#,
                         #dtype=dtype)
    return model

## Visualize convdecoder

In [None]:
model = convdecoder(
        [16,16], #default [16,16]
        [512,512],#default [256,256]
        1, #default 3
        6, #default 6
        [1]*6, #default [1]*6,
        64, #default 64

        kernel_size=3,
        upsample_mode='nearest', #default 'bilinear', 
        act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 
        bn_affine = True,
        nonlin_scales=False,
        bias=False,
        need_last=True, #False,
        ).type(torch.cuda.FloatTensor)

In [None]:
summary(model, input_size=(64, 16, 16))

# Define loss functions

## Reconstruction loss

In [None]:
def loss_reconstruction(img, true_img):
  mse = torch.nn.MSELoss()
  return mse(img, true_img)

## Autocorrelation loss

In [None]:
def autocorrelation(img):
  # shape of img: (B, C, H, W)
  B, C, H, W = img.shape
  mu = torch.mean(img)
  sigma = torch.std(img)
  fft = torch.rfft((img-mu)/sigma, signal_ndim=2, onesided=False) # (B, C, H, W, 2)
  fft_square = torch.cuda.FloatTensor(B, C, H, W, 2).fill_(0) # (B, C, H, W, 2)
  fft_square[:, :, :, :, 0] = fft[:, :, :, :, 0]**2 + fft[:, :, :, :, 1]**2
  fast_auto_corr = torch.irfft(fft_square, signal_ndim=2, onesided=False) # (B, C, H, W)
  fast_auto_corr = fast_auto_corr/(H*W-1)
  return fast_auto_corr

def loss_ac(img):
  # shape of img: (B, C, H, W)
  B, C, H, W = img.shape
  autocorr_img = autocorrelation(img)
  const_tensor = torch.cuda.FloatTensor(B, C, H, W).fill_(0.)
  const_tensor[0, 0, 0, 0] = 1.
  auto_corr_loss = (autocorr_img-const_tensor).pow(2).mean()
  return auto_corr_loss

## Partial autocorrelation loss

In [None]:
def autocorrelation(img):
  # shape of img: (B, C, H, W)
  B, C, H, W = img.shape
  mu = torch.mean(img)
  sigma = torch.std(img)
  fft = torch.rfft((img-mu)/sigma, signal_ndim=2, onesided=False) # (B, C, H, W, 2)
  fft_square = torch.cuda.FloatTensor(B, C, H, W, 2).fill_(0) # (B, C, H, W, 2)
  fft_square[:, :, :, :, 0] = fft[:, :, :, :, 0]**2 + fft[:, :, :, :, 1]**2
  fast_auto_corr = torch.irfft(fft_square, signal_ndim=2, onesided=False) # (B, C, H, W)
  fast_auto_corr = fast_auto_corr/(H*W-1)
  return fast_auto_corr

def loss_partial_ac(img):
  # shape of img: (B, C, H, W)
  B, C, H, W = img.shape
  autocorr_img = autocorrelation(img)
  const_tensor = torch.cuda.FloatTensor(B, C, H, W).fill_(0.)
  const_tensor[0, 0, 0, 0] = 1.
  auto_corr_loss = (autocorr_img-const_tensor)[:, :, :10, :10].pow(2).mean()
  return auto_corr_loss

## Total variation loss

In [None]:
def loss_tv(img):
  # shape of img: (B, C, H, W)
  B, C, H, W = img.shape
  tv_h = torch.abs(img[:,:,1:,:]-img[:,:,:-1,:]).sum()
  tv_w = torch.abs(img[:,:,:,1:]-img[:,:,:,:-1]).sum()
  return (tv_h+tv_w)/((H-1)*(W-1))

# Define fit function

In [None]:
from torch.autograd import Variable
import torch
import torch.optim
import copy
import numpy as np
from scipy.linalg import hadamard
import matplotlib.pyplot as plt

from helpers import *

dtype = torch.cuda.FloatTensor
#dtype = torch.FloatTensor
           

def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=500):
    """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
    lr = init_lr * (0.65**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer


def fit(net,
        img_noisy_var,
        num_channels,
        img_clean_var,
        loss_type='l2',
        num_iter = 5000,
        LR = 0.01,
        OPTIMIZER='adam',
        opt_input = False,
        reg_noise_std = 0,
        reg_noise_decayevery = 100000,
        mask_var = None,
        apply_f = None,
        lr_decay_epoch = 0,
        net_input = None,
        net_input_gen = "random",
        find_best=False,
        weight_decay=0,
        lambda_ac=0,
        lambda_tv=0,
        debug=False
       ):
  
    if not loss_type in ['l1', 'l2']:
      print("Error: loss_type is {} but should be either l1 or l2".format(loss_type))

    if net_input is not None:
        print("input provided")
    else:
        # feed uniform noise into the network 
        totalupsample = 2**len(num_channels)
        width = int(img_clean_var.data.shape[2]/totalupsample)
        height = int(img_clean_var.data.shape[3]/totalupsample)
        shape = [1,num_channels[0], width, height]
        print("shape: ", shape)
        net_input = Variable(torch.zeros(shape))
        net_input.data.uniform_()
        net_input.data *= 1./10
    
    print(net_input.shape)

    net_input_saved = net_input.data.clone()
    noise = net_input.data.clone()
    p = [x for x in net.parameters() ]

    if(opt_input == True): # optimizer over the input as well
        net_input.requires_grad = True
        p += [net_input]

    mse_wrt_noisy = np.zeros(num_iter)
    mse_wrt_truth = np.zeros(num_iter)

    if OPTIMIZER == 'SGD':
        print("optimize with SGD", LR)
        optimizer = torch.optim.SGD(p, lr=LR,momentum=0.9,weight_decay=weight_decay)
    elif OPTIMIZER == 'adam':
        print("optimize with adam", LR)
        optimizer = torch.optim.Adam(p, lr=LR,weight_decay=weight_decay)
    elif OPTIMIZER == 'LBFGS':
        print("optimize with LBFGS", LR)
        optimizer = torch.optim.LBFGS(p, lr=LR)

    mse = torch.nn.MSELoss() #.type(dtype) 
    noise_energy = mse(img_noisy_var, img_clean_var)

    if find_best:
        best_net = copy.deepcopy(net)
        best_loss = 1000000.
        best_iter = 0

    for i in range(num_iter):
        
        if lr_decay_epoch is not 0:
            optimizer = exp_lr_scheduler(optimizer, i, init_lr=LR, lr_decay_epoch=lr_decay_epoch)
        if reg_noise_std > 0:
            if i % reg_noise_decayevery == 0:
                reg_noise_std *= 0.7
            net_input = Variable(net_input_saved + (noise.normal_() * reg_noise_std))
        
        def closure():
            optimizer.zero_grad()
            out = net(net_input.type(dtype))

            # training loss 
            """
            if mask_var is not None:
                loss = mse( out * mask_var , img_noisy_var * mask_var )
            elif apply_f:
                loss = mse( apply_f(out) , img_noisy_var )
            else:
                loss = mse(out, img_noisy_var)
            """

            if loss_type == 'l2':
              l_loss = loss_reconstruction(out, img_noisy_var)
            else:
              l_loss = torch.abs(out-img_noisy_var).sum()
            #ac_loss = loss_ac(img_noisy_unclipped_var-out)
            ac_loss = loss_partial_ac(img_noisy_var-out)
            tv_loss = loss_tv(out)
            loss = l_loss + lambda_ac*ac_loss + lambda_tv*tv_loss
            # Adding
            
            if i%100==0:
                print("{} iterations: loss={}".format(i, loss))
                if debug:
                  print("    l_loss: {}".format(l_loss))
                  print("    AC loss: {}".format(ac_loss))
                  print("    TV loss: {}".format(tv_loss))
                if out.shape[1] != 1:
                  plt.imshow(np.clip(out.data.cpu().numpy()[0].transpose(1, 2, 0), 0, 1))
                else:
                  plt.imshow(np.clip(out.data.cpu().numpy()[0][0, :, :], 0, 1), cmap='gray')
                plt.show()
                #save_np_img(out.data.cpu().numpy()[0], '/content/drive/My Drive/personal-deep-decoder/saved_images/iter_{}'.format(i))
                
            
            loss.backward()
            mse_wrt_noisy[i] = loss.data.cpu().numpy()
            
            
            # the actual loss 
            true_loss = mse(Variable(out.data, requires_grad=False), img_clean_var)
            mse_wrt_truth[i] = true_loss.data.cpu().numpy()
            if i % 10 == 0:
                out2 = net(Variable(net_input_saved).type(dtype))
                loss2 = mse(out2, img_clean_var)
                print ('Iteration %05d    Train loss %f  Actual loss %f Actual loss orig %f  Noise Energy %f' % (i, loss.data,true_loss.data,loss2.data,noise_energy.data), '\r', end='')
            return loss

        
        #if OPTIMIZER == 'LBFGS':
        #    if i < 100:
        #        optimizer = torch.optim.Adam(p, lr=LR)
        #    else:
        #        optimizer = torch.optim.LBFGS(p, lr=LR)
        
        
        loss = optimizer.step(closure)
            
        if find_best:
            # if training loss improves by at least one percent, we found a new best net
            if best_loss > 1.005*loss.data:
                best_loss = loss.data
                best_net = copy.deepcopy(net)
                best_iter = i
                 
        
    if find_best:
        net = best_net
        print("Best loss at iteration {}".format(best_iter))
    return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net

        ### weight regularization
        #if orth_reg > 0:
        #    for name, param in net.named_parameters():
                # consider all the conv weights, but the last one which only combines colors
        #        if '.1.weight' in name and str( len(net)-1 ) not in name:
        #            param_flat = param.view(param.shape[0], -1)
        #            sym = torch.mm(param_flat, torch.t(param_flat))
        #            sym -= Variable(torch.eye(param_flat.shape[0])).type(dtype)
        #            loss = loss + (orth_reg * sym.sum().type(dtype) )
        ###

# Experiment

In [None]:
path = '/content/drive/My Drive/Full_Images/'
img_name = "img_54"

img_path = path + img_name + ".png"
img_pil = Image.open(img_path)
#img_np = pil_to_np(img_pil)
#Start
img_np = pil_to_np(img_pil)
img_np = img_np[0].reshape((1, img_np.shape[1], img_np.shape[2]))
img_np = img_np[:, 2:-2, :]
#Start
img_clean_var = np_to_var(img_np).type(dtype)

def get_noisy_img(sig=15, noise_same = False):
    sigma = sig/255.
    if noise_same: # add the same noise in each channel
        noise = np.random.normal(scale=sigma, size=img_np.shape[1:])
        noise = np.array( [noise]*img_np.shape[0] )
    else: # add independent noise in each channel
        noise = np.random.normal(scale=sigma, size=img_np.shape)

    img_noisy_np = np.clip(img_np+noise, 0, 1).astype(np.float32)
    img_noisy_var = np_to_var(img_noisy_np).type(dtype)
    return img_noisy_np, img_noisy_var


img_noisy_np, img_noisy_var = get_noisy_img()  
output_depth = img_np.shape[0] 
print("Image size: ", img_np.shape)

def denoise(img_noisy_var, k=128, numit = 1900, rn = 0.0, find_best=True, upsample_first = True, loss_type='l2', lambda_ac=0, lambda_tv=0, debug=False):
    output_depth = img_noisy_var.shape[3]
    num_channels = [k]*5
    net = decodernw(output_depth, num_channels_up=num_channels,upsample_first=upsample_first).type(dtype)
    #net = resdecoder(output_depth, num_channels_up=num_channels).type(dtype)
    #net = convdecoder([16,16], list(img_np.shape), 1, 6, [1]*6, k, kernel_size=3, upsample_mode='bilinear', act_fun=nn.ReLU()).type(dtype)


    mse_n, mse_t, ni, net = fit(
                        net=net,
                        num_channels=num_channels,
                        reg_noise_std=rn,
                        num_iter=numit,
                        img_noisy_var=img_noisy_var,
                        img_clean_var=img_clean_var,
                        find_best=find_best,
                        loss_type=loss_type,
                        lambda_ac=lambda_ac,
                        lambda_tv=lambda_tv,
                        debug=debug
                        )
    out_img_np = net( ni.type(dtype) ).data.cpu().numpy()[0]
    return out_img_np, mse_t

def myimgshow(plt,img):
  if img.shape[0]!=1:
    plt.imshow(np.clip(img.transpose(1, 2, 0), 0, 1))
  else:
    plt.imshow(np.clip(img[0, :, :], 0, 1), cmap='gray')


def plot_results(out_img_np, img_np, img_noisy_np):
    fig = plt.figure(figsize = (15,15)) # create a 5 x 5 figure 
    
    ax1 = fig.add_subplot(131)
    myimgshow(ax1, img_np) 
    ax1.set_title('Original image')
    ax1.axis('off')
    
    ax2 = fig.add_subplot(132)
    myimgshow(ax2, img_noisy_np)
    ax2.set_title( "Noisy observation, PSNR: %.2f" % psnr(img_np, img_noisy_np) )
    ax2.axis('off')

    ax3 = fig.add_subplot(133)
    myimgshow(ax3, out_img_np)
    ax3.set_title( "Denoised image, SNR: %.2f" % psnr(img_np, out_img_np) ) 
    ax3.axis('off')    

    plt.show()

## Experiment with Gaussian

In [None]:
import cv2
from skimage import io, img_as_float
from skimage.filters import gaussian

In [None]:
# Conv
gaussian_kernel = np.array([[1/16, 1/8, 1/16], [1/8, 1/4, 1/8], [1/16, 1/8, 1/16]])
out_conv_img_np = cv2.filter2D(img_noisy_np, -1, gaussian_kernel, borderType=cv2.BORDER_CONSTANT)
plot_results(out_conv_img_np, img_np, img_noisy_np)

# Gaussian
for sigma in [0.01, 0.05, 0.1, 0.3, 0.5, 1, 2]:
  print("sigma =", sigma)
  out_gaussian_img_np = gaussian(img_noisy_np, sigma=sigma, mode='constant', cval=0.0)
  plot_results(out_gaussian_img_np, img_np, img_noisy_np)

## Experiment with Bilateral

In [None]:
from skimage.restoration import denoise_bilateral
out_img_np = denoise_bilateral(np.squeeze(img_noisy_np), sigma_spatial=5, multichannel=False)
plot_results(out_img_np.reshape((1, out_img_np.shape[0], out_img_np.shape[1])), img_np, img_noisy_np)

In [None]:
from skimage.restoration import denoise_bilateral
out_img_np = denoise_bilateral(np.squeeze(img_noisy_np), sigma_spatial=10, multichannel=False)
plot_results(out_img_np.reshape((1, out_img_np.shape[0], out_img_np.shape[1])), img_np, img_noisy_np)

In [None]:
from skimage.restoration import denoise_bilateral
out_img_np = denoise_bilateral(np.squeeze(img_noisy_np), sigma_spatial=15, multichannel=False)
plot_results(out_img_np.reshape((1, 512, 512)), img_np, img_noisy_np)

##Experiments with Non-local Mean

In [None]:
from skimage.restoration import denoise_nl_means, estimate_sigma

In [None]:
sigma_est = np.mean(estimate_sigma(img_noisy_np, multichannel=True))
out_img_np = denoise_nl_means(img_noisy_np, h=1.15 * sigma_est, fast_mode=True,
                               patch_size=5, patch_distance=3, multichannel=False)
if len(out_img_np.shape)==3:
  plot_results(out_img_np, img_np, img_noisy_np)
else:
  plot_results(out_img_np.reshape((1, out_img_np.shape[0], out_img_np.shape[1])), img_np, img_noisy_np)

In [None]:
out_img_np = denoise_nl_means(img_noisy_np, h=0.6*15./255., fast_mode=True,
                               patch_size=5, patch_distance=6, multichannel=False)
if len(out_img_np.shape)==3:
  plot_results(out_img_np, img_np, img_noisy_np)
else:
  plot_results(out_img_np.reshape((1, out_img_np.shape[0], out_img_np.shape[1])), img_np, img_noisy_np)

## Experiments with BM3D

In [None]:
! pip install bm3d

In [None]:
import bm3d

In [None]:
for sigma_psd in [0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5]:
  BM3D_denoised_image = bm3d.bm3d(np.squeeze(np.transpose(img_noisy_np, (1, 2, 0))), sigma_psd=sigma_psd, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)
  print("sigma_psd = ", sigma_psd)
  if len(BM3D_denoised_image.shape)==3:
    plot_results(np.transpose(BM3D_denoised_image, (2, 0, 1)), img_np, img_noisy_np)
  else:
    plot_results(BM3D_denoised_image.reshape((1, BM3D_denoised_image.shape[0], BM3D_denoised_image.shape[1])), img_np, img_noisy_np)

In [None]:
BM3D_denoised_image = bm3d.bm3d(np.squeeze(np.transpose(img_noisy_np, (1, 2, 0))), sigma_psd=0.05, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)

In [None]:
BM3D_denoised_image

In [None]:
n, bins, patches = plt.hist(BM3D_denoised_image.reshape(-1), bins=100)
plt.show()

In [None]:
plt.imshow(BM3D_denoised_image>0.18)

In [None]:
for threshold in np.linspace(0.0,0.2, 20):
  modified_BM3D_denoised_image = BM3D_denoised_image * (BM3D_denoised_image>threshold)
  plot_results(modified_BM3D_denoised_image.reshape((1, modified_BM3D_denoised_image.shape[0], modified_BM3D_denoised_image.shape[1])), img_np, img_noisy_np)

## Experiments with pure TV denoising

In [None]:
import numpy as np
import scipy
import scipy.misc
import matplotlib.pyplot as plt
try:
    from skimage.restoration import denoise_tv_chambolle
except ImportError:
    # skimage < 0.12
    from skimage.filters import denoise_tv_chambolle

In [None]:
out_img_np = denoise_tv_chambolle(img_noisy_np, weight=0.02)
plot_results(out_img_np, img_np, img_noisy_np)

In [None]:
out_img_np = denoise_tv_chambolle(img_noisy_np, weight=0.1)
plot_results(out_img_np, img_np, img_noisy_np)

In [None]:
out_img_np = denoise_tv_chambolle(img_noisy_np, weight=0.15)
plot_results(out_img_np, img_np, img_noisy_np)

##Experiments with L2 Norm

In [None]:
out_img_np, mse_t = denoise(img_noisy_var, k=64, numit = 5000, rn = 0.0, loss_type='l2', lambda_ac=0, lambda_tv=0, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

##Experiments with L1 Norm


In [None]:
out_img_np, mse_t = denoise(img_noisy_var, k=128, numit = 5000, rn = 0.0, loss_type='l1', lambda_ac=0, lambda_tv=0, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

##Experiments with L1 Norm+TV

In [None]:
13298.24609375/0.03646152466535568

In [None]:
out_img_np, mse_t = denoise(img_noisy_var, k=64, numit = 5000, rn = 0.0, loss_type='l1', lambda_ac=0, lambda_tv=1e5, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

In [None]:
out_img_np, mse_t = denoise(img_noisy_var, k=64, numit = 5000, rn = 0.0, loss_type='l1', lambda_ac=0, lambda_tv=5e4, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

##Experiments with L1 Norm+TV+AC

In [None]:
14160.3076171875/6.154358743515331e-06

In [None]:
out_img_np, mse_t = denoise(img_noisy_var, k=64, numit = 5000, rn = 0.0, loss_type='l1', lambda_ac=1e9, lambda_tv=1e5, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

In [None]:
out_img_np, mse_t = denoise(img_noisy_var, k=64, numit = 5000, rn = 0.0, loss_type='l1', lambda_ac=5e8, lambda_tv=1e5, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

In [None]:
out_img_np, mse_t = denoise(img_noisy_var, k=64, numit = 5000, rn = 0.0, loss_type='l1', lambda_ac=1e8, lambda_tv=1e5, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

##Experiments with MSE+TV+Partial AC

In [None]:
img_noisy_np,img_noisy_var = get_noisy_img(sig=30, noise_same=False)
out_img_np, mse_t = denoise(img_noisy_var, k=128, numit = 1900, rn = 0.0, lambda_ac=0, lambda_tv=1e-7, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

In [None]:
ac = autocorrelation(torch.tensor((out_img_np-img_np).reshape((1,1,512,512))))

In [None]:
plt.imshow(ac.cpu().numpy().reshape(512,512)[:10,:10], cmap='gray')

In [None]:
loss_partial_ac(torch.tensor((out_img_np-img_np).reshape((1,1,512,512))))

In [None]:
img_noisy_np,img_noisy_var = get_noisy_img(sig=30, noise_same=False)
out_img_np, mse_t = denoise(img_noisy_var, k=128, numit = 1900, rn = 0.0, lambda_ac=5e2, lambda_tv=1e-7, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

In [None]:
img_noisy_np,img_noisy_var = get_noisy_img(sig=30, noise_same=False)
out_img_np, mse_t = denoise(img_noisy_var, k=128, numit = 1900, rn = 0.0, lambda_ac=1e3, lambda_tv=1e-7, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

In [None]:
img_noisy_np,img_noisy_var = get_noisy_img(sig=30, noise_same=False)
out_img_np, mse_t = denoise(img_noisy_var, k=128, numit = 1900, rn = 0.0, lambda_ac=2e3, lambda_tv=1e-7, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

In [None]:
img_noisy_np,img_noisy_var = get_noisy_img(sig=30, noise_same=False)
out_img_np, mse_t = denoise(img_noisy_var, k=128, numit = 1900, rn = 0.0, lambda_ac=5e3, lambda_tv=1e-7, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)

##Experiments with 3D deep decoder

In [None]:
img = np.zeros((292, 288, 192))
for i in range(192):
  img_i = plt.imread("drive/My Drive/Full_Images/img_"+str(i)+".png")
  img[:,:,i] = img_i[:,:,0]
img_np = img.reshape((1, 292, 288, 192))
img_np = img_np[:,:,:,53:57]
img_clean_var = np_to_var(img_np).type(dtype)
sigma = 15./255.
noise = np.random.normal(scale=sigma, size=img_np.shape)
img_noisy_np = np.clip(img_np+noise, 0, 1).astype(np.float32)
img_noisy_var = np_to_var(img_noisy_np).type(dtype)
output_depth = img_np.shape[3] 

In [None]:
out_img_np, mse_t = denoise(img_noisy_var, k=64, numit = 5000, rn = 0.0, loss_type='l2', lambda_ac=0, lambda_tv=0, debug=True)
plot_results(out_img_np, img_np, img_noisy_np)