# Training Notebook for the Learnable Fourier Filter for the HeLA $[0,\pi]$ dataset

This notebook contains the codes to train the learnable Fourier filter (LFF) (results are discussed in **Section 2.3.** in the paper). 

The **configs** dictionary contains the configurations required to train the complex-valued CNN end-to-end.

### Keywords in the configs dictionary:

* **device**: The device the model trains on
* **model**: The model type used for training
* **neuron_size**: Size of an element of the input image (m)
* **img_size**: Within the dictionary this is set to the size of the input image before padding. It is then reassigned automatically for padding based on the value in the shrink_factor variable
* **shrink_factor**: The padding factor of the input image. 
* **n_layers**: Number of filter layers in the model. For the LFF it is fixed to 1.
* **save_results_local**: Indicates after how many number of epochs results should be saved locally
* **filter_circular**: Indicates if the filter should be circular (True: circular, False: square)
* **input_circular**: Indicates if the input should be circular (True: circular, False: square).
* **ring_optimize**: Option to optimize the filter in a ring-like format (bool)
* **ring_step**: Number of pixels in a ring
* **output_scale**: Scaling factor of the reconstructed image
* **output_scale_learnable**: Option to learn the scaling factor of the reconstructed image in the Learned Transformation loss (bool)
* **learning_rate**: Learning rate for the optimizer
* **epochs**: Number of epochs for training
* **loss_func**: Options for loss functions: 'BerHu(\'mean\',0.95).to(device)','nn.MSELoss().to(device)'
* **train_batch_size**: Training batch size
* **torch_seed**: Pytorch seed for rand functions
* **learn_type**: Indicates if only the amplitude coefficients or the phase coefficients or both types of coefficients should be learned as the weights of the filter. Options:
    * no: Non-learnable weights
    * amp: Only the amplitude coefficients will be learned
    * phase: Only the phase coefficients will be learned
    * both: Both types of coefficients will be learned
* **testing**: For code testing purposes (bool)
* **task_type**: Indicates phase-to-intensity conversion task. Defaults to 'phase2intensity'
* **get_dataloaders**: Options for supporting dataloaders of the datasets:
    * MNIST                   : 'get_mnist_dataloaders'
    * HeLA and HeLA $[0,\pi]$ : 'get_qpm_np_dataloaders'
    * Bacteria                : 'get_bacteria_dataloaders'
* **angle_max**: Options that indicate the maximum phase value that can be set in the dataset:
    * MNIST         : 'np.pi'
    * HeLA          : '2*np.pi'
    * HeLA $[0,\pi]$ : 'np.pi'
    * Bacteria      : '2*np.pi'
* **dataset_debug_opts**: Supporting options:
    * 'clip_phase' : Clip the phase to the angle indicated in angle max 
    * 'clip_phase@phase_set_pi' : Clip the phase to the angle indicated in angle max and then set the maximum angle to $\pi$ (This option is only available for the HeLA dataset).
* **exp_name**: Experiment name. Results of each run will be saved in a folder with this name within the /results folder


### Reproducing Results

```
Run all cells
```

In [None]:
%load_ext autoreload
%autoreload 2

# Importing useful libraries
import sys
sys.path.append('../../')

from modules.train import train_and_log
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def circ(n_neurons_input, neuron_size, delta_fr=16):
    '''
        Function to obtain the filter mask with the central region
        
            Args:
                n_neurons_input : Number of neurons in the spatial domain input | int 
                neuron_size     : Size of a neuron in the spatial domain | float 
                delta_fr        : Radius of the central region of the filter in the frequency domain | float

            Returns:
                central_filter  : Filter mask | torch.Tensor     
    '''

    dx= neuron_size
    N= n_neurons_input
    
    # Creating the fx, fy grid
    fx = torch.arange(-1/(2*dx),1/(2*dx),1/(N*dx)) 
    fx = torch.tile(fx, (1,N)).view(N,N).to(torch.cfloat)
    fy = torch.arange(1/(2*dx),-1/(2*dx),-1/(N*dx)).view(N,1)
    fy = torch.tile(fy, (1,N)).view(N,N).to(torch.cfloat)
    
    central_filter = (abs(fx)**2 + abs(fy)**2 <= (delta_fr)**2)

    return central_filter

def get_weights(A,B,theta,circ_filter):
    '''
        Function to obtain GPC filter weights
        
            Args:
                A           : Filter amplitude coefficient in the outer region | float 
                B           : Filter amplitude coefficient in the central region | float 
                theta       : The applied phase shift at the central region of the filter (rad) | float
                circ_filter : The filter mask with the central region | torch.Tensor

            Returns:
                H           : Generalized Phase Contrast filter weights | torch.Tensor     
    '''
    
    C = torch.tensor(B * (1/A) * np.exp(1j * theta) - 1)
    H = A * (1 + C*circ_filter)

    return H

In [None]:
# Keep the GPC_init variable set to False at all times
GPC_init = False # Variable indicating if GPC initialization is used or random initialization of the weights are used

configs = {
    'device': 'cuda:0',
    'model': 'fourier_model',

    'neuron_size': 3.164e-07,
    'img_size': 32,
    'shrink_factor': 8,
    'n_layers': 1,
    'save_results_local':100,
    'filter_circular': True,
    'input_circular' : True,
    
    'GPC_init':GPC_init,
    'ring_optimize':False,
    'ring_step':1,

    'output_scale':1.0,
    'output_scale_learnable':True,

    'learning_rate': 0.01,
    'epochs': 300,
    'loss_func': 'BerHu(\'mean\',0.95).to(device)',
    'train_batch_size': 32,
    'torch_seed': 10, 
    'learn_type': 'both', 

    'testing': False,

    'task_type': 'phase2intensity',

    'get_dataloaders' : 'get_qpm_np_dataloaders',
    'angle_max': 'np.pi',
    'dataset_debug_opts': 'clip_phase@phase_set_pi'
}


configs['img_size'] = configs['img_size'] * configs['shrink_factor'] # Reassigning the image size based on the padding factor

if configs['testing']:
    configs['exp_name'] = "test_HeLA_pi" # Use a different name for code test experiments
else:
    configs['exp_name'] =  f"network({configs['model']})@dataset(HeLA_pi)"  # Use a different name for training experiments
    
train_and_log(configs) # Training
