# Training the BivBlurring model

In this notebook we train the DivBlurring model. Our approach “DivBlurring” is an "inverse image reconstruction model" reconstructing biomedical images from an inverse problem perspective. For the training, we are using the PyTorch_Lighiting framefwork on synthatic data. This data generated based on the realistic data.
The task is reconstructing the desired image from the noisy and blurry data by using generative approaches such as variational auto encoder in combination with the use of a physical model.


## Import packages:

In [None]:
import warnings
warnings.filterwarnings('ignore')
import torch
import os
from torch.distributions import normal
import matplotlib.pyplot as plt, numpy as np, pickle
from tifffile import imread
import sys
from tqdm import tqdm
sys.path.append('../../')

In [None]:
import numpy as np
from torch.autograd import Variable
import pytorch_lightning as pl
import logging
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt
dtype = torch.float
device = torch.device("cuda:0") 
print(device)

In [None]:
import utilities
import loss_function
from Network import network

## Download and split the data.

In [None]:
# Make sure the data is avialable in the data folder.

datai = []
observation= []
for i in range(1): # As data in multiple files, looping over the all files.
    datai= imread("./data/data4Sai/"+'BluryNoisy_tubulins_'+str(i+1)+'_SOFImodel.tif')
    observation.extend(datai)
observation = np.array(observation)

print(observation.shape)



In [None]:
# Sample data and shape of the total data set.
plt.figure(figsize=(6,6))
plt.title(label='Single raw Image')
plt.imshow(observation[0],cmap='gray')
plt.show()

print("Total number imgs in the give dataset:" + str(observation.shape[0]))

Spli the data 85% for the thraining and 15% for validation.

In [None]:
train_images, val_images = utilities.get_split_data(observation)

## Convert data to tensor for training

In [None]:
x_train_tensor, x_val_tensor, data_mean, data_std = utilities.preprocess(train_images, val_images)

## Required params for trianing. 

In [None]:
n_depth=2 #number of layers for the network.
batch_size=32 #batch size.
max_epochs=1 #Total numebr of epochs for trianing.
real_noise=False # Predifined noise.
noise_model = None #Predefined noise model.
gaussian_noise_std = 10 # The considered noise level for known case.
sigma = 3 #Blur level in FWHM in nanometers
shape = 256 #Shape of the input image.
img_per_each_epochs = [] #To same the img at some iteration for analysis.

## Convolution generation.

The below cells gnerate the blur kernal, then that is transformed into frequency domain with the help of fast furio transform(fft). In the loss function the predicted image also transformed into frequency domain, after that we multiplied both blur kernal and image which are in frequency domain. Later we transformed back into image domin with inverse fast furio transform(ifft).

In [None]:
def convolution(sigma, shape):
        """Generating blur kernal for the convultion.
        """
        sigma = sigma
        n = shape
        t = np.concatenate( (np.arange(0,n/2+1), np.arange(-n/2,-1)) )
        [Y,X] = np.meshgrid(t,t)
        h = np.exp( -(X**2+Y**2)/(2.0*float(sigma)**2) )
        h = h/np.sum(h)
        hf = np.real(np.fft.fft2(h))
        # hf = torch.from_numpy(hf).to(device='cuda') 
        return hf

In [None]:
hf = convolution(sigma, shape) # Generating psf in frequency domain.

In [None]:
# Method to check effectivly for posterior collapse.

def train_network(x_train_tensor, x_val_tensor, batch_size, data_mean, data_std, gaussian_noise_std, 
                  noise_model, hf, reg_parameter, method, n_depth, max_epochs, model_name, basedir, log_info=False):
    """"To train the model along wiht added check points for tensorflow observation.
    """
    print("---------------------------------------------")
    print("The method: "+str(method))
    print("Regularisation parameter: "+str(reg_parameter))
    print("---------------------------------------------")
    train_loader,val_loader = utilities.create_dataloaders(x_train_tensor, x_val_tensor, batch_size)
    collapse_flag = True
    if not os.path.exists(basedir):
        os.makedirs(basedir)
    
    checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=basedir,
    filename=model_name+'_best',
    save_last=True,
    save_top_k=1,
    mode='min',)
    checkpoint_callback.CHECKPOINT_NAME_LAST = model_name+"_last"
    logger = TensorBoardLogger(basedir, name= "", version="", default_hp_metric=False)
    weights_summary="full" #if log_info else None
    if not log_info:
        pl.utilities.distributed.log.setLevel(logging.ERROR)
    posterior_collapse_count = 0
    
    #Check posterior collapse for 20 times.
    while collapse_flag and posterior_collapse_count<20:
        collapse_flag, vae = utilities.create_model_and_train(basedir,data_mean,data_std,gaussian_noise_std,noise_model, hf,reg_parameter,method,
                                               n_depth,max_epochs,logger,checkpoint_callback,
                                               train_loader,val_loader,kl_annealing=False, weights_summary=weights_summary)
        if collapse_flag:
            posterior_collapse_count=posterior_collapse_count+1
        
    if collapse_flag:
        print("Posterior collapse limit reached, attempting training with KL annealing turned on!")
        while collapse_flag:
            collapse_flag, vae = utilities.create_model_and_train(basedir,data_mean,data_std,gaussian_noise_std,noise_model, hf,reg_parameter,method,
                                               n_depth,max_epochs,logger,checkpoint_callback,
                                               train_loader,val_loader,kl_annealing=True, weights_summary=weights_summary)
    return vae

## Train the DivBlurring model

In [None]:
method = ['DivBlurring','DivBlurring_l1','DivBlurring_l2',
                    'DivBlurring_PCReg_1e3','DivBlurring_PCReg_1e5',
                    'DivBlurring_PCReg_l1'] # Definced approches.
model_name = ['models_DivBlurring','models_DivBlurring_l1Regu_1e10','models_DivBlurring_l2Regu_1e10',
                    'models_DivBlurring_PCReg_1e3','models_DivBlurring_PCReg_1e5',
                    'models_DivBlurring_PC_l1X_Reg_1e3_1e10'] # a name used to identify the model.
basedir = model_name # the base directory in which our model will be saved, we prefer same directory as model name.
reg_parameter = [0, 1e-10, 1e-10, 1e-3, 1e-5, [1e-3,1e-10]] # Regularisation parameters with respec to the methods.

In [None]:
# Training all methods with predifined regularier parameters.

for i in range(len(method)):
    vae = train_network(x_train_tensor, x_val_tensor, batch_size, data_mean, data_std, 
                       gaussian_noise_std, noise_model, hf = hf, reg_parameter=reg_parameter[i], method=method[i], n_depth=n_depth, max_epochs=max_epochs, 
                       model_name=model_name[i], basedir=basedir[i], log_info=False)

## Trained models will be saved after successful trianing respect to basedir.