<a href="https://colab.research.google.com/github/Bantami/All-Optical-QPM/blob/main/Colab/lff_pretrained_model_inference_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Colab Setting up Scripts


*   Downloading repository, dataset and models
*   Install pip packages


In [None]:
!git clone https://github.com/wadduwagelab/All-Optical-QPM.git

!chmod 755 All-Optical-QPM/colab_setup.sh
!All-Optical-QPM/colab_setup.sh

### Import Libraries

In [None]:
import sys
sys.path.append('All-Optical-QPM')


from torch import nn
from collections import OrderedDict
from torchvision.utils import make_grid
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage.transform import resize
from torchvision import datasets, transforms
import math
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import glob
import os
import json

from modules.dataloaders import *
from modules.diffraction import *
from modules.fourier_model import *
from modules.eval_metrics import *
from modules.vis_utils import *

## Use Pretrained Models: Model Selection and Loading

In [None]:
pretrained_models = {
    # Amp+Phase models:
    'AP MNIST'         : 'MNIST_lff_ap',
    'AP MNIST [0,2Pi]' : 'MNIST_2pi_lff_ap',
    'AP HeLa [0,Pi]'   : 'Hela_pi_lff_ap',
    'AP HeLa [0,2Pi]'  : 'Hela_2pi_lff_ap',
    'AP Bacteria'      : 'Bacteria_lff_ap',
    
    # Phase models
    'P MNIST'         : 'MNIST_lff_p',
    'P MNIST [0,2Pi]' : 'MNIST_2pi_lff_p',
    'P HeLa [0,Pi]'   : 'Hela_pi_lff_p',
    'P HeLa [0,2Pi]'  : 'Hela_2pi_lff_p',
    'P Bacteria'      : 'Bacteria_lff_p',
    
    # LRF models
    'LRF MNIST'         : 'MNIST_lrf',
    'LRF MNIST [0,2Pi]' : 'MNIST_2pi_lrf',
    'LRF HeLa [0,Pi]'   : 'Hela_pi_lrf',
    'LRF HeLa [0,2Pi]'  : 'Hela_2pi_lrf',
    'LRF Bacteria'      : 'Bacteria_lrf',
} 

In [None]:
def initiate_model(model_name):
  '''
      Function to initiate the model and dataloaders

        Args:
              model_name  : The name of the model - 'MNIST', 'HeLa [0,Pi]', 'HeLa [0,2Pi]', 'Bacteria'
              
        Returns:
              cfg         : Configurations dictionary
              model       : Initiated model
              val_loader  : Dataloader containing the test images
  '''

  model = pretrained_models[model_name] #### Specify which dataset you are considering

  folder = 'models/'

  file_name = f'{model}.pth'

  saved = torch.load(folder + file_name,map_location='cuda:0') # Loading pretrained model

  cfg   = saved['cfg']
  cfg['device'] = 'cuda:0'

  model = eval(cfg['model'])(cfg).to(cfg['device'])
  model.load_state_dict(saved['state_dict'])
  model.eval()

  torch.manual_seed(cfg['torch_seed'])

  shrinkFactor = cfg['shrink_factor'] if 'shrink_factor' in cfg else 1
  img_size     = cfg['img_size']
  angle_max    = eval(cfg['angle_max'])
  device = cfg['device']

  if(shrinkFactor!=1):
      # To obtain the starting position and ending position of the original image within the padded image 
      csize = int(img_size/shrinkFactor)
      spos  = int((img_size - csize)/2)
      epos  = spos + csize
  else:
      spos = 0
      epos = img_size
      
  dataloader = eval(cfg['get_dataloaders'])
  print(cfg['get_dataloaders'])

  _, _, test_loader =  dataloader(cfg['img_size'], cfg['train_batch_size'], cfg['torch_seed'],  task_type= cfg['task_type'], shrinkFactor = shrinkFactor, cfg = cfg)
  print(len(test_loader))

  cfg['spos'] = spos
  cfg['epos'] = epos

  return cfg, model, test_loader

### Run Inference on Unseen Data

In [None]:
def inference(cfg, model, test_loader):
  '''
      Function to infer on unseen data

        Args:
              cfg         : Configurations dictionary
              model       : Initiated model
              test_loader  : Dataloader containing the test images              
  '''

  spos = cfg['spos']
  epos = cfg['epos']
  angle_max    = eval(cfg['angle_max'])
  shrinkFactor = cfg['shrink_factor'] if 'shrink_factor' in cfg else 1
  img_size     = cfg['img_size']
  inp_circular = cfg['input_circular'] if 'input_circular' in cfg.keys() else False # If the input field is propagated through a circular aperture
  device = cfg['device']

  if inp_circular: # Creating a circular mask to apply on the input
      rc = (img_size//2)//shrinkFactor
      xc = torch.arange(-img_size//2,img_size//2,1) 
      xc = torch.tile(xc, (1,img_size)).view(img_size,img_size).to(torch.cfloat)

      yc = torch.arange(img_size//2,-img_size//2,-1).view(img_size,1)
      yc = torch.tile(yc, (1,img_size)).view(img_size,img_size).to(torch.cfloat)

      circ = (abs(xc)**2 + abs(yc)**2 <= (rc)**2).to(torch.float32).view(1,img_size,img_size).to(device)
  else:
      circ = torch.ones(1,img_size,img_size).to(device)

  ssim_scores = []

  for idx, (x, y) in enumerate(test_loader):  # Test loop 
      gt = x[:,0].to(cfg['device']) * circ # Groung truth image (input)
      pred_img,out_scale = model(gt) 

      
      pred_img = pred_img[:,spos:epos,spos:epos] # Crop the reconstructed image
      gt       = gt[:,spos:epos,spos:epos]       # Crop the groundtruth image
      
      pred_out = out_scale * (pred_img.abs()**2) * circ[:,spos:epos,spos:epos]

      if(cfg['get_dataloaders']=="get_mnist_dataloaders"):
          gt_angle = ((gt.angle()%(2*np.pi))/angle_max)
          ground_truth = gt.abs()+1j*gt_angle
      else:
          # CLIP ANGLE TO -> [0, angle_max]
          y = torch.clip(y, min= 0, max= angle_max).to(device) * circ # y will have the original phase image
          gt_angle = y[:,0].to(cfg['device'])[:,spos:epos,spos:epos] / angle_max
          ground_truth = gt.abs() + 1j*gt_angle
          
      pred_out = pred_out.to(torch.float32)
      ssim_scores.append(ssim_pytorch(pred_out, gt_angle, k= 11,range_independent=False))
      

  print("========\nMean SSIM = ", np.mean(ssim_scores))

  if cfg['get_dataloaders'] == 'get_qpm_np_dataloaders':
    s = 6
    e = 10
  else:
    s = 10
    e = 14
        
  pred_img_set= pred_out[s:e]/out_scale # .unsqueeze(dim= 1) when making the grid
  gt_img_set= ground_truth[s:e] # .unsqueeze(dim= 1) when making the grid
    
  gt_angle = gt_img_set.detach().cpu().imag
  gt_abs = gt_img_set.detach().cpu().real
        
  if cfg['get_dataloaders'] == 'get_mnist_dataloaders':
    pred_img = pred_img_set[0]
    gt_angle = gt_angle[0]
    gt_abs = gt_abs[0]
  elif cfg['get_dataloaders'] == 'get_qpm_np_dataloaders':
    pred_img = pred_img_set[0]
    gt_angle = gt_angle[0]
    gt_abs = gt_abs[0]
  elif cfg['get_dataloaders'] == 'get_bacteria_dataloaders':
    pred_img = pred_img_set[3]
    gt_angle = gt_angle[3]
    gt_abs = gt_abs[3]

  plt.figure(figsize=(4,4))
  plt.title("Grountruth Phase")
  plt.imshow(gt_angle.numpy(),vmin=0)
  plt.colorbar()
  

  plt.figure(figsize=(4,4))
  plt.imshow(pred_img.abs().detach().cpu().numpy(),vmin=0)
  plt.colorbar()
  plt.title('Reconstructed : Intensity')

# Amp+Phase LFF Models

#### MNIST

In [None]:
cfg, model, test_loader = initiate_model('AP MNIST')
inference(cfg, model, test_loader)

#### MNIST [0, 2$\pi$]

In [None]:
cfg, model, test_loader = initiate_model('AP MNIST [0,2Pi]')
inference(cfg, model, test_loader)

#### HeLa [0, $\pi$]

In [None]:
cfg, model, val_loader = initiate_model('AP HeLa [0,Pi]')
inference(cfg, model, val_loader)

#### HeLa [0, $2\pi$]

In [None]:
cfg, model, val_loader = initiate_model('AP HeLa [0,2Pi]')
inference(cfg, model, val_loader)

#### Bacteria

In [None]:
cfg, model, val_loader = initiate_model('AP Bacteria')
inference(cfg, model, val_loader)

# Phase LFF Models

#### MNIST

In [None]:
cfg, model, test_loader = initiate_model('P MNIST')
inference(cfg, model, test_loader)

#### MNIST [0, 2$\pi$]

In [None]:
cfg, model, test_loader = initiate_model('P MNIST [0,2Pi]')
inference(cfg, model, test_loader)

#### HeLa [0, $\pi$]

In [None]:
cfg, model, val_loader = initiate_model('P HeLa [0,Pi]')
inference(cfg, model, val_loader)

#### HeLa [0, 2$\pi$]

In [None]:
cfg, model, val_loader = initiate_model('P HeLa [0,2Pi]')
inference(cfg, model, val_loader)

#### Bacteria

In [None]:
cfg, model, val_loader = initiate_model('P Bacteria')
inference(cfg, model, val_loader)

# LRF Models

#### MNIST

In [None]:
cfg, model, test_loader = initiate_model('LRF MNIST')
inference(cfg, model, test_loader)

#### MNIST [0, 2$\pi$]

In [None]:
cfg, model, test_loader = initiate_model('LRF MNIST [0, 2Pi]')
inference(cfg, model, test_loader)

#### HeLa [0, $\pi$]

In [None]:
cfg, model, val_loader = initiate_model('LRF HeLa [0,Pi]')
inference(cfg, model, val_loader)

#### HeLa [0, 2$\pi$]

In [None]:
cfg, model, val_loader = initiate_model('LRF HeLa [0,2Pi]')
inference(cfg, model, val_loader)

#### Bacteria

In [None]:
cfg, model, val_loader = initiate_model('LRF Bacteria')
inference(cfg, model, val_loader)