# Notes

This code can perform the following tasks:


*   Tune a CNN to directly reconstruct PET images from Sinograms (find a set of hyperparameters)
*   Train a network with a given set of hyperparameters
*   Test the network and record MSE and SSIM values for each image tested
*   Visualize the data and test results
*   Plot training curves, metric histograms, example images

The code is organized into sections. The important sections that you can edit are:


> **User Parameters** - Edit important user parameters and decide what the code will do

> **Configuration Dicts: Supervisory** - Dictionary for supervised learning. Make sure this matches the CNN loaded by the checkpoint file, if you are loading from a checkpoint.

In addition to these, you may find that running a single cell is useful when all variables/classes/files have been loaded into memory. This can be quicker than running everything from scratch.

The cells nested under **Analysis Functions** each have their own changeable parameters.

*Notes:*

*1) Raytune in particular is constantly changing. Therefore, if you are running this code after the authors have ceased maintaining it and there are errors, these are likely due to RayTune classes, methods, or functions being changed. Unfortunately, these seem to happen on a regular basis, as the code is relatively new.*

*2) This code was originally written to tune/train/test not just sinogram to image supervisory networks (sinogram-->image), but also image to sinogram supervisory networks, GANs, CycleGANs, and Cycle + Supervisory networks. These latter capabilities have not been updated, but much of the code survives for this functionality. In the future, the code may be updated once again have these capabilities.*


# Imports

In [None]:
## Ray Tune ##
!pip install ray
!pip install -U tensorboardX
!pip install -U hyperopt    # Hyperband search algorithmn. Another popular option is 'Optuna'

from ray import air, tune, train
from ray.air import session
from ray.tune.schedulers import ASHAScheduler
from ray.tune.schedulers import FIFOScheduler # First in/first out scheduler
from ray.tune import ResultGrid, JupyterNotebookReporter, CLIReporter
from ray.tune.search.hyperopt import HyperOptSearch    # Search Algorithm (current)
#from ray.tune.suggest.ax import AxSearch               # Search Algorithm (couldn't make this work)
#from ray.tune.suggest.bayesopt import BayesOptSearch   # Search Algorithm (couldn't make this work)

## Pytorch ##
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
torch.manual_seed(0)  # For testing purposes

## Torchvision ##
from torchvision.utils import make_grid
from torchvision import transforms

## Numpy/MatPlotLib ##
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.pyplot import savefig
import matplotlib.gridspec as gridspec

## Pandas ##
import pandas as pd

## SciKit #
from skimage import metrics
from skimage.metrics import structural_similarity
from skimage.transform import radon, iradon
from skimage.transform import iradon
from skimage import morphology
from skimage.morphology import opening, erosion
#from skimage.restoration import denoise_bilateral, denoise_tv_chambolle, denoise_wavelet

## SciPy ##
#from scipy.stats import moment as compute_moment

## Python ##
import os
import time
import sys
#from IPython.display import display, clear_output


## Colab ##
from google.colab import drive
drive.mount('/content/drive') # Mount Google Drive

## Determine what Hardware we have ##
device = ('cuda' if torch.cuda.is_available() else 'cpu')

Collecting ray
  Downloading ray-2.49.2-cp312-cp312-manylinux2014_x86_64.whl.metadata (21 kB)
Downloading ray-2.49.2-cp312-cp312-manylinux2014_x86_64.whl (70.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.1/70.1 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ray
Successfully installed ray-2.49.2
Collecting tensorboardX
  Downloading tensorboardx-2.6.4-py3-none-any.whl.metadata (6.2 kB)
Downloading tensorboardx-2.6.4-py3-none-any.whl (87 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorboardX
Successfully installed tensorboardX-2.6.4
Mounted at /content/drive


# User Parameters

In [1]:
#####################
### General Setup ###
#####################

# Basic Options #
run_mode='tune'  # Options: 'tune' / 'train' / 'test' / 'visualize' (visualize data set)
sino_size=180          # Resize input sinograms to this size (integer). Sinograms are square, which was found to give the best results.
sino_channels=3       # Number of channels (sinograms). Options: 1, 3. Unless using scattered coincidences, set to 1.
image_size=90         # Image size (Options: 90, 180). Images are square.
image_channels=1      # Number of channels (images)
train_type='SUP'      # 'SUP' / 'GAN' / 'CYCLESUP' / 'CYCLEGAN' = (Supervisory only/GAN/Cycle consistency+supervisory/CycleGAN)
train_SI=True         # If training GAN or SUP, set True to train Gen_SI (Sinogram-->Image), or False to train Gen_IS (Image-->Sinogram)

# Global Directories #
#data_dir=      '/content/drive/MyDrive/Repository/PET_Data/'                # Where training/testing/tuning data is located
data_dir = '/content/drive/MyDrive/Colab/Working/sets/'
local_dir=     '/content/drive/MyDrive/Colab/Working/'                              # Directories not explicitly assigned are created in this directory
plot_dir=      '/content/drive/MyDrive/Colab/Working/Plots/'                        # Directory to save plots
checkpoint_dir='/content/drive/MyDrive/Colab/Working/Checkpoints-temp'              # If not using Ray Tune (not tuning), PyTorch saves and
                                                                                    #   loads checkpoint file from here
                                                                                    # All checkpoint files (for training, testing, visualizing) save
                                                                                    #   the states for a particular network.
                                                                                    # Therefore, the hyperparameters for the loaded CNN must match the data in the checkpoint file.
                                                                                    # The configuration dictionary, which contains these
                                                                                    #   hyperparameter values, is set in the 'Supervisory" cell, below.
############
## Tuning ##
############
# Note: When tuning, ALWAYS select "restart session and run all" from Runtime menu in Google Colab, or there may be bugs.
tune_scheduler = 'ASHA'     # Use FIFO for simple first in/first out to train to the end, or ASHA for utilizing early stopping poorly performing trials.
tune_dataframe_dir= '/content/drive/MyDrive/Colab/Working/Dataframes-TuneTemp'  # This directory should already exist. The code will not make this for you.
tune_csv_file='frame-tunedOnLowSSIM-tunedSSIM-ASHA' # .csv file to save tuning dataframe to
tune_exp_name='search-Temp'                         # Experiment directory: Ray tune (and Tensorboard) write to this directory, relative to the local dir this notebook is run from.
tune_dataframe_fraction=0.33# At what fraction of the max tuning steps (tune_max_t) do you save values to the tuning dataframe.
tune_restore=False          # Restore a run (from the file tune_exp_name in local_dir). Use this if a tuning run terminated early for some reason.
tune_max_t = 10             # Maximum number of reports per network. For even training example reporting (reports made at a constant number of training
                            # examples), 20 is a good number for ASHA. For FIFO, 10 is a good number.
                            # For constant batch size reporting (tune_even_reporting=False), 35 works well.
tune_minutes = 30           # How long to run RayTune. 180 minutes is good for 90x90 input. 210 minutes for 180x180.
tune_for = 'SSIM'           # Tune for which optimization metric?: 'MSE', 'SSIM', or 'CUSTOM'
                            # (user defined, defined later in code).
tune_even_reporting=True    # Set to True to ensure we report to Raytune at an even number of training examples,
                            # regardless of batch size.
tune_iter_per_report=10     # Default value = 10.
                            # If tune_even_reporting = True, this is the number of training iterations per Raytune report for a batch size = 512.
                            # For a batch size = 256, the iterations/report would be twice this number. For batch size # = 128, it would be four
                            # times, etc.
                            # If tune_even_reporting = False, this is the number of batches per report (30 works pretty well).
tune_augment=True           # Augment data (on the fly) for tuning?
num_CPUs=4                  # Number of CPUs to use
num_GPUs=1                  # Number of GPUs to use



## Select Data Files ##
## ----------------- ##
#tune_sino_file=  'tune_sino-10k.npy'
#tune_image_file= 'tune_image-10k.npy'
#tune_sino_file= 'train_sino-highMSE-17500.npy'
#tune_image_file='train_image-highMSE-17500.npy'
#tune_sino_file= 'train_sino-lowMSE-17500.npy'
#tune_image_file='train_image-lowMSE-17500.npy'
tune_sino_file= 'train_sino-lowSSIM-17500.npy'
tune_image_file='train_image-lowSSIM-17500.npy'



##############
## Training ##
##############
train_load_state=False      # Set to True to load pretrained weights. Use if training terminated early.
train_save_state=False      # Save network weights to train_checkpoint_file file as it trains
train_checkpoint_file='checkpoint-tunedLowSSIM-trainedHighSSIM-100epochs' # Checkpoint file to load or save to
#train_checkpoint_file='checkpoint-90x1-tunedLDM_w10s8-b5c-6epochs'
#train_checkpoint_file='checkpoint-90x1-tunedLDM_w5s2-6epochs'
training_epochs = 100      # Number of training epochs.
train_augment=True         # Augment data (on the fly) for training?
train_display_step=10      # Number of steps/visualization. Good values: for supervised learning or GAN, set to: 20, For cycle-consistent, set to 10
train_sample_division=1    # To evenly sample the training set by a given factor, set this to an integer greater than 1 (ex: to sample every other example, set to 2)
train_show_times=False     # Show calculation times during training?

## Select Data Files ##
## ----------------- ##
#train_sino_file= 'train_sino-70k.npy'
#train_image_file='train_image-70k.npy'
#train_sino_file= 'train_sino-highMSE-17500.npy'
#train_image_file='train_image-highMSE-17500.npy'
#train_sino_file= 'train_sino-lowMSE-17500.npy'
#train_image_file='train_image-lowMSE-17500.npy'
train_sino_file= 'train_sino-lowSSIM-17500.npy'
train_image_file= 'train_image-lowSSIM-17500.npy'



###########
# Testing #
###########
#test_dataframe_dir= '/content/drive/MyDrive/Colab/Working/Dataframes-Test-Quartile'   # Directory for metric dataframes
test_dataframe_dir= '/content/drive/MyDrive/Colab/Working/Dataframes-TestOnFull'  # Directory for metric dataframes
test_csv_file = 'combined-tunedLowSSIM-trainedLowSSIM-onTestSet-wMLEM' # csv dataframe file to save testing results to
test_checkpoint_file='checkpoint-tunedLowSSIM-trainedLowSSIM-100epochs' # Checkpoint to load model for testing
test_display_step=15        # Make this a larger number to save bit of time (displays images/metrics less often)
test_batch_size=25          # This doesn't affect the final metrics, just the displayed metrics as testing procedes
test_chunk_size=875              # How many examples do you want to test at once? NOTE: This should be a multiple of test_batch_size AND also go into the test set size evenly.
testset_size=35000          # Size of the set to test. This must be <= the number of examples in your test set file.
test_begin_at=0             # Begin testing at this example number.
compute_MLEM=True           # Compute a simple MLEM reconstruction from the sinograms when running testing.
                            # This takes a lot longer. If set to false, only FBP is calculated.
test_set_type='test'        # Set to 'test' to test on the test set. Set to 'train' to test on the training set.
# Defaults
test_merge_dataframes=True       # Merge the smaller/chunked dataframes at the end of the test run into one large dataframe?
test_show_times=False       # Show calculation times?
test_shuffle=False          # Shuffle test set when testing?
test_sample_division=1      # To evenly sample the test set by a given factor, set this to an integer greater than 1.

## Select Data Files ##
## ----------------- ##
test_sino_file=  'test_sino-35k.npy'
test_image_file= 'test_image-35k.npy'
#test_sino_file= 'test_sino-highMSE-8750.npy'
#test_image_file= 'test_image-highMSE-8750.npy'
#test_sino_file= 'test_sino-lowMSE-8750.npy'
#test_image_file= 'test_image-lowMSE-8750.npy'



####################
## Visualize Data ##
####################

#visualize_checkpoint_file='checkpoint-90x1-tunedMSE-fc6-6epochs' # Checkpoint file to load/save
visualize_checkpoint_file='checkpoint-tunedHigh-trainedHigh-100epochs'
visualize_batch_size = 10   # Set value to exactly 120 to see a large grid of images OR =<10 for reconstructions
                            #  and ground truth with matched color scales
visualize_offset=0          # Image to begin at. Set to 0 to start at beginning.
visualize_type='train'      # Set to 'test' or 'train' to visualize the test set or training set, respectively
visualize_shuffle=True      # Shuffle data set when visualizing?



####################################################################
## Assign Values for Various Scenarios: visualize/tune/train/test ##
####################################################################

tune_sino_path=os.path.join(data_dir, tune_sino_file)
tune_image_path=os.path.join(data_dir, tune_image_file)
train_sino_path=os.path.join(data_dir, train_sino_file)
train_image_path=os.path.join(data_dir, train_image_file)
test_sino_path=os.path.join(data_dir, test_sino_file)
test_image_path=os.path.join(data_dir, test_image_file)

## Run-Type Specific Assignments ##
if run_mode=='tune':
    sino_path=tune_sino_path
    image_path=tune_image_path
    augment=tune_augment
    shuffle = True
    num_epochs=1000         # Tuning is stopped when the iteration = tune_max_t (defined later). We set num_epochs to a large number so tuning doesn't terminate early.
    load_state=False
    save_state=False
    checkpoint_file = ''    # Leave this empty. The checkpoint path is constructed regardless, so this ensures that no error occurs.
    offset=0
    show_times=False
    sample_division=1
    tune_dataframe_path = os.path.join(tune_dataframe_dir, tune_csv_file+'.csv')

if run_mode=='train':
    sino_path=train_sino_path
    image_path=train_image_path
    augment=train_augment
    shuffle=True
    num_epochs=training_epochs
    load_state=train_load_state
    save_state=train_save_state
    checkpoint_file = train_checkpoint_file
    offset=0
    show_times=train_show_times
    sample_division=train_sample_division

if run_mode=='test':
    if test_set_type=='test': # Test on test set
        sino_path=test_sino_path
        image_path=test_image_path
    else: # Test on training set
        sino_path=train_sino_path
        image_path=train_image_path
    augment = False
    shuffle = test_shuffle
    num_epochs=1
    load_state=True # Set to True to load pretrained weights.
    save_state=False # Do not save network weights to checkpoint file as we are only testing.
    checkpoint_file = test_checkpoint_file
    offset=0
    show_times=test_show_times
    sample_division=test_sample_division

if run_mode=='visualize':
    if visualize_type=='test':
        sino_path=test_sino_path  # test_sino_path
        image_path=test_image_path # test_image_path
    else:
        sino_path=train_sino_path  # test_sino_path
        image_path=train_image_path # test_image_path
    augment = False
    shuffle = visualize_shuffle
    num_epochs=1
    load_state=True
    save_state=False
    checkpoint_file = visualize_checkpoint_file
    show_times=False
    offset=visualize_offset
    sample_division=1

## Assign dataframe and checkpoint paths ##
test_dataframe_path = os.path.join(test_dataframe_dir, test_csv_file+'.csv')
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file) # Requires the assignment of checkpoint_file, so is done after the run_type specific assignments

NameError: name 'os' is not defined

# Configuration Dicts

---


## Supervisory

In this cell, set the correct hyperparameter dictionary to config_SUP_SI. This is the dictionary of hyperparameters that determines the form of a the network that will be trained, tested, or visualized (when doing supervised learning, Sinogram-->Image). You will usually find these hyperparameters by performing tuning and examining the best performing networks in tensorboard.

If training supervisory loss networks only, you don't need to worry about the other dictionaries in this section (GANs, Cycle-Consistent). You also don't need to worry about "Search Spaces", as this is simply a dictionary of the search space that Ray Tune uses when tuning. Feel free to look at it though, to see how I set up the search space. The last section (Set Correct Config) is where the configuration dictionary gets assigned. The dictionary is either a searchable space, if tuning, or a set of fixed hyperparameters, if training, testing, or visualizing the data set.

In [None]:
### Below networks were tuned on whole dataset ###
# 1x90x90, Tuned for MSE - fc6 #
'''
config_SUP_SI={
  "SI_dropout": False,
  "SI_exp_kernel": 4,
  "SI_gen_fill": 0,
  "SI_gen_final_activ": None,
  "SI_gen_hidden_dim": 14,
  "SI_gen_mult": 2.3737518721494038,
  "SI_gen_neck": 5,
  "SI_gen_z_dim": 300,
  "SI_layer_norm": "instance",
  "SI_normalize": True,
  "SI_pad_mode": "zeros",
  "SI_scale": 8100,
  "batch_size": 266,
  "gen_b1": 0.5194977285709309,
  "gen_b2": 0.4955647195661826,
  "gen_lr": 0.0006569034263698925,
  "sup_criterion": nn.MSELoss()
}
'''
'''
# 1x90x90, Tuned for MAE (mean absolute error) - b08 #
config_SUP_SI={
  "SI_dropout": True,
  "SI_exp_kernel": 3,
  "SI_gen_fill": 0,
  "SI_gen_final_activ": nn.Tanh(),
  "SI_gen_hidden_dim": 29,
  "SI_gen_mult": 3.4493572412953926,
  "SI_gen_neck": 5,
  "SI_gen_z_dim": 92,
  "SI_layer_norm": "instance",
  "SI_normalize": True,
  "SI_pad_mode": "zeros",
  "SI_scale": 8100,
  "batch_size": 184,
  "gen_b1": 0.41793988944151467,
  "gen_b2": 0.15133808988276928,
  "gen_lr": 0.0012653525173041019,
  "sup_criterion": nn.L1Loss()
}
'''
'''
# 1x90x90, Tuned for SSIM - 14d #
config_SUP_SI = {
  "SI_dropout": False,
  "SI_exp_kernel": 4,
  "SI_gen_fill": 0,
  "SI_gen_final_activ": nn.Tanh(),
  "SI_gen_hidden_dim": 23,
  "SI_gen_mult": 1.6605902406330195,
  "SI_gen_neck": 5,
  "SI_gen_z_dim": 789,
  "SI_layer_norm": "instance",
  "SI_normalize": True,
  "SI_pad_mode": "zeros",
  "SI_scale": 8100,
  "batch_size": 71,
  "gen_b1": 0.2082092731474774,
  "gen_b2": 0.27147903136187507,
  "gen_lr": 0.0005481469822215635,
  "sup_criterion": nn.MSELoss()
}
'''
'''
# 1x90x90, Tuned for Local Distributions Metric, 5x5 window, stride 2
config_SUP_SI={
  "SI_dropout": True,
  "SI_exp_kernel": 3,
  "SI_gen_fill": 2,
  "SI_gen_final_activ": nn.Sigmoid(),
  "SI_gen_hidden_dim": 18,
  "SI_gen_mult": 2.4691388140182475,
  "SI_gen_neck": 11,
  "SI_gen_z_dim": 444,
  "SI_layer_norm": "instance",
  "SI_normalize": True,
  "SI_pad_mode": "zeros",
  "SI_scale": 8100,
  "batch_size": 33,
  "gen_b1": 0.8199882799898334,
  "gen_b2": 0.1207854128656507,
  "gen_lr": 0.0001095057659925285,
  "sup_criterion": nn.BCEWithLogitsLoss()
}
'''
'''
# 1x90x90, Tuned for Local Distributions Metric, 10x10 window, stride 8 (b5c)
config_SUP_SI={
  "SI_dropout": False,
  "SI_exp_kernel": 4,
  "SI_gen_fill": 0,
  "SI_gen_final_activ": None,
  "SI_gen_hidden_dim": 9,
  "SI_gen_mult": 2.1547197646081444,
  "SI_gen_neck": 5,
  "SI_gen_z_dim": 344,
  "SI_layer_norm": "batch",
  "SI_normalize": False,
  "SI_pad_mode": "zeros",
  "SI_scale": 8100,
  "batch_size": 47,
  "gen_b1": 0.31108788447029295,
  "gen_b2": 0.3445239707919786,
  "gen_lr": 0.0007561178182660596,
  "sup_criterion": nn.L1Loss()
}
'''


### Below networks were tuned on 1/4 of dataset (high MSE or low MSE) ####

# 1x90x90, Tuned for SSIM, highSSIM quartile, - c867539
config_SUP_SI = {
  "SI_dropout": False,
  "SI_exp_kernel": 3,
  "SI_gen_fill": 0,
  "SI_gen_final_activ": nn.Tanh(),
  "SI_gen_hidden_dim": 14,
  "SI_gen_mult": 3.1366081867376066,
  "SI_gen_neck": 5,
  "SI_gen_z_dim": 1235,
  "SI_layer_norm": "instance",
  "SI_normalize": True,
  "SI_pad_mode": "reflect",
  "SI_scale": 8100,
  "batch_size": 512,
  "gen_b1": 0.36092827701745117,
  "gen_b2": 0.2959809747063715,
  "gen_lr": 0.0003914885622973457,
  "sup_criterion": nn.MSELoss()
}

'''
# 1x90x90, Tuned for MSE, lowMSE quartile - d3c
config_SUP_SI = {
  "SI_dropout": False,
  "SI_exp_kernel": 3,
  "SI_gen_fill": 0,
  "SI_gen_final_activ": nn.Tanh(),
  "SI_gen_hidden_dim": 10,
  "SI_gen_mult": 3.5952046080348117,
  "SI_gen_neck": 5,
  "SI_gen_z_dim": 1144,
  "SI_layer_norm": "batch",
  "SI_normalize": True,
  "SI_pad_mode": "zeros",
  "SI_scale": 8100,
  "batch_size": 338,
  "gen_b1": 0.21119520045946658,
  "gen_b2": 0.3219437242478679,
  "gen_lr": 0.0012228287967471555,
  "sup_criterion": nn.L1Loss()
}
'''
'''
# 1x90x90, Tuned for MSE, highMSE quartile - 66e
config_SUP_SI = {
  "SI_dropout": False,
  "SI_exp_kernel": 4,
  "SI_gen_fill": 0,
  "SI_gen_final_activ": nn.Tanh(),
  "SI_gen_hidden_dim": 13,
  "SI_gen_mult": 2.427097790975542,
  "SI_gen_neck": 1,
  "SI_gen_z_dim": 1943,
  "SI_layer_norm": "instance",
  "SI_normalize": True,
  "SI_pad_mode": "zeros",
  "SI_scale": 8100,
  "batch_size": 399,
  "gen_b1": 0.5173104983713961,
  "gen_b2": 0.5269533977675209,
  "gen_lr": 0.00042406256400739315,
  "sup_criterion": nn.MSELoss()
}
'''

'\n# 1x90x90, Tuned for MSE, highMSE quartile - 66e\nconfig_SUP_SI = {\n  "SI_dropout": False,\n  "SI_exp_kernel": 4,\n  "SI_gen_fill": 0,\n  "SI_gen_final_activ": nn.Tanh(),\n  "SI_gen_hidden_dim": 13,\n  "SI_gen_mult": 2.427097790975542,\n  "SI_gen_neck": 1,\n  "SI_gen_z_dim": 1943,\n  "SI_layer_norm": "instance",\n  "SI_normalize": True,\n  "SI_pad_mode": "zeros",\n  "SI_scale": 8100,\n  "batch_size": 399,\n  "gen_b1": 0.5173104983713961,\n  "gen_b2": 0.5269533977675209,\n  "gen_lr": 0.00042406256400739315,\n  "sup_criterion": nn.MSELoss()\n}\n'

## GANs + Cycle

In [None]:
## Best Configs for GANs ##

config_GAN_SI = { # Older, this still outperforms the more recent tuning
    'SI_disc_adv_criterion': nn.MSELoss(),
    'SI_normalize': True, # True
    'SI_scale': 1400, # 1      # Added later
    'SI_gen_neck': 1, # 1
    'SI_layer_norm': 'batch',
    'SI_pad_mode': 'reflect',
    'SI_dropout': False,
    'SI_exp_kernel': 3,
    'SI_gen_fill': 0,
    'SI_gen_mult': 1.41,
    'SI_gen_z_dim': 115,
    'SI_gen_final_activ': nn.Sigmoid(),
    'SI_disc_patchGAN': True,
    'SI_gen_hidden_dim': 46,
    'SI_disc_hidden_dim': 25,
    'SI_disc_b1': 0.102081,
    'SI_disc_b2': 0.999,
    # Gets Overwritten Below
    'SI_disc_lr': 0.000167384,
    'batch_size': 78,
    'gen_adv_criterion': nn.MSELoss(),
    'gen_lr': 0.000167384,
    'gen_b1': 0.102081,
    'gen_b2': 0.999,
    }

config_GAN_IS = { # new, looks good by step 400, somewhat blocky. May be outperforming config_GAN_SI
  "IS_disc_adv_criterion": nn.BCEWithLogitsLoss(),
  "IS_disc_b1": 0.3335905891003811,
  "IS_disc_b2": 0.999,
  "IS_disc_hidden_dim": 11,
  "IS_disc_patchGAN": True,
  "IS_dropout": False,
  "IS_exp_kernel": 3,
  "IS_gen_fill": 0,
  "IS_gen_final_activ": None,
  "IS_gen_hidden_dim": 15,
  "IS_gen_mult": 3,
  "IS_gen_neck": 11,
  "IS_gen_z_dim": 5,
  "IS_layer_norm": "instance",
  "IS_normalize": False,
  "IS_pad_mode": "reflect",
  "IS_scale": 1,
  # Gets Overwritten Below
  "IS_disc_lr": 0.00021705437338035208,
  "batch_size": 88,
  "gen_adv_criterion": nn.MSELoss(),
  "gen_b1": 0.46293297275979556,
  "gen_b2": 0.999,
  "gen_lr": 0.00042810775483742824
}

'''
# this config looks decent at step 1100, worse at 1440, better at 1900, etc. (variable). It isn't blocky.
config_GAN_IS_old = {
    "batch_size": 82,
    "gen_adv_criterion": nn.BCEWithLogitsLoss(),
    "gen_lr": 3.365297856241193e-05,
    "gen_b1": 0.11790916451301556,
    "gen_b2": 0.999,

    "IS_disc_adv_criterion": nn.BCEWithLogitsLoss(),
    "IS_normalize": False, # FALSE
    'IS_scale': 1, # 1
    'IS_gen_mult': 3,
    'IS_gen_fill': 0,
    'IS_gen_neck': 5, # Wide neck
    "IS_gen_z_dim": 115,
    'IS_layer_norm': 'instance',
    'IS_pad_mode': 'reflect',
    'IS_dropout': False,
    'IS_exp_kernel': 4,
    "IS_gen_final_activ": nn.Tanh(), # nn.Tanh()
    "IS_disc_patchGAN": True,
    "IS_gen_hidden_dim": 16,
    "IS_disc_hidden_dim": 19,
    "IS_disc_lr": 0.00020392229473545828,
    "IS_disc_b1": 0.35984156365558084,
    "IS_disc_b2": 0.999,
    }
'''

'\n# this config looks decent at step 1100, worse at 1440, better at 1900, etc. (variable). It isn\'t blocky.\nconfig_GAN_IS_old = {\n    "batch_size": 82,\n    "gen_adv_criterion": nn.BCEWithLogitsLoss(),\n    "gen_lr": 3.365297856241193e-05,\n    "gen_b1": 0.11790916451301556,\n    "gen_b2": 0.999,\n\n    "IS_disc_adv_criterion": nn.BCEWithLogitsLoss(),\n    "IS_normalize": False, # FALSE\n    \'IS_scale\': 1, # 1\n    \'IS_gen_mult\': 3,\n    \'IS_gen_fill\': 0,\n    \'IS_gen_neck\': 5, # Wide neck\n    "IS_gen_z_dim": 115,\n    \'IS_layer_norm\': \'instance\',\n    \'IS_pad_mode\': \'reflect\',\n    \'IS_dropout\': False,\n    \'IS_exp_kernel\': 4,\n    "IS_gen_final_activ": nn.Tanh(), # nn.Tanh()\n    "IS_disc_patchGAN": True,\n    "IS_gen_hidden_dim": 16,\n    "IS_disc_hidden_dim": 19,\n    "IS_disc_lr": 0.00020392229473545828,\n    "IS_disc_b1": 0.35984156365558084,\n    "IS_disc_b2": 0.999,\n    }\n'

In [None]:
## I've looked at configurations in 'search-CycleGAN' through (and including) '4a92'

config_CYCLEGAN={ # Works, yeah! ("4a92")
    "IS_disc_adv_criterion": nn.BCEWithLogitsLoss(),
    "IS_disc_b1": 0.3335905891003811,
    "IS_disc_b2": 0.999,
    "IS_disc_hidden_dim": 11,
    "IS_disc_lr": 0.0006554051278271163,
    "IS_disc_patchGAN": True,
    "IS_dropout": False,
    "IS_exp_kernel": 3,
    "IS_gen_fill": 0,
    "IS_gen_final_activ": None,
    "IS_gen_hidden_dim": 15,
    "IS_gen_mult": 3,
    "IS_gen_neck": 11,
    "IS_gen_z_dim": 5,
    "IS_layer_norm": "instance",
    "IS_normalize": False,
    "IS_pad_mode": "reflect",
    "IS_scale": 1,
    "SI_disc_adv_criterion": nn.MSELoss(),
    "SI_disc_b1": 0.102081,
    "SI_disc_b2": 0.999,
    "SI_disc_hidden_dim": 25,
    "SI_disc_lr": 0.0005793968896471209,
    "SI_disc_patchGAN": True,
    "SI_dropout": False,
    "SI_exp_kernel": 3,
    "SI_gen_fill": 0,
    "SI_gen_final_activ": nn.Sigmoid(),
    "SI_gen_hidden_dim": 46,
    "SI_gen_mult": 1.41,
    "SI_gen_neck": 1,
    "SI_gen_z_dim": 115,
    "SI_layer_norm": "batch",
    "SI_normalize": True,
    "SI_pad_mode": "reflect",
    "SI_scale": 1400,
    "batch_size": 91,
    "cycle_criterion": nn.MSELoss(),
    "gen_adv_criterion": nn.MSELoss(),
    "gen_b1": 0.1610671788990834,
    "gen_b2": 0.999,
    "gen_lr": 0.0023450700434171526,
    "lambda_adv": 1,
    "lambda_cycle": 1, #1
    "lambda_sup": 0, # 0
    "sup_criterion": nn.L1Loss()
    }

'''
## Was best for training the CycleGAN all at once ##
# Below config is "SM_1662", the lowest optim_metric in 9h run, 90x90 symmetrical (not symmetrized parameters) networks.
# It was trained on IO_channels==3 but seems to work fine for IO_channels==1. Also, both discriminators use the same architecture,
# which is really better suited for the sinogram (Disc_S_90).
#
# Lessons Learned:
# 1) Utilized: different size necks, final activations, channels, patchGAN
# 2) NOT Uilized: fill Conv2d layers, different adv_criterion (for disc loss), different normalizations

config={ # Symmetrize == FALSE (final activations don't match). This was the best over full tune train time (9 hours). Use two Sinogram discriminators.
'batch_size': 107,
'gen_b1': 0.339,
'gen_b2': 0.999,
'gen_lr': 0.000103,
"cycle_criterion": nn.L1Loss(),
"sup_criterion": nn.L1Loss(),
"gen_adv_criterion": nn.KLDivLoss(),
"lambda_adv": 1,
"lambda_cycle": 2,
"lambda_sup": 0,

"IS_disc_adv_criterion": nn.MSELoss(),
"IS_disc_b1": 0.19520417398460468,
"IS_disc_b2": 0.999,
"IS_disc_hidden_dim": 23,
"IS_disc_lr": 0.0022230036964765274,  # disc_lr is 10X faster than SI
"IS_disc_patchGAN": False,            # true for SI (make sense, images can be more true/false in patches)
"IS_gen_fill": 0,                     # fill=0 for both IS and SI
"IS_gen_final_activ": nn.Sigmoid(),   # tuned final activations opposite than for GANs
"IS_gen_hidden_dim": 8,               # IS much less complex than SI (8 vs 16 hidden_dim)
"IS_gen_mult": 3,                     # mult=3 for both IS and SI
"IS_gen_z_dim": 5,
"IS_normalize": True,                 # both are normalized here
"IS_scale": 1,                        # OMG, this is weird. We are normalizing both, but the SI image scale is 1400x the IS. Could this by why final activation is now Sigmoid?

'IS_layer_norm': 'batch', # Batch
'IS_pad_mode': 'reflect',
'IS_dropout': False,
"IS_gen_neck": 5,            # 2
'IS_exp_kernel': 4,          # 4

"SI_disc_adv_criterion": nn.MSELoss(),
"SI_disc_b1": 0.30423542819878224,
"SI_disc_b2": 0.999,
"SI_disc_hidden_dim": 23,
"SI_disc_lr": 0.00020737432489437965,
"SI_disc_patchGAN": True,
"SI_gen_fill": 0,
"SI_gen_final_activ": nn.Tanh(),
"SI_gen_hidden_dim": 22,
"SI_gen_mult": 3,
"SI_gen_z_dim": 1195,                 # Represents an 8x drop in information into narrowest part of neck
"SI_normalize": True, # True
"SI_scale": 1400,

'SI_layer_norm': 'batch',
'SI_pad_mode': 'reflect',
'SI_dropout': False,
"SI_gen_neck": 1,            # 1
'SI_exp_kernel': 4,          # 4
}
'''

'\n## Was best for training the CycleGAN all at once ##\n# Below config is "SM_1662", the lowest optim_metric in 9h run, 90x90 symmetrical (not symmetrized parameters) networks.\n# It was trained on IO_channels==3 but seems to work fine for IO_channels==1. Also, both discriminators use the same architecture,\n# which is really better suited for the sinogram (Disc_S_90).\n#\n# Lessons Learned:\n# 1) Utilized: different size necks, final activations, channels, patchGAN\n# 2) NOT Uilized: fill Conv2d layers, different adv_criterion (for disc loss), different normalizations\n\nconfig={ # Symmetrize == FALSE (final activations don\'t match). This was the best over full tune train time (9 hours). Use two Sinogram discriminators.\n\'batch_size\': 107,\n\'gen_b1\': 0.339,\n\'gen_b2\': 0.999,\n\'gen_lr\': 0.000103,\n"cycle_criterion": nn.L1Loss(),\n"sup_criterion": nn.L1Loss(),\n"gen_adv_criterion": nn.KLDivLoss(),\n"lambda_adv": 1,\n"lambda_cycle": 2,\n"lambda_sup": 0,\n\n"IS_disc_adv_criterio

## Search Spaces

In [None]:
#################################################################################################################################################################
## (config_RAY_SI OR config_RAY_IS) gets combined with (config_RAY_SUP or config_RAY_GAN) to form a single hyperparameter space for searching a single network ##
#################################################################################################################################################################

## Note: For the Coursera CycleGAN:
# gen_adv_criterion = disc_adv_criterion = nn.MSELoss()
# cycle_criterion = ident_criterion = nn.L1Loss()
# for notes on momentum, see: https://distill.pub/2017/momentum/

config_RAY_SI = { # Dictionary for Generator: Sinogram-->Image
    # Data Loading
    'SI_normalize': tune.choice([True, False]),                 # Normalize dataloader outputs and outputs of generator?
                                                                # If so, the pixel values in the image all add up to 1.
    'SI_scale': 90*90,                                          # If normalizing the pixel images, multiply images by this value.
                                                                # The pixel values will then add up to this number.
    # Generator Network
    'SI_gen_mult': tune.uniform(1.1, 4),                        # Factor by which to multiply channels/block as one moves twowards the center of the network
    'SI_gen_fill': tune.choice([0,1,2]),                        # Number of constant-sized Conv2d layers/block
    'SI_gen_neck': tune.choice([1,5,11]),                       # Size of network neck: 1 = smallest, 11 = largest
    'SI_gen_z_dim': tune.lograndint(64, 4000),                  # If network utilizes smallest neck size (1x1 = a dense layer), this is the number of channels in the neck
    'SI_layer_norm': tune.choice(['batch', 'instance','none']), # Layer normalization type ('none' seems to be, in practice, never chosen by tuning)
    'SI_pad_mode': tune.choice(['zeros', 'reflect']),           # Padding type
    'SI_dropout': tune.choice([True,False]),                    # Implement dropout in network? (without cross-validation, this is likely never chosen)
    'SI_exp_kernel': tune.choice([3,4]),                        # Expanding kernel size: 3x3 or 4x4
    'SI_gen_final_activ':  tune.choice([nn.Tanh(), nn.Sigmoid(), None]), # Options: tune.choice([nn.Tanh(), nn.Sigmoid(), None]),
                                                                # Could add: nn.ReLU6(), nn.Hardsigmoid(), nn.ReLU(), nn.PReLU(), None
    'SI_gen_hidden_dim': tune.lograndint(2, 30),                # Generator channel scaling factor. Larger numbers give more total channels.
    # Discriminator Network
    'SI_disc_hidden_dim': tune.lograndint(10, 30),              # Discriminator channel scaling factor
    'SI_disc_patchGAN': tune.choice([True, False]),             # Use PatchGAN or not
    # Discriminator Optimizer
    'SI_disc_lr': tune.loguniform(1e-4,1e-2),
    'SI_disc_b1': tune.loguniform(0.1, 0.999),
    'SI_disc_b2': tune.loguniform(0.1, 0.999),
    'SI_disc_adv_criterion': tune.choice([nn.MSELoss(), nn.BCEWithLogitsLoss()]), # Possible options: tune.choice([nn.MSELoss(), nn.KLDivLoss(), nn.BCEWithLogitsLoss()]),
    }

config_RAY_IS = { # Dictionary for Generator: Image-->Sinogram
    # Data Loading
    'IS_normalize': False, # tune.choice([True, False]), # Normalize outputs or not
    'IS_scale': 90*90,
    # Generator Network
    'IS_gen_mult': tune.uniform(1.1, 4),
    'IS_gen_fill': tune.choice([0,1,2]),
    'IS_gen_neck': tune.choice([1,5,11]),
    'IS_gen_z_dim': tune.lograndint(64, 4000),
    'IS_layer_norm': tune.choice(['batch', 'instance','none']),
    'IS_pad_mode': tune.choice(['zeros', 'reflect']),
    'IS_dropout': tune.choice([True,False]),
    'IS_exp_kernel': tune.choice([3,4]),
    'IS_gen_final_activ': tune.choice([nn.Tanh(), nn.Sigmoid(), None]), # nn.ReLU6(), nn.Hardsigmoid(), nn.ReLU(), nn.PReLU(), None
    'IS_gen_hidden_dim': tune.lograndint(2, 30),
    # Discriminator Network
    'IS_disc_hidden_dim': tune.lograndint(10, 30),
    'IS_disc_patchGAN': tune.choice([True, False]),
    # Discriminator Optimizer
    'IS_disc_lr': tune.loguniform(1e-4,1e-2),
    'IS_disc_b1': tune.loguniform(0.1, 0.999),
    'IS_disc_b2': tune.loguniform(0.1, 0.999),
    'IS_disc_adv_criterion': tune.choice([nn.MSELoss(), nn.BCEWithLogitsLoss()]),
    }

config_RAY_SUP = { # This dictionary may be merged with either config_RAY_IS or config_RAY_SI to form a single dictionary for supervisory learning
    # NEW: New parameters added to config_RAY_SI (related to generator optimizer)
    'batch_size': tune.choice([32, 64, 128, 256, 512]), # tune.lograndint(30, 400),
    'gen_lr': tune.loguniform(1e-4,1e-2),
    'gen_b1': tune.loguniform(0.1, 0.999),
    'gen_b2': tune.loguniform(0.1, 0.999),
    'sup_criterion': tune.choice([nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.L1Loss(), nn.KLDivLoss(reduction='batchmean')]), # Not SI or IS because used for both
    # OVERWRITES: overwrites values from config_RAY_SI or config_RAY_IS. This is done so time isn't wasted looking for unused hyperparameters.
    'SI_disc_hidden_dim': 1,
    'SI_disc_patchGAN': 1,
    'SI_disc_lr': 1,
    'SI_disc_b1': 1,
    'SI_disc_b2': 1,
    'SI_disc_adv_criterion': 1,
    'IS_disc_hidden_dim': 1,
    'IS_disc_patchGAN': 1,
    'IS_disc_lr': 1,
    'IS_disc_b1': 1,
    'IS_disc_b2': 1,
    'IS_disc_adv_criterion': 1,
    }

config_RAY_GAN = { # This is MERGED with either config_RAY_IS or config_RAY_SI to form a single dictionary for a generative adversarial network.
    # NEW
    'batch_size': tune.choice([32, 64, 128, 256, 512]),  #tune.lograndint(30, 400),
    'gen_lr': tune.loguniform(1e-4,1e-2),
    'gen_b1': tune.loguniform(0.1, 0.999),
    'gen_b2': 0.999, #tune.loguniform(0.1, 0.999),
    'gen_adv_criterion': tune.choice([nn.MSELoss(), nn.BCEWithLogitsLoss()]),
    }

config_GAN_RAY_cycle = { # Mixed New/Overwrites (when combined with config_SI/config_IS) to form a single dictionary for a cycle-consistent generative adversarial network.
    # NEW
    'cycle_criterion': tune.choice([nn.MSELoss(), nn.L1Loss()]),
    'sup_criterion': tune.choice([nn.MSELoss(), nn.KLDivLoss(reduction='batchmean'), nn.L1Loss(), nn.BCEWithLogitsLoss()]),
    'lambda_adv': 1,
    'lambda_sup': 0,
    'lambda_cycle': 1,
    # OVERWRITES
    'gen_adv_criterion': nn.MSELoss(), #tune.choice([nn.MSELoss(), nn.KLDivLoss(), nn.BCEWithLogitsLoss()]),
    'IS_disc_lr': tune.loguniform(1e-4,1e-2),
    'SI_disc_lr': tune.loguniform(1e-4,1e-2),
    'batch_size': tune.choice([32, 64, 128, 256, 512]),
    'gen_lr': tune.loguniform(0.5e-4,1e-2),
    'gen_b1': tune.loguniform(0.1, 0.999),
    'gen_b2': 0.999, #tune.loguniform(0.1, 0.999),
    }

config_SUP_RAY_cycle = { # Mixed New/Overwrites (when combined with config_SI/config_IS) to form a single dictionary for a cycle-consistent, partially supervised network.
    # NEW
    'cycle_criterion': tune.choice([nn.MSELoss(), nn.L1Loss()]),
    'lambda_adv': 0,
    'lambda_sup': 1,
    'lambda_cycle':  tune.uniform(0, 10),
    # OVERWRITES
    'batch_size': tune.choice([32, 64, 128, 256, 512]),
    'gen_lr': tune.loguniform(0.5e-4,1e-2),
    'gen_b1': tune.loguniform(0.1, 0.999), # DCGan uses 0.5, https://distill.pub/2017/momentum/
    'gen_b2': tune.loguniform(0.1, 0.999),
    'sup_criterion': tune.choice([nn.MSELoss(), nn.KLDivLoss(), nn.L1Loss(), nn.BCEWithLogitsLoss()]),
    # NOT USED
    'gen_adv_criterion': nn.MSELoss(), #tune.choice([nn.MSELoss(), nn.KLDivLoss(), nn.BCEWithLogitsLoss()]),
    'IS_disc_lr': 1e-4, #tune.loguniform(1e-4,1e-2),
    'SI_disc_lr': 1e-4, #tune.loguniform(1e-4,1e-2),
    }

## Set Correct Config

In [None]:
## Combine Dictionaries ##
if run_mode=='train' or 'test' or 'visualize':
    if train_type=='SUP':
        if train_SI==True:
            config = config_SUP_SI
        if train_SI==False:
            config = config_SUP_IS
    if train_type=='GAN':
        if train_SI==True:
            config = config_GAN_SI
        if train_SI==False:
            config = config_GAN_IS
    if train_type=='CYCLEGAN':
        config = config_CYCLEGAN
    if train_type=='CYCLESUP':
        config = config_CYCLESUP

if run_mode=='tune':
    if train_type=='SUP':
        if train_SI==True:
            config = {**config_RAY_SI, **config_RAY_SUP}
        if train_SI==False:
            config = {**config_RAY_IS, **config_RAY_SUP}
    if train_type=='GAN':
        if train_SI==True:
            config = {**config_RAY_SI, **config_RAY_GAN}
        if train_SI==False:
            config = {**config_RAY_IS, **config_RAY_GAN}
    if train_type=='CYCLESUP':
        config = {**config_SUP_SI, **config_SUP_IS, **config_SUP_RAY_cycle}
    if train_type=='CYCLEGAN':
        config = {**config_GAN_SI, **config_GAN_IS, **config_GAN_RAY_cycle}

print(config)

{'SI_normalize': <ray.tune.search.sample.Categorical object at 0x7f18738031d0>, 'SI_scale': 8100, 'SI_gen_mult': <ray.tune.search.sample.Float object at 0x7f18737f8610>, 'SI_gen_fill': <ray.tune.search.sample.Categorical object at 0x7f18738009d0>, 'SI_gen_neck': <ray.tune.search.sample.Categorical object at 0x7f187bfa56d0>, 'SI_gen_z_dim': <ray.tune.search.sample.Integer object at 0x7f1873637bd0>, 'SI_layer_norm': <ray.tune.search.sample.Categorical object at 0x7f1a00253090>, 'SI_pad_mode': <ray.tune.search.sample.Categorical object at 0x7f1873670b50>, 'SI_dropout': <ray.tune.search.sample.Categorical object at 0x7f18736a8a10>, 'SI_exp_kernel': <ray.tune.search.sample.Categorical object at 0x7f18736a8b10>, 'SI_gen_final_activ': <ray.tune.search.sample.Categorical object at 0x7f18736a8dd0>, 'SI_gen_hidden_dim': <ray.tune.search.sample.Integer object at 0x7f187bf9da90>, 'SI_disc_hidden_dim': 1, 'SI_disc_patchGAN': 1, 'SI_disc_lr': 1, 'SI_disc_b1': 1, 'SI_disc_b2': 1, 'SI_disc_adv_criteri

# Classes

## Dataset



In [None]:
def NpArrayDataLoader(image_array, sino_array, config, image_size = 90, sino_size=90, image_channels=1, sino_channels=1, augment=False, index=0):
    '''
    Function to load an image and a sinogram. Returns 4 pytorch tensors: the original dataset sinogram and image,
    and scaled and (optionally) normalized sinograms and images.

    image_array:    image numpy array
    sino_array:     sinogram numpy array
    config:         configuration dictionary with hyperparameters
    image_size:     shape to resize image to (for output)
    image_channels: number of channels for output images
    sino_size:      shape to resize sinograms to (for output)
    sino_channels:  number of channels in output sinograms
    augment:        perform data augmentation?
    index:          index of the image/sinogram pair to grab
    '''
    ## Set Normalization Variables ##
    if (train_type=='GAN') or (train_type=='SUP'):
        if train_SI==True:
            SI_normalize=config['SI_normalize']
            SI_scale=config['SI_scale']
            IS_normalize=False     # If the Sinogram-->Image network (SI) is being trained, don't waste time normalizing sinograms
            IS_scale=1             # If the Sinogram-->Image network (SI) is being trained, don't waste time scaling sinograms
        else:
            IS_normalize=config['IS_normalize']
            IS_scale=config['IS_scale']
            SI_normalize=False
            SI_scale=1
    else: # If a cycle-consistent network, normalize & scale everything
        IS_normalize=config['IS_normalize']
        SI_normalize=config['SI_normalize']
        IS_scale=config['IS_scale']
        SI_scale=config['SI_scale']

    ## Data Augmentation Functions ##
    def RandRotate(image_multChannel, sinogram_multChannel):
            '''
            Function for randomly rotating an image and its sinogram. If the image intersects the edge of the FOV, no rotation is applied.

            image_multChannel:    image to rotate. Shape: (C, H, W)
            sinogram_multChannel: sinogram to rotate. Shape: (C, H, W)
            '''


        def IntersectCircularBorder(image):
            '''
            Function for determining whether an image itersects a circular boundary inscribed within the square FOV.
            This function is not currently used.
            '''
            y_max = image.shape[1]
            x_max = image.shape[2]

            r_max = y_max/2.0
            x_center = (x_max-1)/2.0 # the -1 comes from the fact that the coordinates of a pixel start at 0, not 1
            y_center = (y_max-1)/2.0

            margin_sum = 0
            for y in range(0, y_max):
                for x in range(0, x_max):
                    if r_max < ((x-x_center)**2 + (y-y_center)**2)**0.5 :
                        margin_sum += torch.sum(image[:,y,x]).item()

            return_value = True if margin_sum == 0 else False
            return return_value

        def IntersectSquareBorder(image):
            '''
            Function for determining whether the image intersects the edge of the square FOV. If it does not, then the image
            is fully specified by the sinogram and data augmentation can be performed. If the image does
            intersect the edge of the image then some of it may be cropped outside the FOV. In this case,
            augmentation via rotation should not be performed as the rotated image may not be fully described by the sinogram.
            Looks at all channels in the image.
            '''
            max_idx = image.shape[1]-1
            margin_sum = torch.sum(image[:,0,:]).item() + torch.sum(image[:,max_idx,:]).item() \
                        +torch.sum(image[:,:,0]).item() + torch.sum(image[:,:,max_idx]).item()
            return_value = False if margin_sum == 0 else True
            return return_value

        if IntersectSquareBorder(image_multChannel) == False:
            bins = sinogram_multChannel.shape[2]
            bins_shifted = np.random.randint(0, bins)
            angle = int(bins_shifted * 180/bins)

            image_multChannel = transforms.functional.rotate(image_multChannel, angle, fill=0) # Rotate image. Fill in unspecified pixels with zeros.
            sinogram_multChannel = torch.roll(sinogram_multChannel, bins_shifted, dims=(2,)) # Cycle (or 'Roll') sinogram by that angle along dimension 2.
            sinogram_multChannel[:,:, 0:bins_shifted] = torch.flip(sinogram_multChannel[:,:,0:bins_shifted], dims=(1,)) # flip the cycled portion of the sinogram vertically

        return image_multChannel, sinogram_multChannel

    def VerticalFlip(image_multChannel, sinogram_multChannel):
        image_multChannel = torch.flip(image_multChannel,dims=(1,)) # Flip image vertically
        sinogram_multChannel = torch.flip(sinogram_multChannel,dims=(1,2)) # Flip sinogram horizontally and vertically
        return image_multChannel, sinogram_multChannel

    def HorizontalFlip(image_multChannel, sinogram_multChannel):
        image_multChannel = torch.flip(image_multChannel, dims=(2,)) # Flip image horizontally
        sinogram_multChannel = torch.flip(sinogram_multChannel, dims=(2,)) # Flip sinogram horizontally
        return image_multChannel, sinogram_multChannel

    ## Select Data, Convert to Pytorch Tensors ##
    image_multChannel = torch.from_numpy(image_array[index,:]) # image_multChannel.shape = (C, X, Y)
    sinogram_multChannel = torch.from_numpy(sino_array[index,:]) # sinogram_multChannel.shape = (C, X, Y)

    ## Run Data Augmentation on Selected Data. ##
    if augment==True:
        image_multChannel, sinogram_multChannel = RandRotate(image_multChannel, sinogram_multChannel)           # Always rotates image by a random angle
        if np.random.choice([True, False]): # Half of the time, flips the image vertically
            image_multChannel, sinogram_multChannel = VerticalFlip(image_multChannel, sinogram_multChannel)
        if np.random.choice([True, False]): # Half of the time, flips the image horizontally
            image_multChannel, sinogram_multChannel = HorizontalFlip(image_multChannel, sinogram_multChannel)

    ## Create A Set of Resized Outputs ##
    sinogram_multChannel_resize = transforms.Resize(size = (sino_size, sino_size), antialias=True)(sinogram_multChannel)
    image_multChannel_resize    = transforms.Resize(size = (image_size, image_size), antialias=True)(image_multChannel)

    ## (Optional) Normalize Resized Outputs Along Channel Dimension ##
    if SI_normalize:
        a = torch.reshape(image_multChannel_resize, (image_channels,-1))
        a = nn.functional.normalize(a, p=1, dim = 1)
        image_multChannel_resize = torch.reshape(a, (image_channels, image_size, image_size))
    if IS_normalize:
        a = torch.reshape(sinogram_multChannel_resize, (sino_channels,-1))                     # Flattens each sinogram. Each channel is normalized.
        a = nn.functional.normalize(a, p=1, dim = 1)                      # Normalizes along dimension 1 (values for each of the 3 channels)
        sinogram_multChannel_resize = torch.reshape(a, (sino_channels, sino_size, sino_size))  # Reshapes sinograms back into squares.

    ## Adjust Output Channels of Resized Outputs ##
    if image_channels==1:
        image_out = image_multChannel_resize                 # For image_channels = 1, the image is just left alone
    else:
        image_out = image_multChannel_resize.repeat(image_channels,1,1)   # This chould be altered to account for RGB images, etc.

    if sino_channels==1:
        sino_out = sinogram_multChannel_resize[0:1,:]        # Selects 1st sinogram channel only. Using 0:1 preserves the channels dimension.
    else:
        sino_out = sinogram_multChannel_resize               # Keeps full sinogram with all channels

    # Returns both original and altered sinograms and images, assigned to CPU or GPU
    return sinogram_multChannel.to(device), IS_scale*sino_out.to(device), image_multChannel.to(device), SI_scale*image_out.to(device)

class NpArrayDataSet(Dataset):
    '''
    Class for loading data from .np files, given file directory strings and set of optional transformations.
    In the dataset used in our first two conference papers, the data repeat every 17500 steps but with different augmentations.
    For the dataset with FORE rebinning, the dataset contains no augmented examples; all augmentation is performed on the fly.
    '''
    def __init__(self, image_path, sino_path, config, image_size = 90, sino_size=90, image_channels=1, sino_channels=1,
                 augment=False, offset=0, num_examples=-1, sample_division=1):
        '''
        image_path:         path to images in data set
        sino_path:          path to sinograms in data set
        config:             configuration dictionary with hyperparameters
        image_size:         shape to resize image to (for output)
        image_channels:     number of channels in images
        sino_size:          shape to resize sinograms to (for output)
        sino_channels:      number of channels in sinograms (for photopeak sinograms, this is 1)
        augment:            Set True to perform on-the-fly augmentation of data set. Set False to not perform augmentation.
        offset:             To begin dataset at beginning of the datafile, set offset=0. To begin on the second image, offset = 1, etc.
        num_examples:       Max number of examples to load into dataset. Set to -1 to load the maximum number from the numpy array.
        sample_division:    set to 1 to use every example, 2 to use every other example, etc. (Ex: if sample_division=2, the dataset will be half the size.)
        '''

        ## Load Data to Arrays ##
        image_array = np.load(image_path, mmap_mode='r')       # We use memmaps to significantly speed up the loading.
        sino_array = np.load(sino_path, mmap_mode='r')

        ## Set Instance Variables ##
        if num_examples==-1:
            self.image_array = image_array[offset:,:]
            self.sino_array = sino_array[offset:,:]
        else:
            self.image_array = image_array[offset : offset + num_examples, :]
            self.sino_array = sino_array[offset : offset + num_examples, :]

        self.config = config
        self.image_size = image_size
        self.sino_size = sino_size
        self.image_channels = image_channels
        self.sino_channels = sino_channels
        self.augment = augment
        self.sample_division = sample_division

    def __len__(self):
        length = int(len(self.image_array)/sample_division)
        return length

    def __getitem__(self, idx):

        idx = idx*self.sample_division

        sino_ground, sino_ground_scaled, image_ground, image_ground_scaled = NpArrayDataLoader(self.image_array, self.sino_array, self.config, self.image_size,
                                                                                self.sino_size, self.image_channels, self.sino_channels,
                                                                                augment=self.augment, index=idx)

        return sino_ground, sino_ground_scaled, image_ground, image_ground_scaled
        # Returns both original, as well as altered, sinograms and images

## Generators

In [None]:
######################################
##### Block Generating Functions #####
######################################

def contract_block(in_channels, out_channels, kernel_size, stride, padding=0, padding_mode='reflect', fill=0, norm='batch', drop=False):
    '''
    Function to construct a single "contracting block." Each contracting block consists of one 2D convolutional layer, which decreases
    the size (height and width) of the data. There are then up to three 2D convolution layers which do not change the height or width
    (e.g. "constant size layers").

    in_channels:    number of channels at the input of contracting block
    out_channels:   number of channels at the output of contracting block
    kernel_size:    size of the kernel for the 1st 2D convolutional layer in the contracting block
    stride:         stride of the convolution for the 1st 2D convolutional layer in the contracting block
    padding:        amount of padding for the the 1st 2D convolutional layer in the contracting block
    padding_mode:   padding mode (options: "zeros", "reflect")
    fill:           number of "constant size" 2D convolutional layers
    norm:           type of layer normalization ("batch", "instance", or "none")
    dropout:        include dropout layers in the contracting block? (True or False)
    '''

    if norm=='batch':    norm = nn.BatchNorm2d(out_channels)
    if norm=='instance': norm = nn.InstanceNorm2d(out_channels)
    if norm=='none':     norm = nn.Sequential()
    dropout = nn.Dropout() if drop==True else nn.Sequential()

    # Note: for the contracting block, normalization & dropout follow convolutional layers. For expanding blocks, the order is reversed.
    block1 =  nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode=padding_mode), norm, dropout, nn.ReLU())
    if fill==0:
        block2 = nn.Sequential() # If fill=0, there are no "constant size" convolutional layers, and so block2 is empty.
    if fill==1:
        block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode), norm, dropout, nn.ReLU())
    elif fill==2:
        block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode), norm, dropout, nn.ReLU(),
                                nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode), norm, dropout, nn.ReLU())
    elif fill==3:
        block2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode), norm, dropout, nn.ReLU(),
                                nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode), norm, dropout, nn.ReLU(),
                                nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode), norm, dropout, nn.ReLU())
    return nn.Sequential(block1, block2)

def expand_block(in_channels, out_channels, kernel_size=3, stride=2, padding=0, output_padding=0, padding_mode='zeros', fill=0, norm='batch', drop=False, final_layer=False):
    '''
    Function to construct a single "expanding block." Each expanding block consists of one 2D transposed convolution layer which increases
    the size of the incoming data (height and width). There are then up to three 2D convolution layers which do not change the height or
    width (e.g. "constant size layers").

    in_channels:    number of channels at the input of the expanding block
    out_channels:   number of channels at the output of the expanding block
    kernel_size:    size of the kernel for the 1st 2D transposed convolutional layer in the expanding block
    stride:         stride of the convolution for the 1st 2D transposed convolutional layer in the expanding block
    padding:        amount of padding for the the 1st 2D transposed convolutional layer in the expanding block
    padding_mode:   padding mode (ex: "zeros", "reflect")
    fill:           number of "constant size" 2D convolutional layers
    norm:           type of layer normalization ("batch", "instance", or "none")
    dropout:        include dropout in the expanding block (True or False)
    final_layer:    Is this the final layer in the expanding block? (True or False)
    '''

    if norm=='batch':       norm = nn.BatchNorm2d(out_channels)
    if norm=='instance':    norm = nn.InstanceNorm2d(out_channels)
    if norm=='none':        norm = nn.Sequential()
    dropout = nn.Dropout() if drop==True else nn.Sequential()

    # Note: for the expanding block, normalization & dropout precede convolutional layers in blocks 2-3. For expanding blocks, the order is reversed.
    block1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, padding_mode=padding_mode)
    if fill==0:
        block2 = nn.Sequential()
    if fill==1:
        block2 = nn.Sequential(norm, dropout, nn.ReLU(), nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode))
    elif fill==2: # For
        block2 = nn.Sequential(norm, dropout, nn.ReLU(), nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode),
                                norm, dropout, nn.ReLU(), nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode))
    elif fill==3:
        block2 = nn.Sequential(norm, dropout, nn.ReLU(), nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode),
                                norm, dropout, nn.ReLU(), nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode),
                                norm, dropout, nn.ReLU(), nn.Conv2d(out_channels, out_channels, 3, 1, 1, padding_mode=padding_mode))

    if final_layer==False: # If not the final layer, I add normalization, dropout and activation.
        block3 = nn.Sequential(norm, dropout, nn.ReLU())
    else:                  # Otherwise, I leave off the normalization, dropout, and activation. This allows me to do it explicitly
                           # at the end of the network using tuned parameters.
        block3 = nn.Sequential()
    return nn.Sequential(block1, block2, block3)

###########################
##### Generator Class #####
###########################

class Generator(nn.Module):
    def __init__(self, config, gen_SI=True, input_size=90, input_channels=3, output_channels=3):
        '''
        A class to generate a 90x90-->90x90 or 180x180-->90x90 encoder-decoder network. The role of each item in the "config" dictionary is commented below. In addition, the class constructor takes the following as inputs:

        gen_SI:             Equals True if the generator generates images from sinograms. Equals false if the generator generates sinograms from images.
                            In a cycle-consistent network, this class generates two networks from the same config dictionary. Hence, the need
                            for this parameter.
        input_size:         size of the input (90 or 180).
        input_channels:     number of generator input channels
        output_channels:    number of generator output channels
        '''

        super(Generator, self).__init__()

        ## Set Instance Variables ##
        self.output_channels = output_channels

        ## If gen_SI == True, we use the "SI.." keys from the config dictionary to construct the generator network. ##
        if gen_SI:
            # The following instance variables are defined since these will be used in the forward() method below. #
            self.final_activation = config['SI_gen_final_activ']    # {nn.Tanh(), nn.Sigmoid(), None}
                                                                    # Type of activation function employed at the very end of network
            self.normalize=config['SI_normalize']                   # {True, False} : Normalization
            self.scale=config['SI_scale']                           # Scale factor by which the output is multiplied,
                                                                    #    if the output is first normalized

            ## The following variables are used in the network constructor, and not the forward() method, so there is no need for instance variables.

            neck=config['SI_gen_neck'] #            {1,5,11} :          Width of narrowest part (neck) of the network. The smaller the number, the narrower the neck.
            exp_kernel=config['SI_exp_kernel'] #    {3,4} :             Square kernel width (or height) for the expanding part of the network.
            z_dim=config['SI_gen_z_dim'] #          (Any real number) : Number of channels in the network neck, if neck=1. If neck=5 or 11, this parameter isn't used.
            hidden_dim=config['SI_gen_hidden_dim']# (Any real number) : scales all channels in network by the same linear factor. Larger hidden_dim -->more complex network
            fill=config['SI_gen_fill'] #            {0,1,2,3} :         Number of "constant size" 2D convolutions in each block
            mult=config['SI_gen_mult'] #            (Any real number) : Multiplicative factor by which network channels increase as the layers decrease in height & width
            norm=config['SI_layer_norm'] #          {'instance', 'batch', 'none'} : Type of layer normalization
            pad=config['SI_pad_mode'] #             {'zeros', 'reflect'} :          Type of padding in each layer/block
            drop=config['SI_dropout'] #             {'True', 'False'} :             Whether dropout is used in the network

        #If gen_SI == False, we use the "IS.." keys from the config dictionary to construct the generator network. ##
        else:
            self.final_activation = config['IS_gen_final_activ']
            self.normalize=config['IS_normalize']
            self.scale=config['IS_scale']

            neck=config['IS_gen_neck']
            exp_kernel=config['IS_exp_kernel']
            z_dim=config['IS_gen_z_dim']
            hidden_dim=config['IS_gen_hidden_dim']
            fill=config['IS_gen_fill']
            mult=config['IS_gen_mult']
            norm=config['IS_layer_norm']
            pad=config['IS_pad_mode']
            drop=config['IS_dropout']

        ## Abbreviations used for Block Definitions -- used to make code less awkward ##
        in_chan = input_channels
        out_chan = output_channels

        dim_0 = int(hidden_dim*mult**0) # Number of output channels of 1st block/input channels of 2nd block
        dim_1 = int(hidden_dim*mult**1) # Number of output channels of 2nd block/input channels of 3rd block
        dim_2 = int(hidden_dim*mult**2) # Follows pattern above
        dim_3 = int(hidden_dim*mult**3)
        dim_4 = int(hidden_dim*mult**4)
        dim_5 = int(hidden_dim*mult**5)

        ### Block Definitions ###

        ## Build the Contracting Path ##
        # The formula for the output size of a transposed convolution (nn.Conv2d) in Pytorch is as follows:
        # Hf = [Hi+2*padding-dilation(kernel-1)-1]/stride + 1 = [Hi+2*padding-kernel]/stride + 1 (for dialation=1)

        if input_size==180:
            self.contract = nn.Sequential(
                # nn.Conv2d: Hf = [Hi+2*padding-dilation(kernel-1)-1]/stride + 1 = [Hi+2*padding-kernel]/stride + 1 (for dialation=1)
                # Sinogram Shape: (3,90,90)
                contract_block(in_chan, dim_0, 3, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm, drop=drop), # H = [180+2-3]/2 + 1 = 90
                contract_block(dim_0,   dim_1, 3, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm, drop=drop), # H = [90+2-3]/2 + 1 = 45.5
                contract_block(dim_1,   dim_2, 3, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm, drop=drop), # H = [45+2-3]/2 + 1 = 23
                contract_block(dim_2,   dim_2, 4, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm, drop=drop), # H = [23+2-4]/2 + 1 = 11.5
            )
        elif input_size==90:
            self.contract = nn.Sequential(
                contract_block(in_chan, dim_0, 3, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm, drop=drop), # H = [90+2-3]/2 + 1 = 45.5  : a 90x90 input gives a 45x45 output
                contract_block(dim_0,   dim_1, 3, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm, drop=drop), # H = [45+2-3]/2 + 1 = 23    : a 45x45 input gives a 23x23 output
                contract_block(dim_1,   dim_2, 4, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm, drop=drop), # H = [23+2-4]/2 + 1 = 11.5  : a 23x23 input gives a 11x11 output
            )

        ## Build the Neck. There are 3 options ##
        # neck=1 gives the narrowest (1x1) neck #
        if neck==1:
            self.neck = nn.Sequential(
                contract_block(dim_2, dim_3, 4, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm, drop=drop), # H = [11+2-4]/2 + 1 = 5.5
                contract_block(dim_3, dim_4, 3, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm           ), # H = [5+2*1-3]/2 + 1 = 3
                contract_block(dim_4, z_dim, 3, stride=1, padding=0,                   fill=0,    norm='batch'        ), # H = 1   ||norm is set to 'batch' because 'instance' won't work on 1x1 layer
                expand_block(  z_dim, dim_4, 3, stride=2, padding=0,                   fill=fill, norm=norm           ), # H = [1-1]*2+5 = 3
                expand_block(  dim_4, dim_3, 4, stride=2, padding=2, output_padding=1, fill=fill, norm=norm           ), # H = [3-1]*2+4-2*2+1 = 5
            )

        # neck=5 gives the middle width (5x5) neck #
        if neck==5:
            self.neck = nn.Sequential(
                contract_block(dim_2, dim_3, 4, stride=2, padding=1, padding_mode=pad, fill=fill, norm=norm, drop=drop), # H = [11+2-4]/2 + 1 = 5.5
                contract_block(dim_3, dim_3, 5, stride=1, padding=2, padding_mode=pad,            norm=norm           ), # H = [5+2*2-5]/1 + 1 = 5 (Constant Block)
                contract_block(dim_3, dim_3, 5, stride=1, padding=2, padding_mode=pad,            norm=norm           ), # H = [5+2*2-5]/1 + 1 = 5 (Constant Block)
                contract_block(dim_3, dim_3, 5, stride=1, padding=2, padding_mode=pad,            norm=norm           ), # H = [5+2*2-5]/1 + 1 = 5 (Constant Block)
                #contract_block(dim_3, dim_3, kernel_size=5, stride=1, padding=2, padding_mode=pad, norm=norm), # H = [5+2*2-5]/1 + 1 = 5 (Constant Block) # Add this next tuning!
            )

        # neck=11 gives the thickest (11x11) neck #
        if neck==11:
            self.neck = nn.Sequential(
                contract_block(dim_2, dim_2, kernel_size=5, stride=1, padding=2, padding_mode=pad, norm=norm), # H = [11+2*2-5]/1 + 1 = 11 (Constant Block)
                contract_block(dim_2, dim_2, kernel_size=5, stride=1, padding=2, padding_mode=pad, norm=norm), # H = [11+2*2-5]/1 + 1 = 11 (Constant Block)
                contract_block(dim_2, dim_2, kernel_size=5, stride=1, padding=2, padding_mode=pad, norm=norm), # H = [11+2*2-5]/1 + 1 = 11 (Constant Block)
                contract_block(dim_2, dim_2, kernel_size=5, stride=1, padding=2, padding_mode=pad, norm=norm), # H = [11+2*2-5]/1 + 1 = 11 (Constant Block)
                contract_block(dim_2, dim_2, kernel_size=5, stride=1, padding=2, padding_mode=pad, norm=norm), # H = [11+2*2-5]/1 + 1 = 11 (Constant Block)
            )

        ## Build the Expanding Blocks ##
        # The formula for the output size of a transposed convolution (nn.ConvTranspose2d:) in Pytorch is as follows:
        # Hf = (Hi-1)*stride -2*padding +dilation*(kernel-1) +output_padding+1
        #    = (Hi-1)*stride +kernel -2*padding +output_padding (for dialation=1)

        # For neck=1 or 5, the output from previous layers is 5x5. Therefore, these can use the same expanding blocks #
        if (neck==1 or neck==5):
            if exp_kernel==3:
            # Expanding block for neck=1 or 5, expanding kernel size = 3)
                self.expand = nn.Sequential(
                    expand_block(dim_3, dim_2,                      kernel_size=3, stride=2, padding=0, output_padding=0, fill=fill, norm=norm), # H = (5-1)*2  +3         = 11
                    expand_block(dim_2, dim_1,                      kernel_size=3, stride=2, padding=1, output_padding=1, fill=fill, norm=norm), # H = (11-1)*2 +3 -2*1 +1 = 22
                    expand_block(dim_1, dim_0,                      kernel_size=3, stride=2, padding=0, output_padding=0, fill=fill, norm=norm), # H = (22-1)*2 +3         = 45
                    expand_block(dim_0, out_chan, final_layer=True, kernel_size=3, stride=2, padding=1, output_padding=1, fill=fill, norm=norm), # H = (45-1)*2 +3 -2*1 +1 = 90
                )

            elif exp_kernel==4:
            # Expanding block for neck=1 or 5, expanding kernel size = 4
                self.expand = nn.Sequential(
                    expand_block(dim_3, dim_2,                      kernel_size=4, stride=2, padding=1, output_padding=1, fill=fill, norm=norm),  # H = (5-1)*2  +4 -2*1 +1 = 11
                    expand_block(dim_2, dim_1,                      kernel_size=4, stride=2, padding=1, output_padding=0, fill=fill, norm=norm),  # H = (11-1)*2 +4 -2*1    = 22
                    expand_block(dim_1, dim_0,                      kernel_size=4, stride=2, padding=1, output_padding=1, fill=fill, norm=norm),  # H = (21-1)*2 +4 -2*1 +1 = 45
                    expand_block(dim_0, out_chan, final_layer=True, kernel_size=4, stride=2, padding=1, output_padding=0, fill=fill, norm=norm),  # H = (45-1)*2 +4 -2*1    = 90
                )

        # For neck=11, the output is 11x11. This neck requires its own expanding blocks #
        if neck==11:
            if exp_kernel==3:
            # Expanding block for neck=11, expanding kernel size = 3
                self.expand = nn.Sequential(
                    expand_block(dim_2, dim_1,                      kernel_size=3, stride=2, padding=1, output_padding=1, fill=fill, norm=norm),  # H = (11-1)*2 +3 -2*1 +1 = 22
                    expand_block(dim_1, dim_0,                      kernel_size=3, stride=2, padding=0, output_padding=0, fill=fill, norm=norm),  # H = (22-1)*2 +3         = 45
                    expand_block(dim_0, out_chan, final_layer=True, kernel_size=3, stride=2, padding=1, output_padding=1, fill=fill, norm=norm),  # H = (45-1)*2 +3 -2*1 +1 = 90
                )

            if exp_kernel==4:
            # Expanding block for neck=11, expanding kernel size = 4
                self.expand = nn.Sequential(
                    expand_block(dim_2, dim_1,                      kernel_size=4, stride=2, padding=1, output_padding=0, fill=fill, norm=norm),  # H = (11-1)*2 +4 -2*1    = 22
                    expand_block(dim_1, dim_0,                      kernel_size=4, stride=2, padding=1, output_padding=1, fill=fill, norm=norm),  # H = (21-1)*2 +4 -2*1 +1 = 45
                    expand_block(dim_0, out_chan, final_layer=True, kernel_size=4, stride=2, padding=1, output_padding=0, fill=fill, norm=norm),  # H = (45-1)*2 +4 -2*1    = 90
                )

    def forward(self, input):
        # This method gets run when the network is called to produce an output from an input #

        batch_size = len(input)  # Get batch size

        a = self.contract(input) # Run input through contracting blocks
        a = self.neck(a)         # Run output from contracting blocks through the neck
        a = self.expand(a)       # Run outoput from the neck through the expanding blocks

        if self.final_activation:   # Optional final activations
            a = self.final_activation(a)
        if self.normalize:          # Optionally normalize
            a = torch.reshape(a,(batch_size, self.output_channels, 90**2)) # Flattens each image
            a = nn.functional.normalize(a, p=1, dim = 2)
            a = torch.reshape(a,(batch_size, self.output_channels , 90, 90)) # Reshapes images back into square matrices
            a = self.scale*a        # If normalizing, multiply the outputs by a scale factor

        return a                    # Return the output

## Discriminators

In [None]:
#################################
#### SINOGRAMS DISCRIMINATOR ####
#################################

class Disc_S_90(nn.Module):
    '''
    Through experimentation it has been found that sinogram discriminators work best with a fat network neck.
    This class takes as input a 90x90.
    '''
    def __init__(self, config, disc_I=True, input_channels=3):
        super(Disc_S_90, self).__init__()

        hidden_dim=config['IS_disc_hidden_dim']
        patchGAN=config['IS_disc_patchGAN']

        ## Sequence 1 ##
        self.seq1 = nn.Sequential(
            # Sinogram Shape: (in_channels,90,90)
            # nn.Conv2d: Hf = [Hi+2*padding-dilation(kernel-1)-1]/stride + 1
            #               = [Hi+2*padding-kernel]/stride + 1 (for dialation=1)

            # Feature Map Block
            nn.Conv2d(in_channels=sino_channels, out_channels=hidden_dim, kernel_size=7, padding=3, padding_mode='reflect'),

            # Contracting Block without normalization:
            # H1 = (90-4)/2+1 = 44
            nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim*2, kernel_size=4, stride=2, padding=0, padding_mode='reflect'),
                nn.LeakyReLU(negative_slope=0.2),

            # Contracting Blocks:
            # H1 = (44-4)/2+1 = 21
            nn.Conv2d(in_channels=hidden_dim*2, out_channels=hidden_dim*3, kernel_size=4, stride=2, padding=0, padding_mode='reflect'),
                nn.InstanceNorm2d(hidden_dim*3), nn.LeakyReLU(negative_slope=0.2),
            # H1 = (21-4)/2+1 = 9.5 = 9
            nn.Conv2d(in_channels=hidden_dim*3, out_channels=hidden_dim*4, kernel_size=4, stride=2, padding=0, padding_mode='reflect'),
                nn.InstanceNorm2d(hidden_dim*4), nn.LeakyReLU(negative_slope=0.2),
            # H1 = (9-4)/2+1 = 3.5 = 3
            nn.Conv2d(in_channels=hidden_dim*4, out_channels=hidden_dim*5, kernel_size=4, stride=2, padding=0, padding_mode='reflect'),
                nn.InstanceNorm2d(hidden_dim*5), nn.LeakyReLU(negative_slope=0.2),
        )

        ## PatchGAN ##
        if patchGAN==True:
            # H = [3+2*1-3]/1+1 = 3 (3x3x3 matrix)
            self.seq2 = nn.Sequential(
                nn.Conv2d(hidden_dim*5, hidden_dim*5, kernel_size=3, padding=1, padding_mode='reflect'),
                    nn.BatchNorm2d(hidden_dim*5), nn.LeakyReLU(negative_slope=0.2),
            )
        else:
            self.seq2 = nn.Sequential(
                # H0 = (3-3)/1+1 = 1 (1x1x3 matrix)
                nn.Conv2d(in_channels=hidden_dim*5, out_channels=hidden_dim*5, kernel_size=3),
                    nn.BatchNorm2d(hidden_dim*5), nn.LeakyReLU(negative_slope=0.2),
            )
        ## 1x1 Convolution ##
        self.seq3 = nn.Conv2d(hidden_dim * 5, sino_channels, kernel_size=1)

    def forward(self, image):
        a = self.seq1(image)
        b = self.seq2(a) # a tensor
        c = self.seq3(b)
        #return disc_pred.view(len(disc_pred), -1) # returns a flattened tensor
        return c.squeeze()

##############################
#### IMAGES DISCRIMINATOR ####
##############################

class Disc_I_90(nn.Module):
    def __init__(self, config, disc_I=True, input_channels=3):
        super(Disc_I_90, self).__init__()

        hidden_dim=config['SI_disc_hidden_dim']
        patchGAN=config['SI_disc_patchGAN']

        ## Sequence 1 ##
        self.seq1 = nn.Sequential(
            # Image Shape: (1,90,90)
            # nn.Conv2d: Hf = [Hi+2*padding-dilation(kernel-1)-1]/stride + 1
            #               = [Hi+2*padding-kernel]/stride + 1 (for dialation=1)

            # H = [90-4]/2+1 = 44
            nn.Conv2d(in_channels=input_channels, out_channels=hidden_dim, kernel_size=4, stride=2),
                nn.BatchNorm2d(hidden_dim), nn.LeakyReLU(negative_slope=0.2),
            # H = [44-4]/2+1 = 21
            nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim*2, kernel_size=4, stride=2),
                nn.BatchNorm2d(hidden_dim*2), nn.LeakyReLU(negative_slope=0.2),
            # H = [21-4]/2+1 = 9.5 = 9
            nn.Conv2d(in_channels=hidden_dim*2, out_channels=hidden_dim*4, kernel_size=4, stride=2),
                nn.BatchNorm2d(hidden_dim*4), nn.LeakyReLU(negative_slope=0.2),
            # H = [9+2-4]/2+1 = 4.5 = 4
            nn.Conv2d(in_channels=hidden_dim*4, out_channels=hidden_dim*4, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(hidden_dim*4), nn.LeakyReLU(negative_slope=0.2),
        )

        ## Sequence 2 ##
        if patchGAN==True:
            # H = [4+2-3]/1+1 = 4
            self.seq2=nn.Conv2d(hidden_dim*4, 1, kernel_size=3, padding=1, padding_mode='reflect')
        else:
            # H = [4-4]/2+1 = 1
            self.seq2=nn.Conv2d(hidden_dim*4, 1, kernel_size=4, stride=2)

    def forward(self, image):

        a = self.seq1(image)
        disc_pred = self.seq2(a) # a tensor
        #return disc_pred.view(len(disc_pred), -1) # returns a flattened tensor
        return disc_pred.squeeze()

# Functions

## Cropping & Weight Initialization

In [None]:
def crop_single_image_by_size(image, crop_size=-1):
    '''
    Function to crop a single image to a square shape, with even margins around the edges.

    image:       Input image tensor of shape [height, width]
    crop_size:   Edge size of (square) image to keep. The edges are discarded.
    '''
    x_size = image.shape[1]

    margin_low = int((x_size-crop_size)/2.0)  # (90-71)/2 = 19/2 = 9.5 -->9
    margin_high = x_size-crop_size-margin_low # 90-71-9 = 10

    pix_min = 0 + margin_low
    pix_max = x_size - margin_high

    image = image[pix_min : pix_max , pix_min : pix_max]

    return image

def crop_single_image_by_factor(image, crop_factor=1):
    '''
    Function to crop a single image for a factor, with even margins around the edges.

    image:       Input image tensor of shape [height, width]
    crop_factor: Fraction of image to keep. The image is trimmed so the edges are discarded.
    '''
    x_size = image.shape[1]
    y_size = image.shape[0]

    x_margin = int(x_size*(1-crop_factor)/2)
    y_margin = int(y_size*(1-crop_factor)/2)

    x_min = 0 + x_margin
    x_max = x_size - x_margin
    y_min = 0 + y_margin
    y_max = y_size - y_margin

    return image_tensor[y_min:y_max , x_min:x_max]


def crop_image_tensor_with_corner(batch, crop_size, corner=(0,0)):
    '''
    Function which returns a smaller, cropped version of a tensor (multiple images)

    batch:       a batch of images with dimensions: (num_images, channel, y_dimension, x_dimension)
    corner:      upper-left corner of window
    crop_size:   size of cropping window (int)
    '''

    y_min = corner[0]
    y_max = corner[0]+crop_size
    x_min = corner[1]
    x_max = corner[1]+crop_size

    return batch[:, :, y_min:y_max , x_min:x_max ]


def crop_image_tensor_by_factor(image_tensor, crop_factor=1):
    '''
    Function to crop an image tensor, with even margins around the edges.

    image_tensor:   Input image tensor of shape [image number, channel, height, width]
    crop_factor:    Fraction of image to keep. The images are trimmed so the edges are discarded.
    '''
    x_size = image_tensor.shape[3]
    y_size = image_tensor.shape[2]

    x_margin = int(x_size*(1-crop_factor)/2)
    y_margin = int(y_size*(1-crop_factor)/2)

    x_min = 0 + x_margin
    x_max = x_size - x_margin
    y_min = 0 + y_margin
    y_max = y_size - y_margin

    return image_tensor[:,:, y_min:y_max , x_min:x_max ]

def weights_init(m): # 'm' represents layers in the generator or discriminator.

    #Function for initializing network weights to normal distribution, with mean 0 and s.d. 0.02
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

## Reconstructions & Projection

In [None]:
def iradon_MLEM(sino_ground, azi_angles=None, max_iter=15, circle=True, crop_factor=2**0.5/2):
    '''
    Function to reconstruct a single PET image from a single sinogram using ML-EM.

    sino_ground:    sinogram (photopeak). This is a numpy array with minimum values of 0. Shape: (H,W)
    azi_angles:     list of azimuthal angles for sinogram. If set to None, angles are assumed to span [0,180)
    max_iter:       Maximum number of iterations for ML-EM algorithm.
    circle:         circle=True: The projection data spans the width (or height) of the activity distribution, and the reconstructed image is circular.
                    circle=False: The projection data (sinograms) spans the corner-to-corner line of the activity distribution, and the reconstructed image is square.
    crop_size:      Size to crop the image to after performing ML-EM. For ML-EM performed on a 90x90 sinogram, the
                    output image will be 90x90. However, it is necessary to crop this to 64x64 to get the same FOV
                    as the dataset. This means the image must cropped by a factor of sqrt(2)/2.
    '''
    if azi_angles==None:
        num_angles = sino_ground.shape[1] # Width
        azi_angles=np.linspace(0, 180, num_angles, endpoint=False)

    ## Create Sensitivity Image ##
    sino_ones = np.ones(sino_ground.shape)
    sens_image = iradon(sino_ones, azi_angles, circle=circle, filter_name=None)

    if circle==False:
        def modify_sens(image, const_factor=0.9, slope=0.03):
            '''
            Modifies an image so that the area in the central FOV remains constant, but values at edges are attenuated.
            image               image to modify
            constant_factor     fraction of the image to leave alone
            slope               increase this to attenuate images at the edges more
            '''
            def shape_piecewise(r, const_value, slope):
                if r <= const_value:
                    return 1
                else:
                    return 1+slope*(r-const_value)

            y_max = image.shape[0]
            x_max = image.shape[1]

            const_dist = const_factor*x_max/2 # radius over which image remains constant

            x_center = (x_max-1)/2.0 # the -1 comes from the fact that the coordinates of a pixel start at 0, not 1
            y_center = (y_max-1)/2.0

            for y in range(0, y_max):
                for x in range(0, x_max):
                    r = ((x-x_center)**2 + (y-y_center)**2)**0.5

                    total_factor = shape_piecewise(r, const_dist, slope) # creates a circular shaped piece-wise
                    #total_factor = shape_piecewise(abs(x-x_center), const_dist, slope) * shape_piecewise(abs(y-y_center), const_dist, slope) # square-shaped piecewise
                    #total_factor = shape_piecewise(abs(y-y_center), const_dist, slope) # vertical only

                    image[y,x] = image[y,x]*total_factor

            return image
        sens_image = modify_sens(sens_image)

    ## Create blank reconstruction ##
    image_recon  = np.ones(sens_image.shape)

    for iter in range(max_iter):

        if circle==True:
            sens_image = sens_image + 0.001 # Guarantees the denominator is >0

        sino_recon = radon(image_recon, azi_angles, circle=circle) #
        sino_recon[sino_recon==0]=1000 # Set a limit on the denominator (next line)
        sino_ratio = sino_ground / (sino_recon) #
        image_ratio = iradon(sino_ratio, azi_angles, circle=circle, filter_name=None) / sens_image
        image_ratio[image_ratio>1.5]=1.5 # Sets limit on backprojected ratio, on how fast image can grow. Threshold and set value should equal each other (good value=1.5)
        image_recon = image_recon * image_ratio
        image_recon[image_recon<0]=0 # Sets floor on image pixels. No need to adjust.

        #footprint = morphology.disk(1)
        #image_recon = opening(image_recon, footprint)

    image_cropped = crop_single_image_by_factor(image_recon, crop_factor=crop_factor)
    #image_cropped = crop_single_image_by_size(image_recon, crop_size=crop_size)

    return image_cropped

def reconstruct(sinogram_tensor, config, image_size=90, recon_type='FBP', circle=True):
    '''
    Function for calculating a reconstructed PET image tensor, given a sinogram_tensor. One image is reconstructed for
    each sinogram in the sinogram_tensor.

    sinogram_tensor:    Tensor of sinograms of size (number of images)x(channels)x(height)x(width).
                        Only the first channel (photopeak) is used for recontruction here.
    config:             configuration dictionary
    image_size:         size of output (images are resized to this shape)
    recon_type:         Can be set to 'MLEM' for maximum-likelihood expectation maximization, or 'FBP' for
                        filtered back-projection.
    circle              circle=True: The projection data spans the width (or height) of the activity distribution, and the reconstructed image is circular.
                        circle=False: The projection data (sinograms) spans the corner-to-corner line of the activity distribution, and the reconstructed image is square.

    Function returns a tensor of reconstructed images. Returned images are resized, and optionall normalized and scaled (according to the keys in the configuration dictionary)
    '''
    normalize = config["SI_normalize"]
    scale = config['SI_scale']

    photopeak_array = torch.clamp(sinogram_tensor[:,0,:,:], min=0).detach().cpu().numpy()  # Here, we collapse the channel dimension.
    # Note: there really should be no need to clamp the sinogram, as it should contain no negative values, but might as well.

    ## Reconstruct Individual Sinograms ##
    first=True
    for sino in photopeak_array[0:,]:
        if recon_type == 'FBP':
            image = iradon(sino.squeeze(), # Sinogram is now 2D
                        circle=False, # For an unknown reason, circle=False gives better reconstructions here. Maybe due to errors introduced in interpolation.
                        preserve_range=True,
                        filter_name='cosine' # Options: 'ramp', 'shepp-logan', 'cosine', 'hamming', 'hann'
                        )
        else:
            image = iradon_MLEM(sino, circle=circle)

        ## Morphologic Opening - removes outlier pixels than can cause problems with image normalization
        #footprint = morphology.disk(1)
        #image = opening(image, footprint)

        ## Concatenate Images ##
        image = np.expand_dims(image, axis=0) # Add a dimension to the beginning of the reconstructed image
        if first==True:
            image_array = image
            first=False
        else:
            image_array = np.append(image_array, image, axis=0)

    ## For All Images: create resized/dimensioned Torch tensor ##
    image_array = np.expand_dims(image_array, axis=1)        # Creates channels dimension
    a = torch.from_numpy(image_array)                        # Converts to Torch tensor
    a = torch.clamp(a, min=0)                                # You HAVE to clamp before normalizing or the negative values throw it off.
    a = transforms.Resize(size = (image_size, image_size), antialias=True)(a) # Resize tensor

    ## Normalize Entire Tensor ##
    if normalize:
        batch_size = len(a)
        a = torch.reshape(a,(batch_size, 1, image_size**2)) # Flattens each image
        a = nn.functional.normalize(a, p=1, dim = 2)
        a = torch.reshape(a,(batch_size, 1 , image_size, image_size)) # Reshapes images back into square matrices
        a = scale*a

    return a.to(device)

def project(image_tensor, circle=False, theta=-1):
    '''
    Perform the forward radon transform to calculate projections from images. Returns an array of sinograms.

    image_tensor:   tensor of PET images
    theta:          numpy array of projection angles. Default is [0,180)
    '''
    image_collapsed = torch.clamp(image_tensor[:,0,:,:], min=0).detach().squeeze().cpu().numpy()

    if theta==-1:
        theta = np.arange(0,180)

    first=True
    for image in image_collapsed[0:,]:
        sino = radon(image,
                    circle=circle,
                    preserve_range=True,
                    theta=theta,
                    )
        sino = np.moveaxis(np.atleast_3d(sino), 2, 0) # Adds a blank axis and moves it to the beginning
        if first==True:
            sino_array=sino
            first=False
        else:
            sino_array = np.append(sino_array, sino, axis=0)

    return torch.from_numpy(sino_array)

'''
### Functions that are no longer used ###

def FBP2(sinogram_tensor, config, image_size = 90, circle=circle):
    #This is an alternative filtered back-projection implementation. Not currently used.

    normalize = config["SI_normalize"]
    scale = config['SI_scale']

    photopeak_array = sinogram_tensor[:,0,:,:].detach().squeeze().cpu().numpy()
    # Note: there's no need to clamp the sinogram as it contains no negative values.

    first=True
    for sino in photopeak_array[0:,]:
        image = iradon(sino,
                    circle=circle,
                    preserve_range=True,
                    filter_name='cosine' # Options: 'ramp', 'shepp-logan', 'cosine', 'hamming', 'hann'
                    )

        ## For Individual Images: create resized/dimensioned Torch tensors ##
        image = np.expand_dims(image, axis = 0) # Creates an extra dimension at beginning for images in the batch.
        image = torch.from_numpy(image)             # I convert the array to a tensor so I can perform the the resizing and clamping below
        image = torch.clamp(image, min=0)           # Clamping
        image = (transforms.Resize(size = (image_size, image_size))(image)).numpy() # I convert back to Numpy so I can use the append function later.

        ## Normalize each individual image ##
        if normalize==True:
            image = image/np.sum(image)

        if first==True:
            image_array = image
            first=False
        else:
            image_array = np.append(image_array, image, axis=0)

    image_array = np.expand_dims(image_array, axis=1) # Creates channels dimension

    return scale*torch.from_numpy(image_array).to(device)


def shape_smooth(r,R=10000): # NOTE: this function is currently not used

    #Returns a smoothly varying function. At r=0, returns 1. As r increases, returned value decreases.
    #The function if a portion of a circle.

    return (R**2-r**2)**0.5+1-R


def weights_init(m): # 'm' represents layers in the generator or discriminator.

    #Function for initializing network weights to normal distribution, with mean 0 and s.d. 0.02

    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.uniform_(m.weight, 0, 0.0001)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.uniform_(m.weight, 0, 0.0001)
        torch.nn.init.constant_(m.bias, 0.02)

'''

'\n### Functions that are no longer used ###\n\ndef FBP2(sinogram_tensor, config, image_size = 90, circle=circle):\n    #This is an alternative filtered back-projection implementation. Not currently used.\n\n    normalize = config["SI_normalize"]\n    scale = config[\'SI_scale\']\n\n    photopeak_array = sinogram_tensor[:,0,:,:].detach().squeeze().cpu().numpy()\n    # Note: there\'s no need to clamp the sinogram as it contains no negative values.\n\n    first=True\n    for sino in photopeak_array[0:,]:\n        image = iradon(sino,\n                    circle=circle,\n                    preserve_range=True,\n                    filter_name=\'cosine\' # Options: \'ramp\', \'shepp-logan\', \'cosine\', \'hamming\', \'hann\'\n                    )\n\n        ## For Individual Images: create resized/dimensioned Torch tensors ##\n        image = np.expand_dims(image, axis = 0) # Creates an extra dimension at beginning for images in the batch.\n        image = torch.from_numpy(image)        

## Display Images

In [None]:
def show_single_unmatched_tensor(image_tensor, grid=False, cmap='inferno', fig_size=1):
    '''
    Function for visualizing images. The images are displayed, each with their own colormap scaling, so quantitative comparisons are not possible.
    Send only the images you want plotted to this function. Works with both single channel and multi-channel images.
    If using the single-channel grid option, it plots 120 images in a 15x8 grid.

    image_tensor:   image tensor of shape [num, chan, height, width]
    grid:           If True, displays images in a 15x8 grid (120 images in total). If false, images are displayed in a horizontal line.
    cmap:           Matplotlib color map
    fig_size:       figure size
    '''
    print(f'Shape: {image_tensor.shape} // Min: {torch.min(image_tensor)} // Max: {torch.max(image_tensor)} \
    //Mean Sum (per image): {torch.sum(image_tensor).item()/(image_tensor.shape[0]*image_tensor.shape[1])} // Sum (a single image): {torch.sum(image_tensor[0,0,:])}')

    #image_tensor = image_tensor.detach().squeeze(.cpu()
    image_tensor = image_tensor.detach().cpu()
    image_tensor = torch.clamp(image_tensor, min=0)

    num = image_tensor.size(dim=0)
    chan = image_tensor.size(dim=1)

    ## Plot 3-Channel Images ##
    #image_np = image_grid.mean(dim=0).squeeze().numpy() # This also works!

    ## Plot Multi-Channel Images ##
    if chan!=1:
        #123
        print(f'Mean (Ch 0): {torch.mean(image_tensor[:,0,:,:])} // Mean (Ch 1): {torch.mean(image_tensor[:,1,:,:])} // Mean (Ch 2): {torch.mean(image_tensor[:,2,:,:])}')

        # Plot Grid #
        if grid:
            fig, ax = plt.subplots(num, chan, figsize=(fig_size*num, fig_size*chan), constrained_layout=True)
            for N in range(0, num): # Iterate over image number
                for C in range(0, chan): # Iterate over channels
                    img = image_tensor[N,C,:,:]
                    ax[N,C].axis('off')
                    ax[N,C].imshow(img.squeeze(), cmap=cmap)

        # Plot in-Line #
        else:
            fig, ax = plt.subplots(1, num*(chan+1), figsize=(fig_size, fig_size*num*(chan+1)), constrained_layout=True)
            i=0
            for N in range(0, num): # Iterate over image number
                for C in range(0, chan): # Iterate over channels
                    img = image_tensor[N,C,:,:]
                    ax[i].axis('off')
                    ax[i].imshow(img.squeeze(), cmap=cmap)
                    i+=1
                blank = torch.ones_like(img)
                ax[i].axis('off')
                ax[i].imshow(blank.squeeze())
                i+=1

    ## Plot 1 Channel Images ##
    else:
        # Plot Grid #
        # Note: This plots 120 images at a time!
        if grid:
            cols, rows = 15, 8
        # Plot in-Line #
        else:
            rows = 1
            cols = image_tensor.shape[0]

        figure=plt.figure(figsize=(cols*fig_size,rows*fig_size))

        for i in range(0, cols*rows):
            img = image_tensor[i]             # Shape: torch.Size([3, 1, 180, 180]) /
            figure.add_subplot(rows,cols,i+1) # MatplotLib indeces start at 1
            plt.axis("off")
            plt.imshow(img.squeeze(), cmap=cmap)

    plt.show()


def show_multiple_matched_tensors(*image_tensors, cmap='inferno', fig_size=0.8):
    '''
    Function for visualizing images from multiple tensors. Each image is "matched" with images from the other tensors,
    and each matched set of images (one from each tensor) is plotted with the same colormap in a column.
    Send only the images you want plotted to this function. Works with both single channel and multi-channel images.

    image_tensors:  list of tensors, each of which may contain multiple images.
    '''
    for tensor in image_tensors:
        # Begin by printing statistics for each tensor
        print(f'Shape: {tensor.shape} // Min: {torch.min(tensor)} // Max: {torch.max(tensor)} \
        // Mean: {torch.mean(tensor)} // Mean Sum (per image): {torch.sum(tensor).item()/(tensor.shape[0]*tensor.shape[1])} // Sum (a single image): {torch.sum(tensor[0,0,:])}')

    combined_tensor = torch.cat(image_tensors, dim=0).detach().cpu()
    combined_tensor = torch.clamp(combined_tensor, min=0)

    num_rows = len(image_tensors)           # The number of rows equals the number of tensors (images to match)
    num_cols = len(image_tensors[0])        # The length of the zeroth element (of the list) is the number of images in a tensor.
    num_chan = image_tensors[0].size(dim=1) # Equivalent to: image_tensors[0].shape(1)

    ## Plot 1 Channel Images ##
    if num_chan==1:
        fig, ax = plt.subplots(num_rows, num_cols, squeeze=False, figsize=(fig_size*num_cols, fig_size*num_rows), constrained_layout=True)
        #fig, ax = plt.subplots(num_rows, num_cols, constrained_layout=True)

        i=0 # i = column number
        for col in range(0, num_cols): # Iterate over column number. All images in a column will have the same colormap.
            img_list=[]
            min_list=[]
            max_list=[]

            # Construct image list and normalization object for matched images in a column (iterating over rows) #
            for row in range(0,num_rows):                               # We iterate over rows in orcer
                img = combined_tensor[row*num_cols+col, 0 ,:,:]         # Grab the correct image (zeroth channel for 1-D images)
                img_list.append(img)                                    # We construct a new image list for each row
                min_list.append(torch.min(img).item())                  # Create list of image minimums
                max_list.append(torch.max(img).item())                  # Create list of image maximums
            norm = Normalize(vmin=min(min_list), vmax=max(max_list))    # We construct a normalization object with min/max = min/max pixel value for all images in list

            # Plot normalized images in a single column (iterating over rows) #
            for row in range(0,num_rows):
                ax[row, i].axis('off')
                ax[row, i].imshow(img_list[row].squeeze(), cmap=cmap, norm=norm) # Squeeze gets rid of extra channel dimension
            i+=1

    ## Plot Multi-Channel Images ##
    else:
        print(f'Mean (Ch 0): {torch.mean(combined_tensor[:,0,:,:])} // Mean (Ch 1): {torch.mean(combined_tensor[:,1,:,:])} // Mean (Ch 2): {torch.mean(combined_tensor[:,2,:,:])}')

        #if num_cols>3:  # Restricts to 3-channels. You could get rid of this without an issue.
        #    num_cols=3

        # Construct figure and axes. Note: 'num_chan+1' arises from the divider blank image btw. each multi-channel image
        fig, ax = plt.subplots(num_rows, num_cols*(num_chan+1), squeeze=False, figsize=(fig_size*num_cols*(num_chan+1), fig_size*num_rows), constrained_layout=True)

        i=0
        for col in range(0, num_cols):      # Iterate over column number
            for chan in range(0, num_chan): # Iterate over channels
                img_list=[]
                min_list=[]
                max_list=[]

                # Iterates over rows (one row per tensor) to construct an image list and normalization object a single column. All matched images have the same channel. #
                for row in range(0,num_rows):
                    img = combined_tensor[row*num_cols+col, chan ,:,:] # Constructs an image list where each row has the same channel #
                    img_list.append(img)
                    min_list.append(torch.min(img).item())
                    max_list.append(torch.max(img).item())
                norm = Normalize(vmin=min(min_list), vmax=max(max_list))

                # Iterates over rows to plot matched images in a single column. These share the same channel. #
                for row in range(0,num_rows):
                    ax[row, i].axis('off')
                    ax[row, i].imshow(img_list[row].squeeze(), cmap=cmap, norm=norm) # Squeeze gets rid of extra channel dimension
                i+=1

            # After all channels have been iterated, the complete multi-channel image has been plotted. Now we plot a divider before the next image #
            for row in range(0,num_rows):
                blank = torch.ones_like(img)
                ax[row, i].axis('off')
                ax[row, i].imshow(blank.squeeze())
            i+=1

    plt.show()

def show_single_commonmap_tensor(image_tensor, nrow=15, figsize=(27,18), cmap='inferno'):
    '''
    Function for visualizing images from one tensor, all of which will be plotted with the same scaled colormap. Only works with single-channel image tensors.

    image_tensor:  image tensor. nrow should go into this evenly.
    nrow:          number of rows for the image grid
    figsize:       figure size
    cmap:          color map
    '''
    tensor = torch.clamp(image_tensor, min=0).detach().cpu()
    image_grid = make_grid(tensor, nrow=nrow)  # from torchvision.utils import make_grid

    #print(f'Shape: {tensor.shape} // Min: {torch.min(tensor)} // Max: {torch.max(tensor)} \
    #// Mean: {torch.mean(tensor)} // Mean Sum (per image): {torch.sum(tensor).item()/(tensor.shape[0]*tensor.shape[1])} // Sum (a single image): {torch.sum(tensor[0,0,:])}')

    fig, ax = plt.subplots(1,1, figsize=figsize)
    ax.axis('off')

    image_grid = image_grid[0,:].squeeze()
    #plt.imshow(image_grid, cmap=cmap)
    im = ax.imshow(image_grid, cmap=cmap)
    #fig.colorbar(im, ax=ax)
    plt.show()

def show_multiple_commonmap_tensors(*image_tensors, cmap='inferno'):
    '''
    Function for visualizing images from multiple tensors, all of which will be plotted with the same scaled colormap. Only works with single-channel image tensors.

    *image_tensors: list of image tensors, all of which should contain the same number of images. Only send the number of images you want to plot to this function.
    '''
    # Print tensor statistics #
    for tensor in image_tensors:
        print(f'Shape: {tensor.shape} // Min: {torch.min(tensor)} // Max: {torch.max(tensor)} \
        // Mean: {torch.mean(tensor)} // Mean Sum (per image): {torch.sum(tensor).item()/(tensor.shape[0]*tensor.shape[1])} // Sum (a single image): {torch.sum(tensor[0,0,:])}')

    num_rows = len(image_tensors)
    num_columns = len(image_tensors[0])
    # Combine tensors into one & clamp #
    combined_tensor = torch.cat(image_tensors, dim=0).detach().cpu()
    combined_tensor = torch.clamp(combined_tensor, min=0)
    # Make a grid of the tensors #
    image_grid = make_grid(combined_tensor, nrow=num_columns) # Note: nrow is the number of images displayed in each row (i.e., the number of columns)

    # Determine figure size #
    print('num_rows:', num_rows)
    fig, ax = plt.subplots(1,1, figsize=(30,1*num_rows))
    #fig, ax = plt.subplots(1,1, figsize=(30,7))

    ax.axis('off')

    image_grid = image_grid[0,:].squeeze()
    im = ax.imshow(image_grid, cmap=cmap)
    fig.colorbar(im, ax=ax)
    plt.show()

## Metrics

In [None]:
##################################################
## Functions for Calculating Metrics Dataframes ##
##################################################

## Calculate Arbitrary Metric ##
def calculate_metric(batch_A, batch_B, img_metric_function, return_dataframe=False, label='default', crop_factor=1):
    '''
    Function which calculates metric values for two batches of images.
    Returns either the average metric value for the batch or a dataframe of individual image metric values.

    batch_A:                tensor of images to compare [num, chan, height, width]
    batch_B:                tensor of images to compare [num, chan, height, width]
    img_metric_function:    a function which calculates a metric (MSE, SSIM, etc.) from two INDIVIDUAL images
    return_dataframe:       If False, then the average is returned.
                            Otherwise both the average, and a dataframe containing the metric values of the images in the batches, are returned.
    label:                  what to call dataframe, if it is created
    crop_factor:            factor by which to crop both batches of images. 1 = whole image is retained.
    '''

    if crop_factor != 1:
        A = crop_image_tensor_by_factor(batch_A, crop_factor=crop_factor)
        B = crop_image_tensor_by_factor(batch_B, crop_factor=crop_factor)

    length = len(batch_A)
    metric_avg = 0
    metric_list = []

    for i in range(length):
        image_A = batch_A[i:i+1,:,:,:] # Using i:i+1 instead of just i preserves the dimensionality of the array
        image_B = batch_B[i:i+1,:,:,:]

        metric_value = img_metric_function(image_A, image_B)
        metric_avg += metric_value/length
        if return_dataframe==True:
            metric_list.append(metric_value)

    metric_frame = pd.DataFrame({label : metric_list})

    if return_dataframe==False:
        return metric_avg
    else:
        return metric_frame, metric_avg


def update_tune_dataframe(tune_dataframe, model, config, mean_CNN_MSE, mean_CNN_SSIM, mean_CNN_CUSTOM):
    '''
    Function to update the tune_dataframe for each trial run that makes it partway through the tuning process.

    tune_dataframe      a dataframe that stores model and IQA metric information for a particular trial
    model               model being trained (in tuning)
    config              configuration dictionary
    mean_CNN_MSE        mean MSE for the CNN
    mean_CNN_SSIM       mean SSIM for the CNN
    mean_CNN_CUSTOM     mean custom metric for the CNN

    '''
    # Extract values from config dictionary
    SI_dropout =        config['SI_dropout']
    SI_exp_kernel =     config['SI_exp_kernel']
    SI_gen_fill =       config['SI_gen_fill']
    SI_gen_hidden_dim = config['SI_gen_hidden_dim']
    SI_gen_neck =       config['SI_gen_neck']
    SI_layer_norm =     config['SI_layer_norm']
    SI_normalize =      config['SI_normalize']
    SI_pad_mode =       config['SI_pad_mode']
    batch_size =        config['batch_size']
    gen_lr =            config['gen_lr']

    # Calculate number of trainable weights in CNN
    num_params = sum(map(torch.numel, model.parameters()))

    # Concatenate Dataframe
    add_frame = pd.DataFrame({'SI_dropout': SI_dropout, 'SI_exp_kernel': SI_exp_kernel, 'SI_gen_fill': SI_gen_fill, 'SI_gen_hidden_dim': SI_gen_hidden_dim,
                            'SI_gen_neck': SI_gen_neck, 'SI_layer_norm': SI_layer_norm, 'SI_normalize': SI_normalize, 'SI_pad_mode': SI_pad_mode, 'batch_size': batch_size,
                            'gen_lr': gen_lr, 'num_params': num_params, 'mean_CNN_MSE': mean_CNN_MSE, 'mean_CNN_SSIM': mean_CNN_SSIM, 'mean_CNN_CUSTOM': mean_CNN_CUSTOM}, index=[0])

    tune_dataframe = pd.concat([tune_dataframe, add_frame], axis=0)

    # Save Dataframe to File
    tune_dataframe.to_csv(tune_dataframe_path, index=False)

    return tune_dataframe


def reconstruct_images_and_update_test_dataframe(sino_tensor, image_size, CNN_output, ground_image, test_dataframe, config):
    '''
    Function which: A) performs reconstructions (FBP and possibly ML-EM)
                    B) constructs a dataframe of metric values (MSE & SSIM) for these reconstructions, and also for the CNN output, with respect to the ground truth image.
                    C) concatenates this with the test dataframe passed to this function
                    D) returns the concatenated dataframe, mean metric values, and reconstructions

    sino_tensor:    sinogram tensor of shape [num, chan, height, width]
    image_size:     image_size
    CNN_output:     CNN reconstructions
    ground_image:   ground truth images
    test_dataframe: dataframe to append metric values to
    config:         general config dictionary
    '''

    # Construct Outputs #
    FBP_output = reconstruct(sino_tensor, config, image_size=image_size, recon_type='FBP')
    if compute_MLEM==True:
        MLEM_output = reconstruct(sino_tensor, config, image_size=image_size, recon_type='MLEM')
    else: # If not looking at ML-EM, don't waste time computing the MLEM images, which can take awhile.
        MLEM_output = FBP_output

    # Dataframes: build dataframes for every reconstruction technique/metric combination #
    batch_CNN_MSE,  mean_CNN_MSE   = calculate_metric(ground_image, CNN_output, MSE,  return_dataframe=True, label='MSE (Network)')
    batch_CNN_SSIM,  mean_CNN_SSIM = calculate_metric(ground_image, CNN_output, SSIM, return_dataframe=True, label='SSIM (Network)')
    batch_FBP_MSE,  mean_FBP_MSE   = calculate_metric(ground_image, FBP_output, MSE,  return_dataframe=True, label='MSE (FBP)')
    batch_FBP_SSIM,  mean_FBP_SSIM = calculate_metric(ground_image, FBP_output, SSIM, return_dataframe=True, label='SSIM (FBP)')
    batch_MLEM_MSE, mean_MLEM_MSE  = calculate_metric(ground_image, MLEM_output, MSE, return_dataframe=True, label='MSE (ML-EM)')
    batch_MLEM_SSIM, mean_MLEM_SSIM= calculate_metric(ground_image, MLEM_output, SSIM,return_dataframe=True, label='SSIM (ML-EM)')

    # Concatenate batch dataframes and larger running test dataframe
    add_frame = pd.concat([batch_CNN_MSE, batch_FBP_MSE, batch_MLEM_MSE, batch_CNN_SSIM, batch_FBP_SSIM, batch_MLEM_SSIM], axis=1)
    test_dataframe = pd.concat([test_dataframe, add_frame], axis=0)

    # Return a whole lot of stuff
    return test_dataframe, mean_CNN_MSE, mean_CNN_SSIM, mean_FBP_MSE, mean_FBP_SSIM, mean_MLEM_MSE, mean_MLEM_SSIM, FBP_output, MLEM_output

######################
## Metric Functions ##
######################

## Metrics which take only single images as inputs ##
## ----------------------------------------------- ##
def SSIM(image_A, image_B, win_size=-1):
    '''
    Function to return the SSIM for two 2D images.

    image_A:        pytorch tensor for a single image
    image_B:        pytorch tensor for a single image
    win_size:       window size to use when computing the SSIM. This must be an odd number. If =-1, the full size of the image is used (or full size-1 so it's odd).
    '''

    if win_size == -1:   # The default shape of the window size is the same size as the image.
        x = image_A.shape[2]
        win_size = (x if x % 2 == 1 else x-1) # Guarantees the window size is odd.

    image_A_npy = image_A.detach().squeeze().cpu().numpy()
    image_B_npy = image_B.detach().squeeze().cpu().numpy()

    max_value = max([np.amax(image_A_npy, axis=(0,1)), np.amax(image_B_npy, axis=(0,1))])   # Find maximum among the images
    min_value = min([np.amin(image_A_npy, axis=(0,1)), np.amin(image_B_npy, axis=(0,1))])   # Find minimum among the images
    data_range = max_value-min_value

    SSIM_image = structural_similarity(image_A_npy, image_B_npy, data_range=data_range, gaussian_weights=False, use_sample_covariance=False, win_size=win_size)

    return SSIM_image

## Metrics which take either batches or images as inputs ##
## ----------------------------------------------------- ##
def MSE(image_A, image_B):
    '''
    Function to return the mean square error for two 2D images (or two batches of images).

    image_A:        pytorch tensor for a single image
    image_B:        pytorch tensor for a single image
    '''
    image_A_npy = image_A.detach().squeeze().cpu().numpy()
    image_B_npy = image_B.detach().squeeze().cpu().numpy()

    return torch.mean((image_A-image_B)**2).item()

def NMSE(image_A, image_B):
    '''
    Function to return the normalized mean square error for two 2D images (or two batches of images).

    image_A:        pytorch tensor for a single image (reference image)
    image_B:        pytorch tensor for a single image
    '''
    image_A_npy = image_A.detach().squeeze().cpu().numpy()
    image_B_npy = image_B.detach().squeeze().cpu().numpy()

    return (torch.mean((image_A-image_B)**2)/torch.mean(image_A**2)).item()

def MAE(image_A, image_B):
    '''
    Function to return the mean absolute error for two 2D images (or two batches of images).

    image_A:        pytorch tensor for a single image
    image_B:        pytorch tensor for a single image
    '''
    image_A_npy = image_A.detach().squeeze().cpu().numpy()
    image_B_npy = image_B.detach().squeeze().cpu().numpy()

    return torch.mean(torch.abs(image_A-image_B)).item()

def calculate_moments(batch_A, batch_B, window_size = 10, stride=10, dataframe=False):
    '''
    Function to return the three statistical moment scores for two image tensors.
    '''
    ## Nested Functions ##

    def compare_moments(win_A, win_B, moment):
        def compute_moment(win, moment, axis=1):
            mean_value = np.mean(win, axis=axis)
            if moment == 1:
                return mean_value
            else:
                mean_array = np.array([mean_value] * win.shape[1]).T  # The square brackets in win.shape[1] mean the value is repeated spatially
                moment = np.mean((win - mean_array)**moment, axis=1)
                return moment

        batch_size = win_A.shape[0]


        reshape_A = (torch.reshape(win_A, (batch_size, -1))).detach().cpu().numpy()
        reshape_B = (torch.reshape(win_B, (batch_size, -1))).detach().cpu().numpy()

        moment_A = compute_moment(reshape_A, moment=moment)
        moment_B = compute_moment(reshape_B, moment=moment)
        moment_score = np.mean(np.absolute(moment_A-moment_B)/(np.absolute(moment_A)+0.1))

        '''
        print('===============================')
        print('MOMENT: ', moment)
        print('moment_A shape: ', moment_A.shape)
        print('moment_A mean: ', np.mean(moment_A))
        print('moment_B shape: ', moment_B.shape)
        print('moment_B mean: ', np.mean(moment_B))
        print('moment_score, |moment_A-moment_B|/(moment_A+0.1) : ', moment_score)
        '''
        return moment_score

    ## Code ##
    image_size = batch_A.shape[2]

    num_windows = int((image_size)/stride) # Maximum number of windows occurs when: stride = window_size.
    while (num_windows-1)*stride + window_size > image_size: # Solve for the number of windows (crops)
        num_windows += -1

    moment_1_running_score = 0
    moment_2_running_score = 0
    moment_3_running_score = 0

    for i in range(0, num_windows):
        for j in range(0, num_windows):
            corner = (i*stride, j*stride)

            win_A = crop_image_tensor_with_corner(batch_A, window_size, corner)
            win_B = crop_image_tensor_with_corner(batch_B, window_size, corner)

            moment_1_score = compare_moments(win_A, win_B, moment=1)
            moment_2_score = compare_moments(win_A, win_B, moment=2)
            moment_3_score = compare_moments(win_A, win_B, moment=3)

            moment_1_running_score += moment_1_score
            moment_2_running_score += moment_2_score
            moment_3_running_score += moment_3_score

    return moment_1_running_score, moment_2_running_score, moment_3_running_score

def LDM(batch_A, batch_B):
    '''
    Calculate the local distributions metric (LDM) for two batches of images
    '''

    score_1, score_2, score_3 = calculate_moments(batch_A, batch_B, window_size=5, stride=5)

    score_1 = score_1*1
    score_2 = score_2*1
    score_3 = score_3*1

    '''
    print('Scores')
    print('====================')
    print(score_1)
    print(score_2)
    print(score_3)
    '''

    return score_1+score_2+score_3

def custom_metric(batch_A, batch_B):
    return 0
    #return MSE(batch_A, batch_B)



###############################################
## Average or a Batch Metrics: Good for GANs ##
###############################################

# Range #
def range_metric(real, fake):
    '''
    Computes a simple metric which penalizes "fake" images in a batch for having a range different than the "real" images in a batch.
    Only a single metric number is returned.
    '''
    range_real = torch.max(real).item()-torch.min(real).item()
    range_fake = torch.max(fake).item()-torch.min(fake).item()

    return abs(range_real-range_fake)/(range_real+.1)

# Average #
def avg_metric(real, fake):
    '''
    Computes a simple metric which penalizes "fake" images in a batch for having an average value different than the "real" images in a batch.
    Only a single metric number is returned.
    '''
    avg_metric = abs((torch.mean(real).item()-torch.mean(fake).item())/(torch.mean(real)+.1).item())
    return avg_metric

# Pixel Variation #
def pixel_dist_metric(real, fake):
    '''
    Computes a metric which penalizes "fake" images for having a pixel distance different than the "real" images.

    real: real image tensor
    fake: fake image tensor
    '''
    def pixel_dist(image_tensor):
        '''
        Function for computing the pixel distance (standard deviation from mean) for a batch of images.
        For simplicity, it only looks at the 0th channel.
        '''
        array = image_tensor[:,0,:,:].detach().cpu().numpy().squeeze()
        sd = np.std(array, axis=0)
        avg=np.mean(sd)
        return(avg)

    pix_dist_fake = pixel_dist(fake)
    pix_dist_real = pixel_dist(real)

    return abs((pix_dist_real-pix_dist_fake)/(pix_dist_real+.1)) # The +0.1 in the denominators guarantees we don't divide by zero

###################
## Old Functions ##
###################


def LDM_OLD(real, fake, crop_size = 10, stride=10):
    '''
    Function to return the local distributions metric for two images.

    image_A:        pytorch tensor for a single image
    image_B:        pytorch tensor for a single image
    '''
    image_size = real.shape[2]

    i_max = int((image_size)/stride) # Maximum number of windows occurs when the stride equals the crop_size
    while (i_max-1)*stride + crop_size > image_size: # If stride < crop_size, we need fewer need to solve for the number of crops
        i_max += -1

    def crop_image_tensor_with_corner(A, corner=(0,0), crop_size=1):
        '''
        Function which returns a small, cropped version of an image.

        A           a batch of images with dimensions: (num_images, channel, height, width)
        corner      upper-left corner of window
        crop_size   size of croppiong window
        '''
        x_min = corner[1]
        x_max = corner[1]+crop_size
        y_min = corner[0]
        y_max = corner[0]+crop_size
        return A[:,:, y_min:y_max , x_min:x_max ]

    running_dist_score = 0
    running_avg_score = 0

    for i in range(0, i_max):
        for j in range(0, j_max):
            corner = (i*crop_size, j*crop_size)
            win_real = crop_image_tensor_with_corner(real, corner, crop_size)
            win_fake = crop_image_tensor_with_corner(fake, corner, crop_size)

            #range_score = range_metric(win_real, win_fake)
            avg_score = avg_metric(win_real, win_fake)
            pixel_dist_score = pixel_dist_metric(win_real, win_fake)

            running_dist_score += pixel_dist_score
            running_avg_score += avg_score

    combined_score = running_dist_score + running_avg_score

    return combined_score


## Loss Functions

In [None]:
def get_supervisory_loss(fake_X, real_X, sup_criterion):
    '''
    Function to calculate the supervisory loss.

    fake_X:         fake image tensor (Terminology from GANs. For supervisory networks, it's arbitrary whether fake_X or real_X are ground truths or reconstructions)
    real_X:         real image tensor
    sup_criterion   loss function. Will be a Pytorch object.
    '''
    #print('Calc supervisory loss')
    sup_loss = sup_criterion(fake_X, real_X)
    return sup_loss

def get_disc_loss(fake_X, real_X, disc_X, adv_criterion):
    '''
    Function to calculate the discriminator loss. Used to train the discriminator.
    '''
    disc_fake_pred = disc_X(fake_X.detach()) # Detach generator from fake batch
    disc_fake_loss = adv_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) # Good fakes shoudl yield predictions = 0.
    disc_real_pred = disc_X(real_X)
    disc_real_loss = adv_criterion(disc_real_pred, torch.ones_like(disc_real_pred)) # Good fakes shoudl yield predictions = 1.
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    return disc_loss

def get_gen_adversarial_loss(real_X, gen_XY, disc_Y, adv_criterion):
    '''
    Function to calculate the adversarial loss (for gen_XY) and fake_Y (from real_X).
    '''
    fake_Y = gen_XY(real_X)
    disc_fake_pred = disc_Y(fake_Y)
    adversarial_loss = adv_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) # generator is penalized for discriminmator getting it right
    return adversarial_loss, fake_Y

def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion):
    '''
    Function to calculate the cycle-consistency loss (for gen_YX).
    '''
    cycle_X = gen_YX(fake_Y)
    cycle_loss = cycle_criterion(cycle_X, real_X)
    return cycle_loss, cycle_X

def get_gen_loss(real_A, real_B, gen_AB, gen_BA, disc_A, disc_B, config):
    '''
    Function to calculate the total generator loss. Used to train the generators.
    '''
    supervisory_criterion = config['sup_criterion']
    cycle_criterion = config['cycle_criterion']
    gen_adversarial_criterion = config['gen_adv_criterion']
    lambda_adv = config['lambda_adv']
    lambda_sup = config['lambda_sup']
    lambda_cycle = config['lambda_cycle']

    # Adversarial Loss
    if lambda_adv != 0: # To save resources, we only run this code if lambda_adv != 0
        adv_loss_AB, fake_B = get_gen_adversarial_loss(real_A, gen_AB, disc_B, gen_adversarial_criterion)
        adv_loss_BA, fake_A = get_gen_adversarial_loss(real_B, gen_BA, disc_A, gen_adversarial_criterion)
        adv_loss = adv_loss_AB+adv_loss_BA
    else: # Even if we don't compute adversarial losses, we still need fake_A and fake_B for later code
        fake_A = gen_BA(real_B)
        fake_B = gen_AB(real_A)

    # Supervisory Loss
    if lambda_sup != 0: # To save resources, we only run this code if lambda_sup != 0
        sup_loss_AB = get_supervisory_loss(fake_B, real_B, supervisory_criterion)
        sup_loss_BA = get_supervisory_loss(fake_A, real_A, supervisory_criterion)
        sup_loss = sup_loss_AB+sup_loss_BA

    # Cycle-consistency Loss -- get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion)
    cycle_loss_AB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion)
    cycle_loss_BA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion)
    cycle_loss = cycle_loss_AB+cycle_loss_BA

    # Total Generator Loss
    if lambda_sup == 0:
        gen_loss = lambda_adv*adv_loss+lambda_cycle*cycle_loss
        return gen_loss, adv_loss.item(), 0, cycle_loss.item(), cycle_A, cycle_B
    elif lambda_adv == 0:
        gen_loss = lambda_sup*sup_loss+lambda_cycle*cycle_loss
        return gen_loss, 0, sup_loss.item(), cycle_loss.item(), cycle_A, cycle_B
    else:
        gen_loss = lambda_adv*adv_loss+lambda_sup*sup_loss+lambda_cycle*cycle_loss
        return gen_loss, adv_loss.item(), sup_loss.item(), cycle_loss.item(), cycle_A, cycle_B

### Functons for Assymmetric/Separate (Older) ###
'''
def get_gen_adv_loss(fake_X, disc_X, adv_criterion):
    print('Calc generative adversarial loss')
    disc_fake_pred = disc_X(fake_X)
    adversarial_loss = adv_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) # Called only from get_gen_loss
    return adversarial_loss

def get_sup_loss(fake_X, real_X, sup_criterion):
    print('Calc supervisory loss')
    sup_loss = sup_criterion(fake_X, real_X)
    return sup_loss

def get_cycle_loss(fake_I, gen_IS, low_rez_S, cycle_criterion):
    print('Calc cycle loss')
    cycle_S = gen_IS(fake_I)
    cycle_loss = cycle_criterion(cycle_S, low_rez_S)
'''

"\ndef get_gen_adv_loss(fake_X, disc_X, adv_criterion):\n    print('Calc generative adversarial loss')\n    disc_fake_pred = disc_X(fake_X)\n    adversarial_loss = adv_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) # Called only from get_gen_loss\n    return adversarial_loss\n\ndef get_sup_loss(fake_X, real_X, sup_criterion):\n    print('Calc supervisory loss')\n    sup_loss = sup_criterion(fake_X, real_X)\n    return sup_loss\n\ndef get_cycle_loss(fake_I, gen_IS, low_rez_S, cycle_criterion):\n    print('Calc cycle loss')\n    cycle_S = gen_IS(fake_I)\n    cycle_loss = cycle_criterion(cycle_S, low_rez_S)\n"

## Timing

In [None]:
def display_times(label_string, init_time, show_times):
    '''
    Function to display the time it takes to perform individual steps in the code. This can be helpful when trying to streamline things. Also returns the current time, which can be used to reset the timer in the code that calls the function.

    init_time:      initiation time when the process started
    label_string:   string to label the displayed time
    show_times:     show times or not
    '''
    current_time = time.time()

    if show_times == True:
        print(f'{label_string} (ms): {(current_time-init_time)*1000}')

    return current_time

# Tune/Train/Test Functions

## SUP Loss Only

In [None]:
def train_Supervisory_Sym(config, offset=0, num_examples=-1, sample_division=1):
    '''
    Function to train or test a network with supervisory loss only. Also used for visualizing data in the dataset.
    '''
    print('Dataset offset:', offset)
    print('Dataset num_examples:', num_examples)
    print('Dataset sample_division:', sample_division)

    ############################
    ### Initialize Variables ###
    ############################

    ## Grab some values and assign to local variables ##
    sup_criterion=config['sup_criterion']
    scale=config['SI_scale'] if train_SI==True else config['IS_scale']

    ## If Tuning ##
    if run_mode=='tune':
        batch_size=config['batch_size']   #config['batch_size']=tune.choice([32, 64, 128, 256, 512, 1024])
        batch_mult = 512/batch_size if tune_even_reporting == True else 1
        display_step = tune_iter_per_report*batch_mult # Larger batch size --> fewer training iterations per report to RayTune

        if tune_restore==False:
            tune_dataframe = pd.DataFrame({'SI_dropout': [], 'SI_exp_kernel': [], 'SI_gen_fill': [], 'SI_gen_hidden_dim': [], 'SI_gen_neck': [], 'SI_layer_norm': [], 'SI_normalize': [],'SI_pad_mode': [], 'batch_size': [], 'gen_lr': [], 'num_params': [], 'mean_CNN_MSE': [], 'mean_CNN_SSIM': [], 'mean_CNN_CUSTOM': []})
            tune_dataframe.to_csv(tune_dataframe_path, index=False)
        else:
            tune_dataframe = pd.read_csv(tune_dataframe_path)

    ## If Training ##
    elif run_mode=='train':
        batch_size=config['batch_size']
        display_step = train_display_step

    ## If Testing ##
    elif run_mode=='test':
        batch_size = config['batch_size'] = test_batch_size  # If we don't override the batch size in the config dictionary, the same batch size will be used as was used to train the network. Therefore, we override it.
        display_step = test_display_step
        test_dataframe = pd.DataFrame({'MSE (Network)' : [],  'MSE (FBP)': [],  'MSE (ML-EM)': [],'SSIM (Network)' : [], 'SSIM (FBP)': [], 'SSIM (ML-EM)': []})

    ## If Visualizeing ##
    elif run_mode=='visualize':
        batch_size = config['batch_size'] = visualize_batch_size # If we don't override the batch size in the config dictionary, the same batch size will be used as was used to train the network. Therefore, we override it.
        display_step = 1

    ## Define running variables ##
    mean_gen_loss = 0; mean_CNN_SSIM = 0 ; mean_CNN_MSE = 0 ; mean_CNN_CUSTOM = 0; report_num = 1  # First report to RayTune is report_num=1.

    ###########################
    ### Instantiate Classes ###
    ###########################

    # Generator #
    if train_SI==True:
        gen =  Generator(config=config, gen_SI=True,  input_size=sino_size, input_channels=sino_channels,  output_channels=image_channels).to(device)
    else:
        gen =  Generator(config=config, gen_SI=False, input_size=image_size, input_channels=image_channels, output_channels=sino_channels ).to(device)

    # Optimizer #
    gen_opt = torch.optim.Adam(gen.parameters(), lr=config['gen_lr'], betas=(config['gen_b1'], config['gen_b2']))

    # Dataloader #
    dataloader = DataLoader(
        NpArrayDataSet(image_path=image_path, sino_path=sino_path, config=config, image_size=image_size, image_channels=image_channels,
                       sino_size=sino_size, sino_channels=sino_channels, augment=augment, offset=offset, num_examples=num_examples, sample_division=sample_division),
        batch_size=batch_size,
        shuffle=shuffle
    )


    ##############################
    ### Set Initial Conditions ###
    ##############################

    ## If loading checkpoint (training, testing or visualizing). For tuning, load_state=False (always). ##
    if load_state==True:
        checkpoint = torch.load(checkpoint_path) # checkpoint is a dictionary of dictionaries
        gen.load_state_dict(checkpoint['gen_state_dict'])
        gen_opt.load_state_dict(checkpoint['gen_opt_state_dict'])

        # If testing or visualizing, start from the beginning #
        if run_mode=='test' or run_mode=='visualize':
            gen.eval()  # Evaluation mode-->don't run backprojection
            start_epoch=0; end_epoch=1; batch_step = 0

        # If training, pick up where we left off #
        elif run_mode=='train':
            start_epoch = checkpoint['epoch'] # Note: if interrupted, this epoch may be trained more than once
            end_epoch = training_epochs
            batch_step = checkpoint['batch_step'] # Note: because training is done with shuffling (unless you alter it), stopping partway through a training epoch will result in the network seeing some training examples more than once, and some not at all.

    ## If starting from scratch ##
    else:
        gen = gen.apply(weights_init)
        start_epoch=0 ; batch_step = 0
        end_epoch=num_epochs  # =1000 for tuning (Ray Tune terminates before you hit this), =training_epochs for training, =1 for testing or visualizing


    ## Initialize timestamps to keep track of calculation times ##
    time_init_full = time.time()   # This is reset at the display time so that the full step time is displayed (see below).
    time_init_loader = time.time()  # This is reset at the display time, but also reset at the end of the inner "for loop", so that only displays the data loading time.

    ########################
    ### Loop over Epochs ###
    ########################

    ### Loop over Epochs ###
    for epoch in range(start_epoch, end_epoch):

        #########################
        ### Loop Over Batches ###
        #########################
        for sino_ground, sino_ground_scaled, image_ground, image_ground_scaled in iter(dataloader): # Dataloader returns the batches. Loop over batches within epochs.

            # Show times #
            current_time = display_times('loader time', time_init_loader, show_times) # current_time is a dummy variable that isn't used in this loop
            time_init_full = display_times('FULL STEP TIME', time_init_full, show_times) # This step resets time_init_full after displaying the time so this displays the full time to fun the loop over a batch.

            # Assign inputs and targets #
            if train_SI==True:
                target=image_ground_scaled
                input=sino_ground_scaled
            else:
                target=sino_ground_scaled
                input=image_ground_scaled

            #######################
            ## Calculate Outputs ##
            #######################

            ## If Tuning or Training, train one step ##
            if run_mode=='tune' or run_mode=='train':
                time_init_train = time.time() # Initialize timestamp for training duration

                gen_opt.zero_grad()
                CNN_output = gen(input)

                if run_mode=='train' and torch.sum(CNN_output[1,0,:]) < 0: # Let's you know if the network starts outputing predominantly negative values.
                    print('PIXEL VALUES SUM TO A NEGATIVE NUMBER. IF THIS CONTINUES FOR AWHILE, YOU MAY NEED TO RESTART')

                # Update gradients
                gen_loss = sup_criterion(CNN_output, target)
                gen_loss.backward()
                gen_opt.step()
                # Keep track of the average generator loss
                mean_gen_loss += gen_loss.item() / display_step

                current_time = display_times('training time', time_init_train, show_times)

            ## If Testing or Vizualizing, calculate output only ##
            else:
                CNN_output=gen(input).detach()

            # Increment batch_step
            batch_step += 1

            ####################################
            ### Run-Type Specific Operations ###
            ####################################
            time_init_metrics=time.time()


            ## If Tuning or Training ##
            # We only calculate the mean value of the metrics, but not dataframes or reconstructions. Mean values are used to calculate the optimization metrics #
            if (run_mode == 'tune') or (run_mode=='train'):

                mean_CNN_SSIM += calculate_metric(target, CNN_output, SSIM)/ display_step # The SSIM function can only take single images as inputs, not batches, so we use a wrapper function and pass batches to it.
                mean_CNN_MSE +=  calculate_metric(target, CNN_output, MSE) / display_step # The MSE function can take either single images or batches. We use the wrapper for consistency.

                time_init_custom=time.time()
                # Custom metrics can take a long time to calculate, so we don't use a wrapper (which would loop through individual images in calculations.)
                mean_CNN_CUSTOM += custom_metric(target, CNN_output) / display_step
                current_time = display_times('Custom metric time', time_init_custom, show_times)

            ## If Testing ##
            # We reconstruct images and we calculate metric dataframes #
            if run_mode == 'test':
                test_dataframe, mean_CNN_MSE, mean_CNN_SSIM, mean_FBP_MSE, mean_FBP_SSIM, mean_MLEM_MSE, mean_MLEM_SSIM, FBP_output, MLEM_output =  reconstruct_images_and_update_test_dataframe(
                    input, image_size, CNN_output, image_ground_scaled, test_dataframe, config)

            ## If Visualizing ##
            if run_mode=='visualize':
                # We calculate reconstructions but not metric values. #
                FBP_output =  reconstruct(input, config, image_size=image_size, recon_type='FBP')
                MLEM_output = reconstruct(input, config, image_size=image_size, recon_type='MLEM')


            # Show metric calculation time #
            current_time = display_times('metrics time', time_init_metrics, show_times)

            ######################################
            ### VISUALIZATION / REPORTING CODE ###
            ######################################

            if batch_step % display_step == 0 # and (batch_step > 0 or run_mode != 'tune'):

                time_init_visualization=time.time()

                example_num = batch_step*batch_size

                ## If Tuning ##
                if run_mode=='tune':

                    session.report({'MSE':mean_CNN_MSE, 'SSIM':mean_CNN_SSIM, 'CUSTOM':mean_CNN_CUSTOM, 'example_number': example_num, 'batch_step':batch_step, 'epoch':epoch}) # Report to RayTune multiple times per trial

                    if int(tune_dataframe_fraction*tune_max_t) == report_num: # We only update tune_dataframe once per trial
                        tune_dataframe = update_tune_dataframe(tune_dataframe, gen, config, mean_CNN_MSE, mean_CNN_SSIM, mean_CNN_CUSTOM)

                    report_num +=1

                ## If Training ##
                if run_mode == 'train':
                    # Display Batch Metrics #
                    print('================Training===================')
                    print(f'CURRENT PROGRESS: epoch: {epoch} / batch_step: {batch_step} / image #: {example_num}')
                    print(f'mean_gen_loss:', mean_gen_loss)
                    print(f'mean_CNN_MSE :', mean_CNN_MSE)
                    print(f'mean_CNN_SSIM:', mean_CNN_SSIM)
                    print(f'mean-CNN_CUSTOM', mean_CNN_CUSTOM)
                    print('===========================================')
                    print('Last Batch MSE: ', calculate_metric(target, CNN_output, MSE))
                    print('Last Batch SSIM: ', calculate_metric(target, CNN_output, SSIM))

                    # Display Inputs & Reconstructions#
                    print('Input:')
                    show_single_unmatched_tensor(input[0:9])
                    print('Target/Output:')
                    show_multiple_matched_tensors(target[0:9], CNN_output[0:9])

                ## If Testing ##
                if run_mode == 'test':
                    # Display Batch Metrics #
                    print('==================Testing==================')
                    print(f'mean_CNN_MSE/mean_MLEM_MSE/mean_FBP_MSE : {mean_CNN_MSE}/{mean_MLEM_MSE}/{mean_FBP_MSE}')
                    print(f'mean_CNN_SSIM/mean_MLEM_SSIM/mean_FBP_SSIM: {mean_CNN_SSIM}/{mean_MLEM_SSIM}/{mean_FBP_SSIM}')
                    print('===========================================')

                    # Display Inputs & Reconstructions #
                    print('Input')
                    show_single_unmatched_tensor(input[0:9])
                    print('Target/Output/MLEM/FBP:')
                    show_multiple_matched_tensors(target[0:9], CNN_output[0:9], MLEM_output[0:9], FBP_output[0:9])

                ## If Visualizing ##
                if run_mode == 'visualize':
                    if visualize_batch_size==120:
                        print(f'visualize_offset: {visualize_offset}, Image Number (batch_step*120): {batch_step*120}')
                        show_single_unmatched_tensor(target, grid=True, cmap='inferno', fig_size=1)
                    else:
                        print('Input:')
                        show_single_unmatched_tensor(input[0:visualize_batch_size])
                        print('Target/ML-EM/FBP/Output:')
                        show_multiple_matched_tensors(target[0:visualize_batch_size], MLEM_output[0:visualize_batch_size], FBP_output[0:visualize_batch_size], CNN_output[0:visualize_batch_size])


                # Save State -- This does not occur with every batch used in training so save resources #
                if save_state:
                    print('Saving model!')
                    torch.save({
                        'epoch': epoch,
                        'batch_step': batch_step,
                        'gen_state_dict': gen.state_dict(),
                        'gen_opt_state_dict': gen_opt.state_dict(),
                        }, checkpoint_path)

                # Zero running stats -- occurs once per visualization step #
                mean_gen_loss = 0 ; mean_CNN_SSIM = 0 ; mean_CNN_MSE = 0 ; mean_CNN_CUSTOM=0

                # Show visualization time #
                current_time = display_times('visualization time', time_init_visualization, show_times)


            # Time step to display loader time
            time_init_loader = time.time()


    ############################################
    ### Complete end of Train Function Tasks ###
    ############################################

    # Save Network State (Training) #
    if save_state:
        print('Saving model!')
        path = os.path.join(checkpoint_dir, checkpoint_file)
        torch.save({
            'epoch': epoch+1, # If we are saving after an epoch is completed, and we pick up training later, we have to start at the next epoch.
            'batch_step': batch_step,
            'gen_state_dict': gen.state_dict(),  # dictionary of dictionaries!
            'gen_opt_state_dict': gen_opt.state_dict(),
            }, path)

    # If testing, return dataframe #
    if run_mode=='test':
        return test_dataframe

## GAN / CYCLE

In [None]:
'''
Note: It makes no sense to "test" a GAN or use SSIM since there is nothing to compare it to. Therefore, this functionality is left out here.
Also, now that you've defined assigned the checkpoint_dir and test_dataframe_dir in the "User Parameters cell", you can get rid of the path constructions below.

'''
def train_test_GAN(config, checkpoint_dir=None, load_state=False, save_state=False):
    '''
    Note: Arguments are set to False/None to ensure that when RayTune calles train(), states are not saved/loaded
    Note: you may want to use 'model.train()' to put model back into training mode if you put it into eval mode at some point...
    '''
    print('Training GAN only!!')

    ## Grab from Config ##

    batch_size=config['batch_size']
    gen_adv_criterion=config['gen_adv_criterion']
    scale=config['SI_scale'] if train_SI==True else config['IS_scale']

    ## Tensorboard ##
    writer=SummaryWriter(tensorboard_dir)

    # Generators/Discriminators #

    ## These are the original networks, and work great with 71x71 images ##
    #disc = Disc_I_Orig(config=config).to(device)
    #gen =  Gen_SI_Orig(config=config).to(device)

    ## These are the modified networks, for 90x90, and also work great ##
    #disc = Disc_I_Orig_90(config=config).to(device)
    #gen = Gen_SI_Orig_90(config=config).to(device)

    if train_SI==True:
        ## Now let's try a flex generator and Gen_SI_Orig_90 discriminator ##
        disc_adv_criterion=config['SI_disc_adv_criterion']
        disc = Disc_I_90(config=config, input_channels=image_channels).to(device)
        gen =  Gen_90(config=config, gen_SI=True, input_channels=sino_channels, output_channels=image_channels).to(device)
        gen_opt = torch.optim.Adam(gen.parameters(), lr=config['gen_lr'], betas=(config['gen_b1'], config['gen_b2'])) #betas are optional inputs
        disc_opt = torch.optim.Adam(disc.parameters(), lr=config['SI_disc_lr'], betas=(config['SI_disc_b1'], config['SI_disc_b2']))
    else:
        disc_adv_criterion=config['IS_disc_adv_criterion']
        disc = Disc_S_90(config=config, input_channels=sino_channels).to(device)
        gen =  Gen_90(config=config, gen_SI=False, input_channels=image_channels, output_channels=sino_channels).to(device)
        gen_opt = torch.optim.Adam(gen.parameters(), lr=config['gen_lr'], betas=(config['gen_b1'], config['gen_b2'])) #betas are optional inputs
        disc_opt = torch.optim.Adam(disc.parameters(), lr=config['IS_disc_lr'], betas=(config['IS_disc_b1'], config['IS_disc_b2']))

    ## Load Data ##
    dataloader = DataLoader(
        NpArrayDataSet(image_path=image_path, sino_path=sino_path, config=config, resize_size=resize_size, image_channels=image_channels, sino_channels=sino_channels, offset=True),
        batch_size=batch_size,
        shuffle=shuffle
    )

    ## Load Checkpoint ##
    if checkpoint_dir and load_state:
        # Load dictionary
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
        checkpoint = torch.load(checkpoint_path)
        # Load values from dictionary
        start_epoch = checkpoint['epoch'] #If interrupted, this epoch may be trained more than once
        end_epoch = start_epoch + num_epochs
        batch_step = checkpoint['batch_step']
        gen.load_state_dict(checkpoint['gen_state_dict'])
        gen_opt.load_state_dict(checkpoint['gen_opt_state_dict'])
        disc.load_state_dict(checkpoint['disc_state_dict'])
        disc_opt.load_state_dict(checkpoint['disc_opt_state_dict'])
    else:
        print('Starting from scratch')
        start_epoch=0
        end_epoch=num_epochs
        batch_step = 0
        gen = gen.apply(weights_init)
        disc = disc.apply(weights_init) # Both gen & disc inherit nn.Module functionality (.apply())

    ## Loop Over Epochs ##
    for epoch in range(start_epoch, end_epoch):
        pix_dist_real_array = np.array([]) # Reset every epoch
        mean_gen_loss = 0  # Reset every display step, but I define it here so it's available later
        mean_disc_loss = 0 # Reset every display step
        mean_pix_metric = 0  # Reset every display step
        time_init_full = time.time()

        ## Loop Over Batches ##
        for sino, sino_ground_scaled, image, image_ground_scaled in iter(dataloader): # Dataloader returns the batches. Loop over batches within epochs.

            print(f'FULL step (time): {(time.time()-time_init_full)*1000}')
            time_init_full = time.time()

            if train_SI==True:
                real=image_ground_scaled
                noise=sino_ground_scaled
            else:
                real=sino_ground_scaled
                noise=image_ground_scaled

            #print(f'Real Type: {real.dtype}, Real Shape:  {real.shape}')
            #print(f'Noise Type: {noise.dtype}, Noise Shape:  {noise.shape}')
            #cur_batch_size = len(real)

            ## UPDATE DISCRIMINATOR ##
            disc_opt.zero_grad()                    # Zero gradients before every batch #
            disc_real_pred = disc(real)             # Predictions on Real Images #

            with torch.no_grad(): # We won't be optmizing generator here, so disabling gradients saves on resources
                fake = gen(noise)
            disc_fake_pred = disc(fake.detach())

            a = torch.ones_like(disc_real_pred)

            disc_real_loss = disc_adv_criterion(disc_real_pred, torch.ones_like(disc_real_pred))
            disc_fake_loss = disc_adv_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))

            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True) # retain_graph=True is set so that we can perform gradient calculations using "backward" twice:
                                                  # you need to compute gradients of discriminator in order to obtain gradients of generator, later.
                                                  # Otherwise, for performance reasons, you can't do this.
            disc_opt.step()

            # Keep track of the average discriminator loss
            mean_disc_loss += disc_loss.item() / display_step

            ## UPDATE GENERATOR ##
            gen_opt.zero_grad()
            # Generator adversarial loss
            disc_fake_pred = disc(gen(noise))
            gen_loss = gen_adv_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
            # Update gradients
            gen_loss.backward()
            gen_opt.step()
            # Keep track of the average generator loss
            mean_gen_loss += gen_loss.item() / display_step #gen_loss.item() reduces tensor to scalar. It updates loss per display step

            ## PIXEL DISTANCE METRIC ##
            pix_dist_fake = pixel_dist(fake)
            pix_dist_real = pixel_dist(real)
            pix_dist_real_array = np.append(pix_dist_real_array, pix_dist_real)
            pix_dist_real_avg = np.mean(pix_dist_real_array)
            #pix_dist_real_avg = 0.00029 # determined experimentally
            pix_metric = abs((pix_dist_real_avg-pix_dist_fake)/pix_dist_real_avg)
            mean_pix_metric += pix_metric / display_step

            ## visualization CODE ##
            if batch_step % display_step == 0 and batch_step > 0: # runs if batch_step is a multiple of the display step

                # Calculate Individual Loss Terms #
                loss_balance=abs(mean_gen_loss-mean_disc_loss)
                r_metric= range_metric(real, fake)
                a_metric= avg_metric(real, fake)

                # Metric Loss #
                optim_metric=0.5*loss_balance+mean_pix_metric+a_metric #+r_metric

                ## REPORT AND SAVE STATE ##
                # Report #
                if run_mode=='tune':
                    tune.report(batch_step=batch_step, epoch=epoch,
                                mean_gen_loss=mean_gen_loss, mean_disc_loss=mean_disc_loss, loss_balance=loss_balance,
                                range_metric=r_metric, avg_metric = a_metric, mean_pix_metric=mean_pix_metric, optim_metric=optim_metric
                                )
                else:
                    # Display Stats #
                    print(f'===========================================\nEPOCH: {epoch}, STEP: {batch_step}')

                    print(f'Real Image Batch Min: {torch.min(real)} // Max: {torch.max(real)} // Mean: {torch.mean(real)} // Sum: {torch.sum(real).item()}')
                    print(f'Fake Image Batch Min: {torch.min(fake)} // Max: {torch.max(fake)} // Mean: {torch.mean(fake)} // Sum: {torch.sum(fake).item()}')
                    print(f'mean_gen_loss: {mean_gen_loss} // mean_disc_loss: {mean_disc_loss}')
                    print(f'loss_balance: {loss_balance}')
                    print(f'mean_pixel_metric: {mean_pix_metric}')
                    print(f'range_metric: {r_metric}')
                    print(f'avg_metric: {a_metric}')
                    print(f'optim_metric: {optim_metric}')

                    # visualize Images #
                    print('Reals: ')
                    show_single_unmatched_tensor(real)
                    print('Fakes: ')
                    show_single_unmatched_tensor(fake)

                    writer.add_scalar('generator loss', mean_gen_loss, batch_step)
                    writer.add_scalar('discriminator loss', mean_disc_loss, batch_step)
                    writer.add_scalar('loss balance', loss_balance, batch_step)
                    writer.add_scalar('pixel distance loss', mean_pix_metric, batch_step)
                    #writer.add_image("real", make_grid(real_image_tensor[:25], nrow=5, normalize=True)) # [:num_images]=[0:num_images]
                    #writer.add_image("fake", make_grid(fake_image_tensor[:25], nrow=5, normalize=True))
                    writer.flush()

                # Save State #
                if checkpoint_dir and save_state:
                    path = os.path.join(checkpoint_dir, checkpoint_file)
                    torch.save({
                        'epoch': epoch,
                        'batch_step': batch_step,
                        'gen_state_dict': gen.state_dict(),
                        'gen_opt_state_dict': gen_opt.state_dict(),
                        'disc_state_dict': disc.state_dict(),
                        'disc_opt_state_dict': disc_opt.state_dict(),
                        }, path)

                # Zero Stats #
                mean_disc_loss = 0
                mean_gen_loss = 0
                mean_pix_metric = 0

    ## And the end of the epoch loop, we do a final save of the model ##
    if checkpoint_dir and save_state:
        path = os.path.join(checkpoint_dir, checkpoint_file)
        torch.save({
            'epoch': epoch,
            'batch_step': batch_step,
            'gen_state_dict': gen.state_dict(),
            'gen_opt_state_dict': gen_opt.state_dict(),
            'disc_state_dict': disc.state_dict(),
            'disc_opt_state_dict': disc_opt.state_dict(),
            }, path)

In [None]:
## Note: This function still needs to be updated for SSIM and testing with the test set. See 'START HERE' comment below.

def train_test_CYCLE(config, checkpoint_dir=None, load_state=False, save_state=False):
    '''
    Note: Arguments are set to None/False to ensure that when RayTune calles train(), states are not saved/loaded. This uses up way too much hard drive space.
    Note: you may want to use 'model.train()' to put model back into training mode if you put it into eval mode at some point...
    '''

    ## Grab Stuff from Config Dict. ##
    batch_size = config['batch_size']
    gen_b1 = config['gen_b1']
    gen_b2 = config['gen_b2']
    gen_lr = config['gen_lr']
    scale=config['SI_scale'] if train_SI==True else config['IS_scale']

    ## Tensorboard ##
    writer=SummaryWriter(tensorboard_dir)

    ## Initialize Generators/Discriminator/Summary Writer ##
    disc_I = Disc_S_90(config=config, input_channels=image_channels).to(device)
    disc_S = Disc_S_90(config=config, input_channels=sino_channels).to(device)
    gen_SI = Gen_90(config=config, gen_SI=True, input_channels=sino_channels, output_channels=image_channels).to(device)
    gen_IS = Gen_90(config=config, gen_SI=False, input_channels=image_channels, output_channels=sino_channels).to(device)

    gen_both_opt = torch.optim.Adam(list(gen_SI.parameters()) + list(gen_IS.parameters()), lr=gen_lr, betas=(gen_b1, gen_b2)) # Common optimizer
    disc_I_opt = torch.optim.Adam(disc_I.parameters(), lr=config['SI_disc_lr'], betas=(config['SI_disc_b1'], config['SI_disc_b2']))
    disc_S_opt = torch.optim.Adam(disc_S.parameters(), lr=config['IS_disc_lr'], betas=(config['IS_disc_b1'], config['IS_disc_b2']))

    ## Load Data ##
    dataloader = DataLoader(
        NpArrayDataSet(image_path=image_path, sino_path=sino_path, config=config, resize_size=resize_size, image_channels=image_channels, sino_channels=sino_channels),
        batch_size=batch_size,
        shuffle=shuffle
    )

    ## Load Checkpoint ##
    if checkpoint_dir and load_state:
        # Load dictionary
        checkpoint = torch.load(os.path.join(checkpoint_dir, checkpoint_file))
        # Load values from dictionary
        start_epoch = checkpoint['epoch'] #If interrupted, this epoch may be trained more than once
        end_epoch = start_epoch + num_epochs
        batch_step = checkpoint['batch_step']
        gen_SI.load_state_dict(checkpoint['gen_SI_state_dict'])
        gen_IS.load_state_dict(checkpoint['gen_IS_state_dict'])
        gen_both_opt.load_state_dict(checkpoint['gen_both_opt_state_dict'])
        disc_I.load_state_dict(checkpoint['disc_I_state_dict'])
        disc_S.load_state_dict(checkpoint['disc_S_state_dict'])
        disc_I_opt.load_state_dict(checkpoint['disc_I_opt_state_dict'])
        disc_S_opt.load_state_dict(checkpoint['disc_S_opt_state_dict'])
        if run_mode=='test':
            gen_SI.eval()
            gen_IS.eval()

    ## START HERE WITH UPDATING THIS FUNCION FOR SSIM AND TEST SET FUNCTIONALITY

    else:
        print('Starting from scratch')
        start_epoch=0
        end_epoch=num_epochs
        batch_step = 0
        gen_SI = gen_SI.apply(weights_init)
        gen_IS = gen_IS.apply(weights_init)
        disc_I = disc_I.apply(weights_init)
        disc_S = disc_S.apply(weights_init)

    ## Loop Over Epochs ##
    for epoch in range(start_epoch, end_epoch):

        # Following variables reset every display step. The line below only establishes these variables, it does not reset them.
        mean_disc_loss, mean_adv_loss, mean_sup_loss, mean_cycle_loss, mean_pix_metric, mean_range_metric, mean_avg_metric = 0,0,0,0,0,0,0

        ## Loop Over Batches ##

        time_init_full = time.time()
        #time_init_loader = time.time()

        for sino, sino_ground_scaled, image, image_ground_scaled in iter(dataloader): # Dataloader returns the batches. Loop over batches within epochs.

            #print(f'iter dataloader (time): {(time.time()-time_init_loader)*1000}')
            #print(f'FULL step (time): {(time.time()-time_init_full)*1000}')
            time_init_full = time.time()

            real_S = sino_ground_scaled
            real_I = image_ground_scaled

            ## Update Networks ##

            # Update Discriminators #
            # Image Discriminator #
            disc_I_opt.zero_grad() # Zero out the gradient before backpropagation
            with torch.no_grad(): # We won't be optmizing the generator here, so disabling gradients saves on resources
                fake_I = gen_SI(real_S)

            disc_I_loss = get_disc_loss(fake_I, real_I, disc_I, config['SI_disc_adv_criterion'])
            disc_I_loss.backward(retain_graph=True) # Update gradients
            disc_I_opt.step() # Update optimizer

            # Sinogram Discriminator #
            disc_S_opt.zero_grad() # Zero out the gradient before backpropagation
            with torch.no_grad(): # We won't be optmizing the generator here, so disabling gradients saves on resources
                fake_S = gen_IS(real_I)
            disc_S_loss = get_disc_loss(fake_S, real_S, disc_S, config['IS_disc_adv_criterion'])
            disc_S_loss.backward(retain_graph=True) # Update gradients
            disc_S_opt.step() # Update optimizer

            # Generators #
            gen_both_opt.zero_grad()
            gen_loss, adv_loss, sup_loss, cycle_loss, cycle_I, cycle_S = get_gen_loss(real_I, real_S, gen_IS, gen_SI, disc_I, disc_S, config)
            gen_loss.backward() # Update gradients
            gen_both_opt.step() # Update optimizer

            #print(f'update generator (time)): {(time.time()-time_init_gen)*1000}')

            ## Metrics ##
            # Pixel Distance #
            pix_metric_I = pixel_metric(real_I, fake_I)
            pix_metric_S = pixel_metric(real_S, fake_S)
            p_metric = pix_metric_I + pix_metric_S

            # Range Metric #
            range_metric_I = range_metric(real_I, fake_I)
            range_metric_S = range_metric(real_S, fake_S)
            r_metric = range_metric_I+range_metric_S

            # Average Metric #
            avg_metric_I = avg_metric(real_I, fake_I)
            avg_metric_S = avg_metric(real_S, fake_S)
            a_metric = avg_metric_I + avg_metric_S

            ## Running Statistics ##
            # Mean loss terms #
            mean_disc_loss    += (abs(disc_I_loss.item()) + abs(disc_S_loss.item())) / display_step
            mean_adv_loss     += abs(adv_loss) / display_step
            mean_sup_loss     += abs(sup_loss) / display_step
            mean_cycle_loss   += abs(cycle_loss) / display_step
            mean_pix_metric   += p_metric / display_step
            mean_range_metric += r_metric / display_step
            mean_avg_metric   += a_metric / display_step

            ## visualization CODE ##
            if batch_step % display_step == 1 and batch_step > 0: # runs if batch_step is a multiple of the display step

                # Optim_Metric #
                MS_Error = MSE(real_I, fake_I)
                loss_balance=abs(mean_adv_loss-mean_disc_loss)
                #optim_metric = 0.5*loss_balance+mean_cycle_loss+mean_pix_metric #+mean_avg_metric #+mean_range_metric
                optim_metric = MS_Error

                # Prune #
                #gen_SI = prune_gen(gen_SI)
                #gen_IS = prune_gen(gen_IS)

                ## Report  to Ray Tune ##
                if run_mode=='tune':
                    tune.report(batch_step=batch_step, epoch=epoch,
                                mean_adv_loss=mean_adv_loss, mean_disc_loss=mean_disc_loss, loss_balance=loss_balance,
                                mean_sup_loss=mean_sup_loss,
                                mean_cycle_loss=mean_cycle_loss,
                                mean_pix_metric=mean_pix_metric,
                                mean_avg_metric=mean_avg_metric,
                                optim_metric=optim_metric
                                )
                ## Display Stats & Images ##
                else:
                    print(f'================================================================================\nEPOCH: {epoch}, STEP: {batch_step}, Batch Size: {batch_size}')

                    lambda_adv, lambda_sup, lambda_cycle = config['lambda_adv'], config['lambda_sup'], config['lambda_cycle']

                    print(f'MSE (Images):  {MS_Error}')
                    print(f'lambda * Mean Adversarial Loss: {lambda_adv*mean_adv_loss}')
                    print(f'lambda * Mean Supervisory Loss: {lambda_sup*mean_sup_loss}')
                    print(f'lambda * Mean Cycle Loss      : {lambda_cycle*mean_cycle_loss}')
                    print(f'mean_disc_loss: {mean_disc_loss} // mean_adv_loss: {mean_adv_loss} // loss_balance (M) {loss_balance}')
                    print(f'mean_pix_metric (M): {mean_pix_metric}')
                    print(f'range_metric (M): {mean_range_metric}')
                    print(f'avg_metric: {mean_avg_metric}')
                    print(f'optim_metric: {optim_metric}')

                    ## visualize Images ##
                    # Images #
                    print('Ground Truth Images:')
                    show_single_unmatched_tensor(real_I)
                    print('Generated PET Images:')
                    show_single_unmatched_tensor(fake_I)
                    print('Cycle PET Images:')
                    show_single_unmatched_tensor(cycle_I)

                    # Sinograms #
                    print('Grount Truth Sinograms:')
                    show_single_unmatched_tensor(real_S) # low_rez_S = real
                    print('Generated Sinograms:')
                    show_single_unmatched_tensor(fake_S)
                    print('Cycle Sinograms:')
                    show_single_unmatched_tensor(cycle_S)

                    # Less interesting #
                    '''
                    print('Resized Model Images:')
                    show_single_unmatched_tensor(resized_I[0:9])
                    print('FBP, Full-Rez Sinograms, resized (90x90):')
                    show_single_unmatched_tensor(FBP_I[0:9])

                    print('Hi-Rez Sinograms:')
                    show_single_unmatched_tensor(high_rez_S)
                    print('Sinogram of Ground Truth Images:')
                    show_single_unmatched_tensor(project(ground_I))
                    print('Sinogram of Generated PET:')
                    show_single_unmatched_tensor(project(fake_I))
                    print('FBP, Low-Rez Sinograms:')
                    show_single_unmatched_tensor(reconstruct(low_rez_S))
                    '''

                    writer.add_scalar('mean adversarial loss', mean_adv_loss, batch_step)
                    writer.add_scalar('discriminator loss', mean_disc_loss, batch_step)
                    writer.add_scalar('loss balance', loss_balance, batch_step)
                    writer.add_scalar('pixel distance loss', mean_pix_metric, batch_step)
                    writer.add_scalar('cycle loss', mean_cycle_loss)
                    writer.add_scalar('supervisory loss (ground)', mean_sup_loss)
                    writer.flush()

                # Save State #
                if checkpoint_dir and save_state:
                    path = os.path.join(checkpoint_dir, checkpoint_file)
                    torch.save({
                        'epoch': epoch,
                        'batch_step': batch_step,
                        'gen_SI_state_dict': gen_SI.state_dict(),
                        'gen_IS_state_dict': gen_IS.state_dict(),
                        'gen_both_opt_state_dict': gen_both_opt.state_dict(),
                        'disc_I_state_dict': disc_I.state_dict(),
                        'disc_S_state_dict': disc_S.state_dict(),
                        'disc_I_opt_state_dict': disc_I_opt.state_dict(),
                        'disc_S_opt_state_dict': disc_S_opt.state_dict(),
                        }, path)

                # Zero Stats #
                mean_adv_loss = 0  # Should balance with mean_disc_loss (below)
                mean_disc_loss = 0 # Should balance with mean_adv_loss (above)
                mean_sup_loss_model = 0
                mean_sup_loss_ground = 0 #
                mean_cycle_loss = 0 # Better performing models will minimize this
                mean_pix_metric = 0 # Reasonable to minimize this for tuning purposes
                mean_range_metric=0
                mean_avg_metric=0

            batch_step += 1 #updates with every batch


            time_init_loader=time.time()
#call model.eval() before test set


## Tune Function

In [None]:
def Tune(tune_max_t=40, trainable='SUP', grace_period=1):
    '''
    This function is called to tune the "trainable" function, given:

    tune_max_t:     maximum number of time units (in this case, number of reports) per trial.
    grace_period:   minimum number of raytune reports to run before aborting a trial due to poor performance
    '''

    ## What am I tuning for? ##
    if tune_for=='MSE':     # Values for these metric labels are passed to RayTune in the training function: session.report(.)
        optim_metric='MSE'
        min_max='min' # minimise MSE
    elif tune_for=='SSIM':
        optim_metric='SSIM'
        min_max='max' # maximize SSIM
    elif tune_for=='CUSTOM':
        optim_metric='CUSTOM'
        min_max='min'

    print('===================')
    print('tune_max_t:', tune_max_t)
    print('optim_metric:',optim_metric)
    print('min_max:', min_max)
    print('grace_period:', grace_period)
    print('tune_minutes', tune_minutes)  # Set in "User Parameters".
    print('===================')

    ## Reporters ##
    reporter1 = CLIReporter( # This reporter currently isn't used.
        metric_columns=[optim_metric,'batch_step'])

    reporter = JupyterNotebookReporter(
        overwrite=True,                                           # Overwrite subsequent reporter tables in output (so there is no scrolling)
        metric_columns=[optim_metric,'batch_step','example_num'], # Values for both 'batch_step' and 'example_num' are passed to RayTune
        metric=[optim_metric],                                    # Which metric is used to determine best trial?
        #mode=['min'],
        sort_by_metric=True,                                      # Order reporter table by metric
    )

    ## Trial Scheduler and Run Config ##
    if tune_scheduler == 'ASHA':
        scheduler = ASHAScheduler(
            time_attr='training_iteration', # "Time" is measured in training iterations. 'training_iteration' is a RayTune keyword (not passed in session.report(...)).
            max_t=tune_max_t, # (default=40). Maximum time units per trial (units = time_attr). Note: Ray Tune will by default run a maximum of 100 display steps (reports) per trial
            metric=optim_metric, # This is the label in a dictionary passed to RayTune (in session.report(...))
            mode=min_max,
            grace_period=grace_period, # Train for a minumum number of time_attr. Set in Tune() arguments.
            #reduction_factor=2
            )
        run_config=air.RunConfig(       # How to perform the run
            name=tune_exp_name,         # Ray checkpoints saved to this file, relative to local_dir. Set in "User Parameters"
            storage_path=local_dir,     # Local directory. Set in "User Parameters"
            progress_reporter=reporter, # Specified above
            failure_config=air.FailureConfig(fail_fast=False), # default = False. Keeps running if there is an error.
            checkpoint_config=air.CheckpointConfig(
                num_to_keep=10,         # Maximum number of checkpoints that are kept per run (for each trial)
                checkpoint_score_attribute=optim_metric,  # Determines which checkpoints are kept on disk.
                checkpoint_score_order=min_max
                )
            )
    else:
        scheduler = FIFOScheduler()     # First in/first out scheduler
        run_config=train.RunConfig(
            stop={'training_iteration': tune_max_t}, # When using the FIFO scheduler, we must explicitly specify the stopping criterian.
            name=tune_exp_name,         # Ray checkpoints saved to this file, relative to local_dir
            storage_path=local_dir,     # Local directory
            progress_reporter=reporter,
            failure_config=air.FailureConfig(fail_fast=False), # default = False
            checkpoint_config=air.CheckpointConfig(
                num_to_keep=10,         # Maximum number of checkpoints that are kept per run.
                checkpoint_score_attribute=optim_metric,  # Determines which checkpoints are kept on disk.
                checkpoint_score_order=min_max)
        )
        '''
        run_config=train.RunConfig(       # How to perform the run
            name=tune_exp_name,              # Ray checkpoints saved to this file, relative to local_dir
            storage_path=local_dir,     # Local directory
            progress_reporter=reporter,
            failure_config=air.FailureConfig(fail_fast=False), # default = False
            checkpoint_config=air.CheckpointConfig(
                num_to_keep=10,         # Maximum number of checkpoints that are kept per run.
                checkpoint_score_attribute=optim_metric,  # Determines which checkpoints are kept on disk.
                checkpoint_score_order=min_max,
                stop={"time_total_s": 5})
            #    stop={"training_iteration": tune_max_t}) # The FIFO scheduler does not have a stopping criterian, so this stops the trial.
            )
        '''

    ## HyperOpt Search Algorithm ##
    search_alg = HyperOptSearch(metric=optim_metric, mode=min_max)  # It's also possible to pass the search space directly to the search algorithm here.
                                                                    # But then the search space needs to be defined in terms of the specific search algorithm methods, rather than letting RayTune translate.

    ## Which trainable do you want to use? ##
    if trainable=='SUP':
        trainable_with_resources = tune.with_resources(train_Supervisory_Sym, {"CPU":num_CPUs,"GPU":num_GPUs}) # train_Supervisory_Sym is a function of the config dictionary, but we don't state that explicitly.
    elif trainable=='GAN':
        trainable_with_resources = tune.with_resources(train_test_GAN, {"CPU":num_CPUs,"GPU":num_GPUs})
    elif trainable=='CYCLE':
        trainable_with_resources = tune.with_resources(train_test_CYCLE, {"CPU":num_CPUs,"GPU":num_GPUs})

    ## If starting from scratch ##
    if tune_restore==False:

        # Initialize a blank tuner object
        tuner = tune.Tuner(
                trainable_with_resources,       # The objective function w/ resources
                param_space=config,             # Let RayTune know what parameter space (dictionary) to search over.
                tune_config=tune.TuneConfig(    # How to perform the search
                    num_samples=-1,
                    time_budget_s=tune_minutes*60, # time_budget is in seconds
                    scheduler=scheduler,
                    search_alg=search_alg,
                    ),
                run_config=run_config
                )

    ## If loading from a checkpoint ##
    else:
        # Load the tuner
        tuner = tune.Tuner.restore(
            path=os.path.join(local_dir, tune_exp_name), # Path where previous run is checkpointed
            trainable=trainable_with_resources,
            resume_unfinished = False
            )

    result_grid: ResultGrid = tuner.fit()


## Test CNN (by chunks)

In [None]:
def test_by_chunks(test_begin_at=0, test_chunk_size=5000, testset_size = 35000, sample_division=1, part_name='batch_dataframe_part_',
         test_merge_dataframes=False, test_csv_file='combined_dataframe'):
    '''
    Splits up testing the CNN (on a test set) into smaller chunks so that computer time-outs don't result in lost work.

    test_begin_at:      Where to begin the testing. You set this to >0 if the test terminates early and you need to pick up partway through the test set.
    test_chunk_size:    How many examples to test in each chunk
    testset_size:       Number of examples that you wish to test. This can be less than the number of examples in the dataset file but not more.
    sample_division:    To test every example, set to 1. To test every other example, set to 2, and so forth.
    part_name:          Roots of dataframe parts files (containing testing results) that will be saved. These will have a number appended to them when saved.
    test_merge_dataframes:  Set to True to merge the smaller parts dataframes into a larger dataframe once the smaller parts have finished calculating.
                            Otherwise, you can use the MergeTests function below at a later time.
    '''

    label_num=int(test_begin_at/test_chunk_size) # Which numbered dataframe parts file you start at.

    for index in range(test_begin_at, testset_size, test_chunk_size):

        save_filename = part_name+str(label_num)+'.csv'

        print('###############################################')
        print(f'################# Working on:', save_filename)
        print(f'################# Starting at example: ', index)
        print('###############################################')

        # Since run_mode=='test', the training function returns a test dataframe. #
        chunk_dataframe = train_Supervisory_Sym(config, offset=index, num_examples=test_chunk_size, sample_division=sample_division)
        chunk_dataframe_path = os.path.join(test_dataframe_dir, save_filename)
        chunk_dataframe.to_csv(chunk_dataframe_path, index=False)
        label_num += 1

    if test_merge_dataframes==True:
        max_index = int(testset_size/test_chunk_size)-1
        merge_test_chunks(max_index, part_name=part_name, test_csv_file=test_csv_file)


def merge_test_chunks(max_index, part_name='batch_dataframe_part_', test_csv_file='combined_dataframe'):
    '''
    Function for merging smaller dataframes (which contain metrics for individual images) into a single larger dataframe.

    max_index:      number of largest index
    part_name:      root of part filenames (not including the numbers appended to the end)
    test_csv_file:  filename for the combined dataframe
    '''

    ## Build list of filenames ##
    names = []
    for i in range(0, max_index+1):
        save_filename = part_name+str(i)+'.csv'
        names.append(save_filename)

    ## Concatenate parts dataframes ##
    first = True
    for name in names:
        add_path = os.path.join(test_dataframe_dir, name)
        print('Concatenating: ', add_path)
        add_frame = pd.read_csv(add_path)

        if first==True:
            test_dataframe = add_frame
            first=False
        else:
            test_dataframe = pd.concat([test_dataframe, add_frame], axis=0)

    ## Save Result ##
    test_dataframe.to_csv(test_dataframe_path, index=False)
    test_dataframe.describe()

#merge_test_chunks(34)

# Run

In [None]:
if run_mode=='tune':
    if train_type=="SUP":
        print('Tuning w/ Supervisory Only!')
        time.sleep(3)
        Tune(tune_max_t=tune_max_t, trainable='SUP', grace_period=1) # for 90-90, tune_max_t=35 | 180-71, tune_max_t=25p | for LDM, tune_max_T=25
    if train_type=='GAN':
        print('Tuning a GAN!')
        time.sleep(3)
        Tune(tune_max_t=tune_max_t, trainable='GAN', grace_period=1)
    if train_type=='CYCLESUP' or train_type=='CYCLEGAN':
        print('Tuning a Cycle!')
        time.sleep(3)
        Tune(tune_max_t=tune_max_t, trainable='CYCLE', grace_period=1)
elif (run_mode=='train') or (run_mode=='visualize'):
    if train_type=="SUP":
        train_Supervisory_Sym(config=config, offset=offset, num_examples=-1, sample_division=sample_division)
    if train_type=='GAN':
        train_test_GAN(config=config)
    if train_type=='CYCLESUP' or train_type=='CYCLEGAN':
        train_test_CYCLE(config=config)
elif run_mode=='test':
    test_by_chunks(test_begin_at=test_begin_at, test_chunk_size=test_chunk_size, testset_size=testset_size, sample_division=sample_division, part_name='batch_dataframe_part_', test_merge_dataframes=test_merge_dataframes, test_csv_file=test_csv_file)
break



+----------------------------------------------------------+
| Configuration for experiment     search-Temp             |
+----------------------------------------------------------+
| Search algorithm                 SearchGenerator         |
| Scheduler                        AsyncHyperBandScheduler |
| Number of trials                 9223372036854775807     |
+----------------------------------------------------------+

View detailed results here: /content/drive/MyDrive/Colab/Working/search-Temp
To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2025-05-19_17-53-49_975554_2531/artifacts/2025-05-19_18-01-21/search-Temp/driver_artifacts`

Trial status: 1 PENDING
Current time: 2025-05-19 18:01:22. Total running time: 0s
Logical resource usage: 4.0/8 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:T4)
+----------------------------------------------------------------------------------------------------------------------------------------------------------




Trial status: 1 RUNNING | 1 PENDING
Current time: 2025-05-19 18:01:52. Total running time: 30s
Logical resource usage: 4.0/8 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:T4)
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                       status     SI_normalize       SI_gen_mult     SI_gen_fill     SI_gen_neck     SI_gen_z_dim   SI_layer_norm     SI_pad_mode     SI_dropout       SI_exp_kernel   SI_gen_final_activ       SI_gen_hidden_dim     batch_size        gen_lr     gen_b1     gen_b2   sup_criterion   |
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

2025-05-19 18:02:35,694	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/content/drive/MyDrive/Colab/Working/search-Temp' in 0.0333s.


Trial status: 1 RUNNING | 1 PENDING
Current time: 2025-05-19 18:02:35. Total running time: 1min 13s
Logical resource usage: 4.0/8 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:T4)
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                       status     SI_normalize       SI_gen_mult     SI_gen_fill     SI_gen_neck     SI_gen_z_dim   SI_layer_norm     SI_pad_mode     SI_dropout       SI_exp_kernel   SI_gen_final_activ       SI_gen_hidden_dim     batch_size        gen_lr     gen_b1     gen_b2   sup_criterion       iter     total time (s)           MSE         SSIM     CUSTOM     example_number |
+----------------------------------

Resume experiment with: Tuner.restore(path="/content/drive/MyDrive/Colab/Working/search-Temp", trainable=...)
- train_Supervisory_Sym_9820be41: FileNotFoundError('Could not fetch metrics for train_Supervisory_Sym_9820be41: both result.json and progress.csv were not found at /content/drive/MyDrive/Colab/Working/search-Temp/train_Supervisory_Sym_9820be41_2_IS_disc_adv_criterion=1,IS_disc_b1=1,IS_disc_b2=1,IS_disc_hidden_dim=1,IS_disc_lr=1,IS_disc_patch_2025-05-19_18-01-26')





SyntaxError: 'break' outside loop (<ipython-input-23-60432af037eb>, line 25)

# Analysis Functions

## Plot: Tuning Curves

In [None]:
def PlotFrame(experiment_path, ax, x_ticks, x_label, y_ticks, y_label, xlim=None, ylim=None, logy=False, max_plot_num=-1):
    '''
    This function plots the dataframes for each tuning (experiment).

    experiment_path:    path to the experiment file
    ax:                 Matplotlib axis object to plot the dataframes
    x_ticks:            x-axis label
    x_label:            x-axis title
    y_ticks:            y-axis label
    y_label:            y-axis title
    xlim:               lower limit for the x-axis. Set to None to set no limit.
    ylim:               lower limit for the y-axis. Set to None to set no limit.
    logy:               use a logarithmic scale for the y-axis?
    max_plot_num        maximum number of dataframes to plot. Set to -1 to plot all dataframes.
    '''
    restored_tuner = tune.Tuner.restore(experiment_path,
                                        trainable = tune.with_resources(train_Supervisory_Sym, {"CPU":4,"GPU":1}))
    result_grid = restored_tuner.get_results()

    for i, result in enumerate(result_grid):
        #print(i)
        #label = f"lr={result.config['lr']:.3f}, momentum={result.config['momentum']}"
        try: # Keeps plotting even if there is an error with one of the plots
            result.metrics_dataframe.plot(x=x_ticks, y=y_ticks, ax=ax, label='test', legend=False, xlim=xlim, ylim=ylim,
                                          logy=logy, fontsize=ticksize)
        except:
            print('Error Plotting')
        if i==max_plot_num:
            break
    ax.set_ylabel(y_label, fontsize=fontsize) # 'fontsize' is a variable set outside of the function (see below)
    ax.set_xlabel(x_label, fontsize=fontsize)

    return result_grid


#####################
## Plot Appearance ##
#####################

## Paths ##
#tune_exp_name='search-Full-tunedMSE-SPIE'
#tune_exp_name='search-Full-tunedLDM_w5s2_meanWeighted'
tune_exp_name='search-Full-tunedLDM_w5s5_evenWeighted'
#tune_exp_name='search-Full-tunedMSE-AHSA_scheduler'
#tune_exp_name='search-Quartile-lowSSIM-tunedSSIM-D'

plot_save_name='figure-tuning'    # Save tuning plot to this filename (do not include extension)
plot_dir= '/content/drive/MyDrive/Colab/Working/Plots/'

## Defaults ##
local_dir='/content/drive/MyDrive/Colab/Working/'
experiment_path = f"{local_dir}{tune_exp_name}"

## Figure ##
'''
titlesize=14
fontsize=10
ticksize=8 # font for ticks. Set to None for default
dpi=800
fig_size=(10,2) # Figure Size
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=fig_size, dpi=dpi)
ax1 = axs[0] ; ax2 = axs[1]
'''
titlesize=13
fontsize=12
ticksize=10
dpi=800
figsize=(10,8)

fig = plt.figure(figsize=figsize, dpi=dpi)
gs = gridspec.GridSpec(ncols=100, nrows=100)

# Top Row Axes #
ax1 = fig.add_subplot(gs[0:25,   0:100])
ax2 = fig.add_subplot(gs[38:62,   0:100])
ax3 = fig.add_subplot(gs[75:100,  0:100])

###########
## Plots ##
###########

#result_grid = PlotFrame(experiment_path, axs[0], 'example_number', 'Example Number', 'MSE', 'MSE', ylim=ylim_MSE, logy=True)
result_grid = PlotFrame(experiment_path, ax1, 'batch_step', 'Batch Step', 'MSE', 'MSE', ylim=(4,20), logy=True)
ax1.set_title('(A) MSE Learning Curves', fontsize=titlesize)

result_grid = PlotFrame(experiment_path, ax2, 'batch_step', 'Batch Step', 'SSIM', 'SSIM', ylim=(0,0.8), logy=False)
ax2.set_title('(B) SSIM Learning Curves', fontsize=titlesize)

result_grid = PlotFrame(experiment_path, ax3, 'batch_step', 'Batch Step', 'CUSTOM', 'Local Distributions Metric', ylim=(300,500))
ax3.set_title('(A) LDM Learning Curves', fontsize=titlesize)


#save_path = plot_dir+plot_save_name+'.svg'
#savefig(save_path, bbox_inches='tight')

##########################
## Pick out Best Result ##
##########################

logdir = result_grid.get_best_result("SSIM", mode="max")
print('##################')
print('## Best Result! ##')
print('##################')
print(logdir)

## Plot: Tuning Stats

In [None]:
tune_dataframe_dir= '/content/drive/MyDrive/Colab/Working/Dataframes-Tune-Full'
tune_csv_file = 'frame-tunedMSE-ASHA'

tune_dataframe_path = os.path.join(tune_dataframe_dir, tune_csv_file+'.csv')
tune_dataframe = pd.read_csv(tune_dataframe_path)

## Describe Dataframes ##

#plt.scatter(tune_dataframe['num_params'], tune_dataframe['mean_CNN_MSE'])
#plt.scatter(tune_dataframe['num_params'][1:], tune_dataframe['mean_CNN_MSE'][1:])

tune_dataframe.plot.scatter('num_params', 'mean_CNN_MSE', ylim=(0,5))
tune_dataframe.plot.scatter('gen_lr', 'mean_CNN_MSE', ylim=(0,5))
tune_dataframe.plot.scatter('batch_size', 'mean_CNN_MSE', ylim=(0,5))

'''
plt.scatter(tune_dataframe['num_params'], tune_dataframe['mean_CNN_MSE'], ylim=(0,1))
plt.xlabel('Number of Parameters')
plt.ylabel('MSE')
plt.show()
'''

tune_dataframe.describe()

## Load: Test Dataframes


In [None]:
# tunedMSE #
test_dataframe_dir1= '/content/drive/MyDrive/Colab/Working/Dataframes-TestOnFull'
#test_csv_file1 = 'combined-tunedFullMSE-trainedFull-onTrainingSet-noMLEM'   # Use this dataframe to determine thresholds for sorting training set by metrics
#test_csv_file1 = 'combined-tunedFullMSE-trainedFull-onTestSet-wMLEM'       # Use this dataframe to determine thresholds for sorting test set by metrics
#test_csv_file1 = 'combined-tunedFullSSIM-trainedFull-onTestSet-wMLEM'
test_csv_file1 = 'combined-tunedHighMSE-trainedHighMSE-onTestSet-wMLEM'

#test_dataframe_dir2= '/content/drive/MyDrive/Colab/Working/Dataframes-Test-Quartile-MSE'
#test_dataframe_dir2= '/content/drive/MyDrive/Colab/Working/Dataframes-TestOnFull'
#test_csv_file2 = 'combined-tunedFullSSIM-trainedFull-onTestSet-wMLEM'
#test_csv_file2 = 'combined-tunedHighMSE-trainedHighMSE-onTestSet-wMLEM'
test_csv_file2 = 'combined-tunedLowSSIM-trainedLowSSIM-onTestSet-wMLEM'

# Read Dataframes from File #
dataframe_path1 = os.path.join(test_dataframe_dir1, test_csv_file1+'.csv')
dataframe1 = pd.read_csv(dataframe_path1)
dataframe_path2 = os.path.join(test_dataframe_dir2, test_csv_file2+'.csv')
dataframe2 = pd.read_csv(dataframe_path2)v

## Describe Dataframes ##

#frame_picked = dataframe[dataframe["SSIM (ML-EM)"]>dataframe["SSIM (FBP)"]]
#frame_picked = dataframe[dataframe["SSIM (Network)"]>dataframe["SSIM (ML-EM)"]]

#frame_picked = dataframe[dataframe["MSE (Network)"]<dataframe["MSE (ML-EM)"]]
#frame_picked = dataframe[dataframe["MSE (ML-EM)"]<dataframe["MSE (FBP)"]]

#frame_picked = dataframe1[dataframe1["MSE (FBP)"]>0.95908]
#frame_picked = dataframe1[dataframe1["MSE (FBP)"]<0.330922]
frame_picked = dataframe1[dataframe1["SSIM (FBP)"]<0.837850]

#dataframe1.describe()
dataframe2.describe()
#frame_picked.describe()

### Plot: Test Dataframes

In [None]:
# Define Plotting Functions #
def plot_hist_1D(ax, dataframe, title, x_label, y_label, column_1, column_2, xlim, ylim, bins=400, alpha=0.5):
    '''
    Plots a histogram of a columns in a dataframe.
    '''
    dataframe = dataframe[dataframe[column_1]>xlim[0]] # Only grab those elements of dataframe that are within the correct limits
    dataframe = dataframe[dataframe[column_1]<xlim[1]]
    dataframe = dataframe[dataframe[column_2]>xlim[0]]
    dataframe = dataframe[dataframe[column_2]<xlim[1]]

    dataframe[[column_1, column_2]].plot.hist(xlim=xlim, ylim=ylim, bins=bins, alpha=alpha, ax=ax, fontsize=fontsize)
    ax.set_title(title, fontsize=titlesize)  # Add a title to the axis.
    ax.set_xlabel(x_label, fontsize=fontsize)  # Add an x-label to the axis.
    ax.set_ylabel(y_label, fontsize=fontsize)  # Add a y-label to the axis.

def plot_hist_2D(ax, dataframe, title, x_label, y_label, x_column, y_column, xlim=(0,1), ylim=(0,1), gridsize=None):
    '''
    Plots hexagonal bin plot of a two columns in a dataframe.

    dataframe   the dataframe from which to grab the data
    x_label     label of data to plot on the x-axis. This must match a column label in the dataframe.
    y_label     label of data to plot on the y-axis. This must match a column label in the dataframe.
    gridsize    how large to make the grid on the gridplot
    '''
    dataframe = dataframe[dataframe[x_column]>xlim[0]]
    dataframe = dataframe[dataframe[x_column]<xlim[1]]
    dataframe = dataframe[dataframe[y_column]>ylim[0]]
    dataframe = dataframe[dataframe[y_column]<ylim[1]]
    dataframe.plot.hexbin(ax=ax, x=x_column, y=y_column, xlim=xlim, ylim=ylim, gridsize=gridsize, fontsize=ticksize)

    ax.set_title(title, fontsize=titlesize)  # Add a title to the axis.
    ax.set_xlabel(x_label, fontsize=fontsize)  # Add an x-label to the axis.
    ax.set_ylabel(y_label, fontsize=fontsize)  # Add a y-label to the axis.
    ax.plot(xlim, ylim, linestyle='--') # plot dividing line

## Specify Plotting Parameters ##
plot_type = 2 # 1 = histograms, 2 = bin plots, 3 = both

column_MSE_1 = 'MSE (ML-EM)'
#column_MSE_1 = 'MSE (FBP)'
column_MSE_2 = 'MSE (Network)'
#column_MSE_2 = 'MSE (FBP)'

column_SSIM_1 = 'SSIM (ML-EM)'
#column_SSIM_1 = 'SSIM (FBP)'
column_SSIM_2 = 'SSIM (Network)'
#column_SSIM_2 = 'SSIM (FBP)'


titlesize=12
fontsize=9
ticksize=7
dpi=800

if plot_type == 1 or plot_type == 2:
    figsize=(8,6) # 17,5
    fig = plt.figure(figsize=figsize, dpi=dpi)
    gs = gridspec.GridSpec(ncols=100, nrows=100)

    # Top Row Axes #
    ax1 = fig.add_subplot(gs[0:42,   0:43])
    ax2 = fig.add_subplot(gs[0:42,   57:100])

    # Bottom Row Axes #
    ax3 = fig.add_subplot(gs[58:100, 0:43])
    ax4 = fig.add_subplot(gs[58:100, 57:100])

    if plot_type == 1:
        plot_hist_1D(ax1, dataframe1, '(1) CNN-A: MSE Histogram',  'MSE', 'frequency', column_MSE_1 , column_MSE_2, xlim=(0,4), ylim=(0,5000), bins=40)
        plot_hist_1D(ax2, dataframe1, '(2) CNN-A: SSIM Histogram', 'SSIM','frequency', column_SSIM_1, column_SSIM_2, xlim=(0.6,1), ylim=(0,4000), bins=40)
        plot_hist_1D(ax3, dataframe2, '(3) CNN-B: MSE Histogram',  'MSE', 'frequency', column_MSE_1 , column_MSE_2,  xlim=(0,4), ylim=(0,5000),  bins=40)
        plot_hist_1D(ax4, dataframe2, '(4) CNN-B: SSIM Histogram', 'SSIM','frequency', column_SSIM_1, column_SSIM_2, xlim=(0.6,1), ylim=(0,4000), bins=40)
    if plot_type == 2:
        plot_hist_2D(ax1, dataframe1, '(1) CNN-A: MSE Bin Plot', column_MSE_1, 'MSE (CNN-A)', column_MSE_1 , column_MSE_2,(0,1.5), (0,1.5), gridsize=60)
        plot_hist_2D(ax2, dataframe1, '(2) CNN-A: SSIM Bin Plot',column_SSIM_1, 'SSIM (CNN-A)', column_SSIM_1, column_SSIM_2, (.7,1), (.7,1), gridsize=100)
        plot_hist_2D(ax3, dataframe2, '(3) CNN-B: MSE Bin Plot', column_MSE_1, 'MSE (CNN-B)', column_MSE_1 , column_MSE_2, (0,1.5), (0,1.5), gridsize=60)
        plot_hist_2D(ax4, dataframe2, '(4) CNN-B: SSIM Bin Plot', column_SSIM_1, 'SSIM (CNN-B)', column_SSIM_1, column_SSIM_2, (.7,1), (.7,1), gridsize=100)

if plot_type == 3:
    figsize=(15,6) # 17,5
    fig = plt.figure(figsize=figsize, dpi=dpi)
    gs = gridspec.GridSpec(ncols=100, nrows=100)

    # Top Row Axes #
    ax1 = fig.add_subplot(gs[0:42,   0:18]) # 20
    ax2 = fig.add_subplot(gs[0:42,   25:47]) # 22
    ax3 = fig.add_subplot(gs[0:42,   53:74]) # 20
    ax4 = fig.add_subplot(gs[0:42,   80:100]) # 22

    # Bottom Row Axes #
    ax5 = fig.add_subplot(gs[58:100, 0:18]) # -5-
    ax6 = fig.add_subplot(gs[58:100, 25:47]) # -3 - -3-
    ax7 = fig.add_subplot(gs[58:100, 53:74]) # -5-
    ax8 = fig.add_subplot(gs[58:100, 80:100])

    plot_hist_1D(ax1, dataframe1, '(1) CNN-A: MSE Histogram',  'MSE', 'frequency', column_MSE_1 , column_MSE_2, xlim=(0,4), ylim=(0,5000), bins=40)
    plot_hist_1D(ax2, dataframe1, '(3) CNN-A: SSIM Histogram', 'SSIM','frequency', column_SSIM_1, column_SSIM_2, xlim=(0.6,1), ylim=(0,4000), bins=40)
    plot_hist_2D(ax3, dataframe1, '(5) CNN-A: MSE Bin Plot', column_MSE_1, 'MSE (CNN-A)', column_MSE_1 , column_MSE_2,(0,1.5), (0,1.5), gridsize=60)
    plot_hist_2D(ax4, dataframe1, '(7) CNN-A: SSIM Bin Plot',column_SSIM_1, 'SSIM (CNN-A)', column_SSIM_1, column_SSIM_2, (.7,1), (.7,1), gridsize=100)

    plot_hist_1D(ax5, dataframe2, '(2) CNN-B: MSE Histogram',  'MSE', 'frequency', column_MSE_1 , column_MSE_2,  xlim=(0,4), ylim=(0,5000),  bins=40)
    plot_hist_1D(ax6, dataframe2, '(4) CNN-B: SSIM Histogram', 'SSIM','frequency', column_SSIM_1, column_SSIM_2, xlim=(0.6,1), ylim=(0,4000), bins=40)
    plot_hist_2D(ax7, dataframe2, '(6) CNN-B: MSE Bin Plot', column_MSE_1, 'MSE (CNN-B)', column_MSE_1 , column_MSE_2, (0,1.5), (0,1.5), gridsize=60)
    plot_hist_2D(ax8, dataframe2, '(8) CNN-B: SSIM Bin Plot', column_SSIM_1, 'SSIM (CNN-B)', column_SSIM_1, column_SSIM_2, (.7,1), (.7,1), gridsize=100)

save_path = plot_dir+'figure-histograms.png'
savefig(save_path, bbox_inches='tight')


## Plot: Example Images

In [None]:
### User Parameters ###
#######################

## Indexes of Example Images ##
#-----------------------------#
# Panel 1: Network performs much better for all images
#indexes = [2280+4, 13*120, 187*120+13, 151 * 120, 240+37, 147*120, 187*120+17, 239*120]

# Panel 2: Network performs somewhat better:
#indexes = [1073, 840+48, 108 * 120,  147*120+33, 153*120, 1440+71, 13560+17, 153*120, 224*120+161]

# Panel 3: ML-EM performs better
#indexes = [960+97, 1268*120, 111*120]

#Panel 5: Panel 1 + Panel 2 --> image 0-8 (network much better)
#indexes = [2280+4, 13*120, 151 * 120, 240+37, 147*120, 187*120+17, 239*120, 1073]
#indexes = [840+48+1, 108 * 120,  147*120+33, 153*120+2, 1440+71, 13560+17, 224*120+161+5]

# Panel 6: Used in SPIE Paper
#indexes = [2280+4, 240+37, 187*120+17, 239*120, 1073, 840+48+1, 153*120+2,  13*120, 224*120+161+5]
#indexes = [240+37, 187*120+17, 239*120, 1073, 840+48+1, 224*120+161+5, 153*120+2, 13*120] # Final Cut

# Final Panel (nine images max)

## Paths ##
indexes = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110]

image_path = '/content/drive/MyDrive/Repository/PET_Data/test_image-35k.npy'
sino_path =   '/content/drive/MyDrive/Repository/PET_Data/test_sino-35k.npy'


checkpoint_dir = checkpoint_dir='/content/drive/MyDrive/Colab/Working/Checkpoints-trainOnFull'

checkpoint_file_SSIM= 'checkpoint-tunedSSIM-14d-6epochs'
checkpoint_file_MSE = 'checkpoint-tunedMSE-fc6-6epochs'
checkpoint_file_MAE = 'checkpoint-tunedMAE-b08-6epochs'
checkpoint_file_LDM = 'checkpoint-tunedLDM-w10s8-b5c-6epochs'
checkpoint_file_LDM_batch='checkpoint-tunedLDM_batch-f9f-6epochs'

## Dimensions ##
image_size=90
sino_size=90
image_channels=1
sino_channels=1

## CNNs ##
config_MSE= { # 1x90x90, Tuned for MSE - fc6 #
    "SI_dropout": False, "SI_exp_kernel": 4, "SI_gen_fill": 0, "SI_gen_final_activ": None, "SI_gen_hidden_dim": 14, "SI_gen_mult": 2.3737518721494038,
    "SI_gen_neck": 5, "SI_gen_z_dim": 300, "SI_layer_norm": "instance", "SI_normalize": True, "SI_pad_mode": "zeros", "SI_scale": 8100,
    "batch_size": 266, "gen_b1": 0.5194977285709309, "gen_b2": 0.4955647195661826, "gen_lr": 0.0006569034263698925, "sup_criterion": nn.MSELoss() }

config_SSIM = { # 1x90x90, Tuned for SSIM - 14d #
    "SI_dropout": False, "SI_exp_kernel": 4, "SI_gen_fill": 0, "SI_gen_final_activ": nn.Tanh(), "SI_gen_hidden_dim": 23, "SI_gen_mult": 1.6605902406330195,
    "SI_gen_neck": 5, "SI_gen_z_dim": 789, "SI_layer_norm": "instance", "SI_normalize": True, "SI_pad_mode": "zeros", "SI_scale": 8100, "batch_size": 71, "gen_b1": 0.2082092731474774,
    "gen_b2": 0.27147903136187507, "gen_lr": 0.0005481469822215635, "sup_criterion": nn.MSELoss()}

config_MAE = { # 1x90x90, Tuned for MAE, - b08 #
    "SI_dropout": True, "SI_exp_kernel": 3, "SI_gen_fill": 0, "SI_gen_final_activ": nn.Tanh(), "SI_gen_hidden_dim": 29, "SI_gen_mult": 3.4493572412953926,
    "SI_gen_neck": 5, "SI_gen_z_dim": 92, "SI_layer_norm": "instance", "SI_normalize": True, "SI_pad_mode": "zeros", "SI_scale": 8100, "batch_size": 184,
    "gen_b1": 0.41793988944151467, "gen_b2": 0.15133808988276928, "gen_lr": 0.0012653525173041019, "sup_criterion": nn.L1Loss() }

config_LDM={ # 1x90x90, Tuned for CUSTOM = LDM (image statistics)
    "SI_dropout": False, "SI_exp_kernel": 4, "SI_gen_fill": 0, "SI_gen_final_activ": None, "SI_gen_hidden_dim": 9, "SI_gen_mult": 2.1547197646081444,
    "SI_gen_neck": 5, "SI_gen_z_dim": 344, "SI_layer_norm": "batch", "SI_normalize": True, "SI_pad_mode": "zeros", "SI_scale": 8100, "batch_size": 47,
    "gen_b1": 0.31108788447029295, "gen_b2": 0.3445239707919786, "gen_lr": 0.0007561178182660596, "sup_criterion": nn.L1Loss()}

config_LDM_batch={ # 1x90x90, Tuned for CUSTOM = LDM (batch statistics)
    "SI_dropout": False, "SI_exp_kernel": 3, "SI_gen_fill": 0, "SI_gen_final_activ": nn.Tanh(), "SI_gen_hidden_dim": 19, "SI_gen_mult": 2.70340867805694,
    "SI_gen_neck": 1, "SI_gen_z_dim": 1616, "SI_layer_norm": "instance", "SI_normalize": True, "SI_pad_mode": "zeros", "SI_scale": 8100, "batch_size": 363,
    "gen_b1": 0.20393974474424928, "gen_b2": 0.6490512100839003, "gen_lr": 0.0004491464075393307, "sup_criterion": nn.MSELoss()}

## Defaults ##
train_type='SUP'
train_SI=True
image_array = np.load(image_path, mmap_mode='r')       # self.image_tensor.shape=(#examples x1x71x71)
sino_array = np.load(sino_path, mmap_mode='r')     # self.sinogram_tensor.shape=(#examples x3x101x180)

## Build Image & Sino Tensors ##
def BuildImageSinoTensors(image_array, sino_array, config, indexes):
    '''
    Build image and sinogram tensors with images and sinograms determined by the indexes list.
    '''
    first=True
    i=0
    for idx in indexes:
        sino_ground, sino_ground_scaled, image_ground, image_ground_scaled = NpArrayDataLoader(image_array, sino_array, config,
                                                                                image_size = image_size, sino_size=sino_size,
                                                                                image_channels=image_channels,
                                                                                sino_channels=sino_channels,
                                                                                augment=False, index=idx)
        # If first time through the loop, create blank tensors (for sino & image) with the correct shape
        if first==True:
            image_tensor = torch.zeros(len(indexes), image_ground_scaled.shape[0], image_ground_scaled.shape[1], image_ground_scaled.shape[2]).to(device)
            sino_tensor  = torch.zeros(len(indexes), sino_scsino_ground_scaled0],  sino_scsino_ground_scaled1],  sino_scsino_ground_scaled2]).to(device)
            first=False
        # Fill the tensors with images & sinograms
        image_tensor[i,:] = image_ground_scaled
        sino_tensor[i,:]  = sino_ground_scaled
        i+=1
    return image_tensor, sino_tensor

## CNN Outputs ##
def CNN_reconstruct(sino_tensor, config, checkpoint_file_name, input_size=90, input_channels=1, output_channels=1):
    '''
    Construct CNN reconstructions of images of a sinogram tensor.
    '''
    gen =  Generator(config=config, gen_SI=True, input_size=input_size, input_channels=input_channels, output_channels=output_channels).to(device)
    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file_name)
    checkpoint = torch.load(checkpoint_path)
    gen.load_state_dict(checkpoint['gen_state_dict'])
    gen.eval()
    return gen(sino_tensor).detach()

## Outputs ###
image_tensor, sino_tensor = BuildImageSinoTensors(image_array, sino_array, config_MSE, indexes)

MLEM_output = reconstruct(sino_tensor, config_MSE, image_size=image_size, recon_type='MLEM', circle=True)
CNN_output_MSE = CNN_reconstruct(sino_tensor, config_MSE, checkpoint_file_MSE)
CNN_output_SSIM = CNN_reconstruct(sino_tensor, config_SSIM, checkpoint_file_SSIM)
CNN_output_MAE = CNN_reconstruct(sino_tensor, config_MAE, checkpoint_file_MAE)
CNN_output_LDM = CNN_reconstruct(sino_tensor, config_LDM, checkpoint_file_LDM)
CNN_output_LDM_batch = CNN_reconstruct(sino_tensor, config_LDM_batch, checkpoint_file_LDM_batch)

#MLEM_output2 = reconstruct(sino_tensor, config1, image_size=image_size, recon_type='MLEM', circle=False)
#FBP_output =  reconstruct(sino_tensor, config1, image_size=image_size, recon_type='FBP')

#############
## Metrics ##
#############

'''
frame_SSIM_MLEM, placeholder = calculate_metric(MLEM_output, image_tensor, SSIM, dataframe = True, label='MLEM, SSIM')
frame_MSE_MLEM, placeholder =  calculate_metric(MLEM_output, image_tensor, MSE, dataframe = True, label='MLEM, MSE')
print('################### MLEM ###################')
print(frame_SSIM_MLEM.T)
print(frame_MSE_MLEM.T)


frame_SSIM_tunedMSE, placeholder = calculate_metric(CNN_output_MSE, image_tensor, SSIM, dataframe = True, label='TunedMSE, SSIM')
frame_MSE_tunedMSE, placeholder =  calculate_metric(CNN_output_MSE, image_tensor, MSE, dataframe = True, label='TunedMSE, MSE')
print('################### CNN (MSE) ##################')
print(frame_SSIM_tunedMSE.T)
print(frame_MSE_tunedMSE.T)

frame_SSIM_tunedSSIM, placeholder = calculate_metric(CNN_output_SSIM, image_tensor, SSIM, dataframe = True, label='TunedSSIM, SSIM')
frame_MSE_tunedSSIM, placeholder =  calculate_metric(CNN_output_SSIM, image_tensor, MSE, dataframe = True, label='TunedSSIM, MSE')
frame_LDM_tunedSSIM, placeholder =  calculate_metric(CNN_output_SSIM, image_tensor, custom_metric, dataframe = True, label='TunedSSIM, LDM')
print('################### CNN (SSIM) ##################')
print(frame_SSIM_tunedSSIM.T)
print(frame_MSE_tunedSSIM.T)
print(frame_LDM_tunedSSIM.T)

frame_SSIM_tunedLDM, placeholder = calculate_metric(CNN_output_LDM, image_tensor, SSIM, dataframe = True, label='TunedLDM, SSIM')
frame_MSE_tunedLDM, placeholder =  calculate_metric(CNN_output_LDM, image_tensor, MSE, dataframe = True, label='TunedLDM, MSE')
frame_LDM_tunedLDM, placeholder =  calculate_metric(CNN_output_LDM, image_tensor, custom_metric, dataframe = True, label='TunedLDM, LDM')
print('################### CNN (LDM) ##################')
print(frame_SSIM_tunedLDM.T)
print(frame_MSE_tunedLDM.T)
print(frame_LDM_tunedLDM.T)
'''

####################
## Display Images ##
####################


#show_multiple_matched_tensors(image_tensor, CNN_output_MSE, CNN_output_LDM, CNN_output_MAE, MLEM_output, CNN_output_SSIM, fig_size=1.0)
#show_multiple_matched_tensors(image_tensor, MLEM_output, CNN_output_MSE, CNN_output_MAE, CNN_output_SSIM, CNN_output_LDM_batch, CNN_output_LDM, fig_size=1.0)
#show_multiple_matched_tensors(image_tensor, fig_size=1.0)

#print('Ground Truth/MLEM/CNN1/CNN2')
show_single_unmatched_tensor(sino_tensor)
show_multiple_matched_tensors(image_tensor, MLEM_output, CNN_output_MSE, CNN_output_MAE, fig_size=1.0)

# Sort: Dataset by Metric

In [None]:
def sort_DataSet(config, load_image_path, load_sino_path, save_image_path, save_sino_path, max_save_index, metric_function, threshold, threshold_min_max, num_examples=-1, visualize=False):
    '''
    '''
    # Variables #
    scale=config['SI_scale'] if train_SI==True else config['IS_scale']

    # Dataloader #
    dataloader = DataLoader(
        NpArrayDataSet(image_path=load_image_path, sino_path=load_sino_path, config=config, image_size=image_size, image_channels=image_channels,
                       sino_size=sino_size, sino_channels=sino_channels, num_examples=num_examples),
        batch_size=1,
        shuffle=True
    )

    ### Loop Over Batches ###
    first = True
    saved_idx = 0
    for sino_ground, sino_ground_scaled, image_ground, image_ground_scaled in iter(dataloader): # Dataloader returns the batches. Loop over batches within epochs.

        # Open memory map if first time through the loop #
        if first==True:
            save_image_array_shape = (max_save_index, image_ground_scaled.shape[1], image_ground_scaled.shape[2], image_ground_scaled.shape[3])
            save_sino_array_shape = (max_save_index, sino_ground_scaled.shape[1], sino_ground_scaled.shape[2], sino_ground_scaled.shape[3])
            print('save_image_array_shape: ', save_image_array_shape)
            print('save_sino_array_shape: ', save_sino_array_shape)

            save_image_array = np.lib.format.open_memmap(save_image_path, mode='w+', shape=save_image_array_shape, dtype=np.float32)
            save_sino_array =  np.lib.format.open_memmap(save_sino_path , mode='w+', shape=save_sino_array_shape,  dtype=np.float32)
            first=False

        # Test the image to see if fits the criteria #
        FBP_output =  reconstruct(sino_ground_scaled, config, image_size=image_size, recon_type='FBP')
        image_metric = metric_function(image_ground_scaled, FBP_output)

        if threshold_min_max == 'min':
            keep = True if (image_metric > threshold) else False
        else:
            keep = True if (image_metric < threshold) else False

        if keep==True:
            save_sino_array[saved_idx] = sino_ground_scaled.cpu().numpy()
            save_image_array[saved_idx] = image_ground_scaled.cpu().numpy()
            saved_idx += 1
            print('Current index (for next image): ', saved_idx)

        if visualize==True:
            # Visualize the rejected or accepted sample #
            print('==================================')
            print('Image Metric: ', image_metric)
            print('Threshold: ', threshold)
            print('Keep?: ', keep)
            print('Current index (for next image): ', saved_idx)
            print('Saved Arrays:')
            print('image_ground_scaled / FBP_output / sino_ground_scaled')
            show_multiple_matched_tensors(image_ground_scaled, FBP_output)
            show_multiple_matched_tensors(sino_ground_scaled)
            show_multiple_matched_tensors(torch.from_numpy(save_sino_array[0:9]))
            show_multiple_matched_tensors(torch.from_numpy(save_image_array[0:9]))

    return save_sino_array, save_image_array

## Changeable Variables ##

load_sino_path = '/content/drive/MyDrive/Repository/PET_Data/train_sino-70k.npy'
load_image_path = '/content/drive/MyDrive/Repository/PET_Data/train_image-70k.npy'
save_sino_path = '/content/drive/MyDrive/Repository/PET_Data/quartile_data/train_sino-lowSSIM-17500.npy'
save_image_path = '/content/drive/MyDrive/Repository/PET_Data/quartile_data/train_image-lowSSIM-17500.npy'
'''
metric_function = MSE
max_save_index = 17500
threshold = 0.330922
threshold_min_max = 'min'
'''
metric_function = SSIM
max_save_index = 17500
threshold = 0.837850 #0.837850  # MSE (min): 0.330922, SSIM (max): 0.837850
threshold_min_max = 'max'

## Run & Verify Result ##
save_sino_array, save_image_array = sort_DataSet(config, load_image_path, load_sino_path, save_image_path, save_sino_path, max_save_index,
                                                   metric_function, threshold, threshold_min_max=threshold_min_max, visualize=False)

sino_ground_scaled

### Save Datasets & Check

In [None]:
#save_sino_path = '/content/drive/MyDrive/Repository/PET_Data/quartile_data/train_sino-lowSSIM-17500.npy'
#save_image_path = '/content/drive/MyDrive/Repository/PET_Data/quartile_data/train_sino-lowSSIM-17500.npy'

# Print sorted array shape & display a few images #
print('save_sino_array.shape: ', save_sino_array.shape)
print('save_image_array.shape: ', save_image_array.shape)

print('save_sino_array sample images')
print('save_image_array sample images')
show_multiple_matched_tensors(torch.from_numpy(save_sino_array[500:509]))
show_multiple_matched_tensors(torch.from_numpy(save_image_array[500:509]))


# Save the sorted array to disk #
save_sino_array.flush()
save_image_array.flush()
#np.save(save_sino_path, save_sino_array)
#np.save(save_image_path, save_image_array)

# Load the saved array and make sure it's the same size/has the same images #
load_sino_array = np.load(save_sino_path, mmap_mode='r')
load_image_array = np.load(save_image_path, mmap_mode='r')
print('load_sino_array.shape: ', load_sino_array.shape)
print('load_image_array.shape: ', load_image_array.shape)

print('load_sino_array sample images')
print('load_image_array sample images')
show_multiple_matched_tensors(torch.from_numpy(load_sino_array[500:509]))
show_multiple_matched_tensors(torch.from_numpy(load_image_array[500:509]))


# Experimenting

In [None]:
## Find what GPU I'm using ##

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
'''
import sys
print('a')
sys.exit()
print('b')
'''
'''
print(train_image_path)
print(train_sino_path)
print(test_image_path)
print(test_sino_path)
'''
print(shuffle)

# Notes:
Change next
===========
tune_even_reporting=False


For high/low MSE experiments
============================
-Tuned networks for 180 minutes each.

-Trained for 100 epochs using on-the-fly augmentation

-See notes in checkpoint folder


For LDM, window = 5, stride = 2
===============================
tune_max_t = 20            

tune_minutes = 180      

tune_display_step=12    

tune_augment=False


GPUs
====
From best to worst:

V100 - 6.92/hr

L4 - 2.15/hr

T4 - 1.7/hr

v6e-1 TPU - 4.21/hr

v5e-1 TPU - 4.11/hr

v2-8 TPU - 1.82/hr

Tensor board works for all experiments except the last one.
My plotting function no longer works for any of the experiments.