# Training Notebook for the Complex-valued CNN for the HeLA $[0,\pi]$ dataset

This notebook contains the codes to train the complex-valued CNN implementation to test the feasibility of linearly converting phase to intensity all-optically (as discussed in **Section 2.2.** in the paper). 

The **configs** dictionary contains the configurations required to train the complex-valued CNN end-to-end.

### Keywords in the configs dictionary:

* **device**: The device the model trains on
* **model**: The model type used for training
* **img_size**: Size of the input image after padding
* **shrink_factor**: The padding factor of the input image. However shrinking is not supported for _complex_cnn_ modely type. Therefore, this variable should be 1 at all times.
* **n_layers**: Number of convolutional layers in the model
* **n_channels**: Number of channels in a convolutional layer. This will be added to all layers except the last layer
* **all_bias**: Whether to add the bias term in all layers (bool)
* **last_bias**: Whether to add the bias term only in the last layer (bool)
* **kernel_size**: Convolutional kernel size
* **output_scale**: Scaling factor of the reconstructed image
* **output_scale_learnable**: Option to learn the scaling factor of the reconstructed image in the Learned Transformation loss (bool)
* **output_bias**: Bias added to the reconstructed image
* **output_bias_learnable**: Option to learn the bias value added to the reconstructed image in the Learned Transformation loss (bool)
* **save_results_local**: Indicates after how many number of epochs results should be saved locally
* **learning_rate**: Learning rate for the optimizer
* **epochs**: Number of epochs for training
* **loss_func**: Options for loss functions: 'BerHu(\'mean\',0.95).to(device)','nn.MSELoss().to(device)'
* **train_batch_size**: Training batch size
* **torch_seed**: Pytorch seed for rand functions
* **task_type**: Indicates phase-to-intensity conversion task. Defaults to 'phase2intensity'
* **testing**: For code testing purposes (bool)
* **get_dataloaders**: Options for supporting dataloaders of the datasets:
    * MNIST                   : 'get_mnist_dataloaders'
    * HeLA and HeLA $[0,\pi]$ : 'get_qpm_np_dataloaders'
    * Bacteria                : 'get_bacteria_dataloaders'
* **angle_max**: Options that indicate the maximum phase value that can be set in the dataset:
    * MNIST         : 'np.pi'
    * HeLA          : '2*np.pi'
    * HeLA $[0,\pi]$ : 'np.pi'
    * Bacteria      : '2*np.pi'
* **dataset_debug_opts**: Supporting options:
    * 'clip_phase' : Clip the phase to the angle indicated in angle max 
    * 'clip_phase@phase_set_pi' : Clip the phase to the angle indicated in angle max and then set the maximum angle to $\pi$ (This option is only available for the HeLA dataset).
* **exp_name**: Experiment name. Results of each run will be saved in a folder with this name within the /results folder

### Reproducing Results

```
Run all cells
```

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../')

from modules.train import train_and_log

configs = {
    'device': 'cuda:0',
    'model' : 'complex_cnn',

    'img_size': 32,
    'shrink_factor': 1 ,
    'n_layers': 5,
    'n_channels':1,

    'all_bias': False, 
    'last_bias':True, 
    'kernel_size':3, 
    
    'output_scale':1.0, 
    'output_scale_learnable':False, 
    'output_bias':0.0,
    'output_bias_learnable':False,
    

    'save_results_local':1, 

    'learning_rate': 0.01,
    'epochs': 50,
    'loss_func': 'BerHu(\'mean\',0.95).to(device)', 
    'train_batch_size': 32,
    'torch_seed': 10,
    'task_type' : 'phase2intensity',

    'testing' : False, 

    'get_dataloaders' : 'get_qpm_np_dataloaders', 

    'angle_max': 'np.pi', 
    'dataset_debug_opts': 'clip_phase@phase_set_pi'
}

if configs['testing']:
    configs['exp_name'] = "test_HelaPi" # Use a different name for code test experiments
else:
    configs['exp_name'] = f"network({configs['model']})@dataset(HeLApi)" # Use a different name for training experiments

train_and_log(configs) # Training
