# Training the BivBlurring model

In this notebook we are training the DivBlurring model. For the training, we are using the PyTorch_Lighiting framefwork on synthatic data. This data generated based on the realistic data.

## Import packages:

In [None]:
import warnings
warnings.filterwarnings('ignore')
import torch
import os
import urllib
import zipfile
from torch.distributions import normal
import matplotlib.pyplot as plt, numpy as np, pickle
from scipy.stats import norm
from tifffile import imread
import sys
from sklearn.feature_extraction import image
from tqdm import tqdm
sys.path.append('../../')

In [None]:
import numpy as np
import time
from glob import glob
from tifffile import imsave
from sklearn.cluster import MeanShift
from matplotlib import pyplot as plt
from IPython.display import clear_output
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
import pytorch_lightning as pl

In [None]:
import logging
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

In [None]:
import torch.optim as optim
from collections import OrderedDict
from torch.nn import init
import matplotlib.pyplot as plt
import datetime

In [None]:
dtype = torch.float
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader
device = torch.device("cuda:0") 
print(device)

In [None]:
import utilities
import loss_function
import network

## Download and load the DATA

In [None]:
# Make sure the data is avialable in the data folder.

datai = []
observation= []
for i in range(1): # As data in multiple files, looping over the all files.
    datai= imread("./data/data4Sai/"+'BluryNoisy_tubulins_'+str(i+1)+'_SOFImodel.tif')
    observation.extend(datai)
observation = np.array(observation)

print(observation.shape)



In [None]:
# Sample data and shape of the total data set.
plt.figure(figsize=(6,6))
plt.title(label='Single raw Image')
plt.imshow(observation[0],cmap='gray')
plt.show()

print("Total number imgs in the give dataset:" + str(observation.shape[0]))

Spli the data 85% for the thraining and 15% for validation. We chose maximum for training as our synthatic data containg 7000 images.

In [None]:
train_images, val_images = utilities.get_split_data(observation)

## Convert data to tensor for training

In [None]:
x_train_tensor, x_val_tensor, data_mean, data_std = utilities.preprocess(train_images, val_images)

## Train the DivBlurring model

In [None]:
n_depth=2
batch_size=32
max_epochs=150
model_name = 'models_DivBlurring_PCReg_1e3' # a name used to identify the model
basedir = 'models_DivBlurring_PCReg_1e3' # the base directory in which our model will live
real_noise=False
noise_model = None
gaussian_noise_std = 10
reg_parameter = 1e-3
sigma = 3
shape = 256
img_per_each_epochs = []

In [None]:


class MyDataset(Dataset):
    def __init__(self, X, y):
        self.data = X
        self.target = y
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        return x, y
    
    def __len__(self):
        return len(self.data)

In [None]:
def create_dataloaders(x_train_tensor,x_val_tensor,batch_size):
    """Conver the data into dataloaders.
    """
    train_dataset = MyDataset(x_train_tensor,x_train_tensor)
    val_dataset = MyDataset(x_val_tensor,x_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    return train_loader,val_loader

## Network Structure

In [None]:
def train_network(x_train_tensor, x_val_tensor, batch_size, data_mean, data_std, gaussian_noise_std, 
                  noise_model,n_depth, max_epochs, model_name, basedir, log_info=False):
    
    train_loader,val_loader = create_dataloaders(x_train_tensor, x_val_tensor, batch_size)
    collapse_flag = True
    if not os.path.exists(basedir):
        os.makedirs(basedir)
    
    checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=basedir,
    filename=model_name+'_best',
    save_last=True,
    save_top_k=1,
    mode='min',)
    checkpoint_callback.CHECKPOINT_NAME_LAST = model_name+"_last"
    logger = TensorBoardLogger(basedir, name= "", version="", default_hp_metric=False)
    weights_summary="full" #if log_info else None
    if not log_info:
        pl.utilities.distributed.log.setLevel(logging.ERROR)
    posterior_collapse_count = 0
    
    while collapse_flag and posterior_collapse_count<20:
#         print("create vae model")
        collapse_flag, vae = utilities.create_model_and_train(basedir,data_mean,data_std,gaussian_noise_std,noise_model,
                                               n_depth,max_epochs,logger,checkpoint_callback,
                                               train_loader,val_loader,kl_annealing=False, weights_summary=weights_summary)
        if collapse_flag:
            posterior_collapse_count=posterior_collapse_count+1
        
    if collapse_flag:
        print("Posterior collapse limit reached, attempting training with KL annealing turned on!")
        while collapse_flag:
            collapse_flag, vae = utilities.create_model_and_train(basedir,data_mean,data_std,gaussian_noise_std,noise_model,
                                               n_depth,max_epochs,logger,checkpoint_callback,
                                               train_loader,val_loader,kl_annealing=True, weights_summary=weights_summary)
    return vae

In [None]:
vae = train_network(x_train_tensor, x_val_tensor, batch_size, data_mean, data_std, 
                       gaussian_noise_std, noise_model, n_depth=n_depth, max_epochs=max_epochs, 
                       model_name=model_name, basedir=basedir, log_info=False)