# **Recursive Deep Prior Video: a Super Resolution algorithm** 
# **for Time-Lapse Microscopy of organ-on-chip experiments**


##### if you use this code, please consider to cite our paper available at the following link:

https://arxiv.org/abs/2011.09855


## Step by step...

1. Dowload the github repository.
2. Add the repository to your Google Drive.
3. Mount your Google Drive and set the pwd on your repository path.

In [1]:
# Mounting my Google Drive and set the cwd on my dataset
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
%cd /content/drive/My Drive/Deep-Video-TV/code_github

!pip install pytorch-msssim

Mounted at /content/drive/
/content/drive/My Drive/Deep-Video-TV/code_github
Collecting pytorch-msssim
  Downloading https://files.pythonhosted.org/packages/9d/d3/3cb0f397232cf79e1762323c3a8862e39ad53eca0bb5f6be9ccc8e7c070e/pytorch_msssim-0.2.1-py3-none-any.whl
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-0.2.1


4. import the libs

In [32]:
# import libs
from __future__ import print_function
import matplotlib.pyplot as plt
import matplotlib
import time
import argparse
import os
import imageio
from PIL import Image
import numpy as np
from models import *
import torch
import torch.optim
from skimage.measure import compare_psnr, compare_ssim, compare_mse
from models.downsampler import Downsampler
from pytorch_msssim import ssim
from utils.sr_utils import *
# import EarlyStopping
from pytorchtools import EarlyStopping

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
#os.environ["CUDA_VISIBLE_DEVICES"]="0"

4. Set the paramters

In [37]:
PLOT = False

# set up net parameters
input_depth = 3
INPUT =     'noise'
pad   =     'reflection'
OPT_OVER =  'net'
NET_TYPE = 'skip' 

# Model parameters
tv_weight = 0.0083 #0# 0.0000001  #dire che è scalato!
factor = 4 
KERNEL_TYPE='lanczos2'
enforse_div32 = 'CROP'

# Computation step parameters
LR = 0.001
OPTIMIZER = 'adam'
patience = 50
num_iter = 1000
num_iter_start_early_stop = 500
reg_noise_std = 0.03

In [38]:
#inizializing variables

list_l = [] #lista da salvare
rg = range(1,3)
eval_array = np.zeros([len(rg),5])
src_dir = 'reconstruction/'

print_every=2

In [None]:
for num in rg:

    net = get_net(input_depth, NET_TYPE, pad,
              upsample_mode='bilinear').type(dtype)

    #weigths initialization
    if num > 1:
      net.load_state_dict(torch.load('updating_model/recursive_model.pt'))
      patience = 50
      num_iter = 1000
      num_iter_start_early_stop = 300
    
    
    net.cuda()

    #load HR image
    path_to_image = 'data/example/frame'+str(num)+'.png'

    # load images and baselines
    imgs = load_LR_HR_imgs_sr(path_to_image , -1, factor, enforse_div32)

    im_HR_size=imgs['HR_np'].shape[1]*imgs['HR_np'].shape[2])
   
    # generating net_input
    net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()

    # initializing Early Stopping
    early_stopping = EarlyStopping(patience=patience, verbose=True)
  
    # Loss
    mse = torch.nn.MSELoss().type(dtype)

    img_LR_var = np_to_torch(imgs['LR_np']).type(dtype)
    img_HR_var = np_to_torch(imgs['HR_np']).type(dtype)
    downsampler = Downsampler(n_planes=3, factor=factor, kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype)

    # define closure and optimize
    running_loss = 0.0
    loss_values = []
    ssim_values = []

    def closure():
        global i, net_input, running_loss, loss_values, early_stopping, num_iter_start_early_stop
        
        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)

        out_HR = net(net_input)
        out_LR = downsampler(out_HR)

        total_loss = mse(out_LR, img_LR_var) 

        if tv_weight > 0:
            total_loss += (tv_weight/(im_HR_size)) * tvi_loss(out_HR) # choose: tvi_loss for isotropic TV and tva_loss for anisotropic TV
            
        total_loss.backward()
        running_loss = total_loss.item()
        loss_values.append(running_loss)
        
        if i > num_iter_start_early_stop:
            early_stopping(running_loss, net)
            flag_early_stopping = early_stopping.early_stop 
        else:
            flag_early_stopping = 0    
        
        # Metrics
        ssim_fun =  ssim(out_HR, img_HR_var, data_range=1, size_average=True) 
        ssim_HR = ssim_fun.item()
        psnr_HR = compare_psnr(imgs['HR_np'], torch_to_np(out_HR))
                    
        # History
        metric_history.append([ssim_HR, psnr_HR])
        
          
        if  i % print_every == 0:
          print('num_frame %d Iteration %05d  PSNR_HR %.3f SSIM_HR %.3f' %(num,i, psnr_HR, ssim_HR))

        
        i += 1
        
        return total_loss, flag_early_stopping



    metric_history = []
    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()
    i = 0
    p = get_params(OPT_OVER, net, net_input)
    optimize_early_stop(OPTIMIZER, p, closure, LR, num_iter)

    #rescaling
    out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)

    # Save history in excel 
    list_l = [num]
    list_l.append(i)
    list_l.append(metric_history[-1][0])
    list_l.append(metric_history[-1][1])
    list_l.append(loss_values[-1])
    eval_array[num-rg[0],:] = np.array(list_l) 
    
    name_image = 'reconstruction/frame'+str(num)+'.png'
    
    out_new = np.zeros([imgs['HR_np'].shape[1],imgs['HR_np'].shape[2],imgs['HR_np'].shape[0]]) 
    out_new[:,:,0]=out_HR_np[0,:,:]
    out_new[:,:,1]=out_HR_np[1,:,:]
    out_new[:,:,2]=out_HR_np[2,:,:]

    out_new_save  = (255*out_new).astype(np.uint8)
    
    im=Image.fromarray(out_new_save)
    im.save(name_image)
    torch.save(net.state_dict(),'updating_model/recursive_model.pt')
    del p, noise, net_input, net_input_saved, imgs, mse, downsampler, list_l, metric_history, loss_values, ssim_values, out_HR_np, out_new, out_new_save, im, early_stopping

eval_name =  'all_frames_evaluation.csv'
np.savetxt(os.path.join(src_dir,eval_name), eval_array)
del eval_array