<a href="https://colab.research.google.com/github/Bantami/All-Optical-QPM/blob/main/Colab/GPC_baseline_inference_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Colab Setting up Scripts


*   Downloading repository, dataset and models
*   Install pip packages


In [None]:
!git clone https://github.com/Bantami/All-Optical-QPM.git

!chmod 755 All-Optical-QPM/colab_setup.sh
!All-Optical-QPM/colab_setup.sh
!mkdir results

In [None]:
import sys
sys.path.append('All-Optical-QPM')

import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import shutil

from torchvision.utils import make_grid
from mpl_toolkits.axes_grid1 import make_axes_locatable

from modules.dataloaders import *
from modules.eval_metrics import *

### Required custom functions:

In [None]:
def circ(n_neurons_input, neuron_size, delta_fr,device='cpu'):
    '''
        Function to obtain the filter mask with the central region
        
            Args:
                n_neurons_input : Number of neurons in the spatial domain input | int 
                neuron_size     : Size of a neuron in the spatial domain | float 
                delta_fr        : Radius of the central region of the filter in the frequency domain | float
                device          : Device to do the inference on
            Returns:
                central_filter  : Filter mask | torch.Tensor     
    '''

    dx= neuron_size
    N= n_neurons_input
    
    # Creating the fx, fy grid
    fx = torch.arange(-1/(2*dx),1/(2*dx),1/(N*dx)).to(device) 
    fx = torch.tile(fx, (1,N)).view(N,N)
    fy = torch.arange(1/(2*dx),-1/(2*dx),-1/(N*dx)).view(N,1).to(device)
    fy = torch.tile(fy, (1,N)).view(N,N)
    circle = ((delta_fr**2) - (abs(fx)**2 + abs(fy)**2))+100
    
    # central_filter = circle.clamp(0, 1)
    central_filter = torch.sigmoid(circle)

    return central_filter

def input_circle(n_i,sf, circle = False, device='cpu'):
    '''
        Function to obtain a circular mask to apply on the input
    
            Args:
                n_i    : Number of neurons in the spatial domain input | int 
                sf     : Shrink factor | int 
                circle : Indicates if a circular mask is required | bool
                device : The device on which the model runs

            Returns:
                circ   : Circular mask for the input | torch.Tensor
    '''
    
    if circle:
        rc = (n_i//2)//sf # Radius of the input region of interest
        
        # x,y grid
        xc = torch.arange(-n_i//2,n_i//2,1) 
        xc = torch.tile(xc, (1,n_i)).view(n_i,n_i).to(torch.cfloat)
        yc = torch.arange(n_i//2,-n_i//2,-1).view(n_i,1)
        yc = torch.tile(yc, (1,n_i)).view(n_i,n_i).to(torch.cfloat)

        circ = (abs(xc)**2 + abs(yc)**2 <= (rc)**2).to(torch.float32).view(1,n_i,n_i).to(device)
    else:
        circ = torch.ones(1,n_i,n_i).to(device)
    
    return circ

def make_circular(img,device='cpu'):
    '''
        Function to obtain a circular mask to apply on the filter
    
            Args:
                img    : The filter | torch.Tensor

            Returns:
                circ   : Circular mask for the filter | torch.Tensor
    '''
    
    img_size = img.shape[0]
    shrinkFactor = 1

    rc = (img_size//2)//shrinkFactor
    xc = torch.arange(-img_size//2,img_size//2,1) 
    xc = torch.tile(xc, (1,img_size)).view(img_size,img_size).to(torch.cfloat)

    yc = torch.arange(img_size//2,-img_size//2,-1).view(img_size,1)
    yc = torch.tile(yc, (1,img_size)).view(img_size,img_size).to(torch.cfloat)

    circ = (abs(xc)**2 + abs(yc)**2 <= (rc)**2).to(torch.float32).view(1,img_size,img_size).to(device)

    return circ.detach().cpu()

In [None]:
def inference_loop(cfg, experiment, dataset_debug_opt, max_angle, data_loader):
    '''
        Function to run the test set and plot results
        
            Args:
                cfg               : Configurations dictionary
                experiment        : The filter configurations resulted from the search experiments
                dataset_debug_opt : Supporting option to clip the dataset phase values
                max_angle         : The maximum phase value that can be set in the dataset
                data_loader       : The name of the dataloader
    '''
    
    cfg['dataset_debug_opts'] = dataset_debug_opt
    cfg['angle_max'] = max_angle
    
    A = torch.tensor(experiment[0])
    B = torch.tensor(experiment[1])
    theta = torch.tensor(experiment[2])
    dfr = torch.tensor(experiment[3])
    output_scale = torch.tensor(experiment[4])
    
    img_size = cfg['img_size']
    shrinkFactor = cfg['shrink_factor']
    neuron_size = cfg['neuron_size']
    device = cfg['device']
    angle_max = eval(max_angle)

    H = A + (B * torch.exp(1j * theta) - A)*circ(img_size, cfg['neuron_size'], dfr,device)
    mask = make_circular(H.abs()).to(device)[0]
    H = H*mask
    
    ssim11_rd = []

    if(shrinkFactor!=1):
        # To obtain the starting position and ending position of the original image within the padded image
        csize = int(img_size/shrinkFactor)
        spos  = int((img_size - csize)/2)
        epos  = spos + csize
    else:
        spos = 0
        epos = img_size

    incircle = input_circle(img_size, shrinkFactor, circle = True, device=device) # Creating a circular mask to apply on the input

    _, _, test_loader = eval(data_loader)(img_size, 
                             cfg['train_batch_size'], 
                             task_type= 'phase2intensity',
                             shrinkFactor = shrinkFactor,
                             cfg = cfg)

    for idx, (x, y) in enumerate(test_loader):

        ground_truth = x[:,0].to(device) * incircle # Remove channel dimension
        X = torch.fft.fftshift(torch.fft.fft2(ground_truth)) # Obtaining the Fourier transform of the input
        filtered = X*H.unsqueeze(dim=0).to(device) # Applying the GPC filter
        out = torch.fft.ifft2(torch.fft.ifftshift(filtered)).to(torch.complex64) # Reconstructed image

        out = out[:,spos:epos,spos:epos] # Crop the reconstructed image

        if data_loader == 'get_qpm_np_dataloaders' or data_loader == 'get_bacteria_dataloaders':
            # CLIP ANGLE TO -> [0, angle_max]
            y = torch.clip(y, min= 0, max= angle_max).to(device) * incircle # y will have the original phase image
            gt = y[:,0].to(device)[:,spos:epos,spos:epos] /angle_max # Crop and normalize the groundtruth image
            gt_angle = gt
            ground_truth = ground_truth[:,spos:epos,spos:epos].abs() + 1j*gt # Preparing the groundtruth in a suitable format for the plot functions
        else:
            ground_truth = ground_truth[:,spos:epos,spos:epos] # Crop the groundtruth image
            gt_angle = (ground_truth.angle()%(2*np.pi))/angle_max
            gt_abs = ground_truth.abs()
            ground_truth = gt_abs+1j*gt_angle

        pred_out= output_scale * (out.abs()**2) * incircle[:,spos:epos,spos:epos]

        ssim11_rd.append(ssim_pytorch(pred_out, gt_angle, k= 11, range_independent = False))

    print("========\nMean SSIM = ", np.mean(ssim11_rd))
    
    if data_loader == 'get_qpm_np_dataloaders':
        s = 6
        e = 10
    else:
        s = 10
        e = 14
        
    pred_img_set= pred_out[s:e]/output_scale # .unsqueeze(dim= 1) when making the grid
    gt_img_set= ground_truth[s:e] # .unsqueeze(dim= 1) when making the grid
    
    gt_angle = gt_img_set.detach().cpu().imag
    gt_abs = gt_img_set.detach().cpu().real
        
    if data_loader == 'get_mnist_dataloaders':
        pred_img = pred_img_set[0]
        gt_angle = gt_angle[0]
        gt_abs = gt_abs[0]
    elif data_loader == 'get_qpm_np_dataloaders':
        pred_img = pred_img_set[0]
        gt_angle = gt_angle[0]
        gt_abs = gt_abs[0]
    elif data_loader == 'get_bacteria_dataloaders':
        pred_img = pred_img_set[3]
        gt_angle = gt_angle[3]
        gt_abs = gt_abs[3]

    plt.figure(figsize=(4,4))
    plt.title("Grountruth Phase")
    plt.imshow(gt_angle.numpy(),vmin=0)
    plt.colorbar()


    plt.figure(figsize=(4,4))
    plt.imshow(pred_img.abs().detach().cpu().numpy(),vmin=0)
    plt.colorbar()
    plt.title('Reconstructed : Intensity')


### Amp+Phase and Phase GPC filter configurations for each dataset are included in the 'experiments' dictionary in the following format:

``` 'dataset_type' : [A, B, theta, dfr, output_scale] ```

* **dataset_type** : The name of the dataset along with the filter type where applicable.

> _MNIST_pi_ - MNIST [0, $\pi$]

> _MNIST_2pi_ - MNIST [0, 2$\pi$]

> _HeLa_pi_ - HeLA [0, $\pi$]

> _HeLa_2pi_ - HeLA [0, 2$\pi$]

> _Bacteria_ - Bacteria

> The models for the different datasets in the phase-only filter configuration has a suffix **_phase** to each dataset name
* **A**: Amplitude coefficient of the transmission coefficients in the outer region of the GPC filter
* **B**: Amplitude coefficient of the transmission coefficients in the central region of the GPC filter
* **theta**: The applied phase shift to the light falling onto the central region of the filter
* **dfr**: Radius of the central region of the filter in the fourier plane
* **output_scale**: Scaling factor applied to the reconstructed image

### Keywords in the configs dictionary:

* **device**: The device the model runs on
* **lambda_**: Wavelength of the light (m)
* **neuron_size**: Size of an element of the input image (m)
* **img_size**: Size of the input image after padding
* **shrink_factor**: The padding factor of the input image. 
* **n_layers**: Number of optical layers in the model
* **train_batch_size**: Training batch size
* **torch_seed**: Pytorch seed for rand functions
* **task_type**: Indicates phase-to-intensity conversion task. Defaults to 'phase2intensity'

In [None]:
# Amp+Phase and Phase GPC filter configurations for each dataset
experiment = {
    'MNIST_pi': [0.2640, 0.9652, 2.8833, 37037.7695, 3.7200],
    'MNIST_pi_phase': [1.0, 1.0, 2.6852, 74075.5391, 0.235],
    'MNIST_2pi': [0.164,0.536,3.2093,37037.7695,8.2],
    'MNIST_2pi_phase': [1.0, 1.0, 3.2507, 61729.6133, 0.205],
    'HeLa_pi': [0.5548, 0.9747, 1.5568, 53371.5, 3.3567],
    'HeLa_pi_phase': [1.0, 1.0, 1.5475, 51433.4219, 1.6629],
    'HeLa_2pi': [0.5001,0.9843, 1.8024, 53371.5, 3.6208],
    'HeLa_2pi_phase': [1.0,1.0, 1.8291, 61729.6133, 1.],
    'Bacteria': [0.16, 0.96, 3.0229, 24691.8457, 8.44],
    'Bacteria_phase': [1.0, 1.0, 2.9502, 51433.4219, 0.1934]
}

cfg = {
    'device': 'cuda:0', # cpu

    'lambda_': 6.328e-07,
    'neuron_size': 3.164e-07,
    'img_size': 256,
    'shrink_factor': 8,
    'n_layers': 1,

    'train_batch_size': 32, # 15
    'torch_seed': 10,

    'task_type': 'phase2intensity',
}    

# Amp+Phase GPC

#### MNIST 

In [None]:
inference_loop(cfg,
                experiment['MNIST_pi'],
                'clip_phase',
                'np.pi',
                'get_mnist_dataloaders')

#### MNIST [0, 2$\pi$]

In [None]:
inference_loop(cfg,
                experiment['MNIST_2pi'],
                'clip_phase',
                '2*np.pi',
                'get_mnist_dataloaders')

#### HeLa [0, $\pi$]

In [None]:
inference_loop(cfg,
                experiment['HeLa_pi'],
                'clip_phase@phase_set_pi',
                'np.pi',
                'get_qpm_np_dataloaders')

#### HeLa [0, 2$\pi$]

In [None]:
inference_loop(cfg,
                experiment['HeLa_2pi'],
                'clip_phase',
                '2*np.pi',
                'get_qpm_np_dataloaders')

#### Bacteria

In [None]:
inference_loop(cfg,
                experiment['Bacteria'],
                'clip_phase',
                '2*np.pi',
                'get_bacteria_dataloaders')

# Phase GPC

#### MNIST 

In [None]:
inference_loop(cfg,
                experiment['MNIST_pi_phase'],
                'clip_phase',
                'np.pi',
                'get_mnist_dataloaders')

#### MNIST [0, 2$\pi$]

In [None]:
inference_loop(cfg,
                experiment['MNIST_2pi_phase'],
                'clip_phase',
                '2*np.pi',
                'get_mnist_dataloaders')

#### HeLa [0, $\pi$]

In [None]:
inference_loop(cfg,
                experiment['HeLa_pi_phase'],
                'clip_phase@phase_set_pi',
                'np.pi',
                'get_qpm_np_dataloaders')

#### HeLa [0, 2$\pi$]

In [None]:
inference_loop(cfg,
                experiment['HeLa_2pi_phase'],
                'clip_phase',
                '2*np.pi',
                'get_qpm_np_dataloaders')

#### Bacteria

In [None]:
inference_loop(cfg,
                experiment['Bacteria_phase'],
                'clip_phase',
                '2*np.pi',
                'get_bacteria_dataloaders')