# Using pyapetnet to predict anatomy-guided MAP PET reconstructions in image space

In this notebook, we wil learn how to use pre-trained models included in the pyapetnet package to predict anatomy-guided MAP PET reconstructions from (simulated) PET OSEM and T1 MR images.

In this tutorial, we will have a closer look at:
- loading pre-trained models
- loading nifti data
- pre-processing nifti data
- feeding the pre-processed data into the pre-trained model
- saving visualizing the results

**If you install pyapetnet from pypi using ```pip install pyapetnet```**, it will create a command line tool that does all those steps in one go. Moreover, it allows allows to load and write dicom data.

For more details on pyapetnet is available here:
- https://doi.org/10.1016/j.neuroimage.2020.117399 (NeuroImage publication on pyapetnet)
- https://github.com/gschramm/pyapetnet/ (github repository of pyapetnet)

## (1) Preparation: Install the pyapetnet package

Before running this notebook, make sure that the pyapetnet package is installed.
This can by done via <br>
```pip install pyapetnet``` <br> 
which will install the package and all its dependencies (e.g. tensorflow). We recommend to use a separate virtual environment.


## (2) Data used in this demo

In this tutorial, we wil use simulated PET and MR data that are based on the brainweb phantom.
The nifti files used in this tutorial, are available at <br>
https://github.com/gschramm/pyapetnet/tree/master/demo_data <br>
By changing ```pet_fname``` or ```mr_fname``` other input data sets can be used.


## (3) Loading modules
In the next cell, we will load all required python modules. E.g. tensorflow, to load the pre-trained model and pyapetnet for data preprocessing

In [None]:
import nibabel as nib
import json
import os
import tensorflow as tf

import numpy as np
import os

import matplotlib.pyplot as plt

import pyapetnet
from pyapetnet.preprocessing import preprocess_volumes
from pyapetnet.utils         import load_nii_in_ras

## (4) Specification of input pameters
In the next cell, we specify the required input parameters:
- ```model_name``` (name of the pre-trained model shipped with the pyapernet package)
- ```pet_fname / mr_fname``` (absolute path of the PET and MR input nifti files)
- ```coreg_inputs``` whether to apply rigid coregistration between PET and MR volumes using mutual information
- ```crop_mr``` whether to crop both volumes to the bounding box of the MR (usefule to limit memory usage)
- ```output_fname``` absolute path of the nifti file for the output


In [None]:
# inputs (adapt to your needs)

# the name of the trained CNN
model_name = '200824_mae_osem_psf_bet_10'

# we use a simulated demo data included in pyapetnet (based on the brainweb phantom)
mydata_dir = '.'
pet_fname  = os.path.join(mydata_dir, 'brainweb_06_osem.nii')
mr_fname   = os.path.join(mydata_dir, 'brainweb_06_t1.nii')

# preprocessing parameters

coreg_inputs = True  # rigidly coregister PET and MR using mutual information
crop_mr      = True   # crop the input to the support of the MR (saves memory + speeds up the computation)

# the name of the ouput file
output_fname =  f'prediction_{model_name}.nii'

## (5) Load the pre-trained CNN (model)
Now we can load the pretrained model. pyapetnet includes a few preprained models that are installed all installed
at <br>
```os.path.join(os.path.dirname(pyapetnet.__file__),'trained_models')```<br>
where ```pyapetnet.__file__``` points to the install path of pyapetnet.

A more detailed description of all models can be found at <br>
https://github.com/gschramm/pyapetnet/blob/master/pyapetnet/trained_models/model_description.md

The dummy dictionary ```custom_objects``` is needed since the model definition depends on 2 custom loss functions (related to SSIM). For inference the loss fucntions are not needed with is why we pass a dummy dictionary.

Last but not least, we read the internal voxel size used to train the model. This is necessary to correctly pre-process the input data (which comes usually in a different voxel size). 

In [None]:
# load the trained CNN and its internal voxel size used for training
model_abs_path = os.path.join(os.path.dirname(pyapetnet.__file__),'trained_models',model_name)

model = tf.keras.models.load_model(model_abs_path, custom_objects = {'ssim_3d_loss': None,'mix_ssim_3d_mae_loss': None})
                   
# load the voxel size used for training
with open(os.path.join(model_abs_path,'config.json')) as f:
  cfg = json.load(f)
  training_voxsize = cfg['internal_voxsize']*np.ones(3)

## (6) Load and preprocess the input PET and MR volumes

Finally, ee load the data from the input nifti files. The preprocessing function rigidly coregisters the inputs,
interpolates the volumes to the internal voxel size of the CNN, crops the volumes to the MR support, and does an intensity normalization (division by 99.9% percentile). We use the 99.99% percentile since it is more robust for noisy (PET) volumes.

**The voxelsize of the input volumes is deduced from the affine transforamtion stored in the nifti header. Make sure that the affine stored there is correct.**

In [None]:
# load and preprocess the input PET and MR volumes
pet, pet_affine = load_nii_in_ras(pet_fname)
mr, mr_affine   = load_nii_in_ras(mr_fname)

# preprocess the input volumes (coregistration, interpolation and intensity normalization)
pet_preproc, mr_preproc, o_aff, pet_max, mr_max = preprocess_volumes(pet, mr, 
  pet_affine, mr_affine, training_voxsize, perc = 99.99, coreg = coreg_inputs, crop_mr = crop_mr)

## (7) Show and check pre-processed Input data

Before passing the PET and MR input volumes to the loaded CNN, it is a good idea to check whether both volumes were correctly pre-processed. If the pre-processing was successfull, the volumes should be well aligned, should be interpolated to the internal voxelsize of the CNN, and their 99.99% percentile should be 1. 

In [None]:
print(f'PET 99.99% percentile {np.percentile(pet_preproc,99.99):.3f}')
print(f'PET 99.99% percentile {np.percentile(mr_preproc,99.99):.3f}')

fig, ax = plt.subplots(2,3, figsize = (9,6))
ax[0,0].imshow(pet_preproc[:,::-1,pet_preproc.shape[2]//2].T, cmap = plt.cm.Greys, vmax = 1)
ax[0,1].imshow(pet_preproc[:,pet_preproc.shape[1]//2,::-1].T, cmap = plt.cm.Greys, vmax = 1)
ax[0,2].imshow(pet_preproc[pet_preproc.shape[0]//2,:,::-1].T, cmap = plt.cm.Greys, vmax = 1)
ax[1,0].imshow(mr_preproc[:,::-1,pet_preproc.shape[2]//2].T, cmap = plt.cm.Greys_r, vmax = 1)
ax[1,1].imshow(mr_preproc[:,pet_preproc.shape[1]//2,::-1].T, cmap = plt.cm.Greys_r, vmax = 1)
ax[1,2].imshow(mr_preproc[pet_preproc.shape[2]//2,:,::-1].T, cmap = plt.cm.Greys_r, vmax = 1)

for axx in ax.flatten(): axx.set_axis_off()

ax[0,1].set_title('preprocessed input PET')
ax[1,1].set_title('preprocessed input MR')


fig.tight_layout()

## (8) Running the actual CNN prediction

Once the data is read and preprocesed we can run the actual prediction.
The input to the pyapetnet models is a python list containing two "tensors" (the preprocessed PET and MR volumes). The dimensions of both tensors are (1,n0,n1,n2,1) where n0,n1,n2 are the spatial dimensions of the pre-processed volumes. The left most dimension is the batch size (1 in our case) and the right most dimension is the number of input channels / features (1 in our case). 

We decided to input two (1,n0,n1,n2,1) tensors instead of one (1,n0,n1,n2,2) tensor since in the first layer, since in the first layers we decided to learn separte PET and MR features. 

Based on the design of the model, there is no restiction on the spatial input shape (n0,n1,n2) provided that enough GPU/CPU memory is available.

Using a recent Nvidia GPU, this step should take roughly 1s.

In [None]:
# the actual CNN prediction
x = [np.expand_dims(np.expand_dims(pet_preproc,0),-1), np.expand_dims(np.expand_dims(mr_preproc,0),-1)]
pred = model.predict(x).squeeze()

## (7) Undo the intensity normalization

We undo the intensity normalization that was applied during pre-processing.

In [None]:
pred *= pet_max

## (8) Save the volumes

We save the pre-processed volumes and the prediction to nifti files.

In [None]:
nib.save(nib.Nifti1Image(pet_preproc, o_aff), 'pet_preproc.nii')
nib.save(nib.Nifti1Image(mr_preproc, o_aff), 'mr_preproc.nii')
nib.save(nib.Nifti1Image(pred, o_aff), f'prediction_{model_name}.nii')

## (9) Display the input and the prediction

Finally we display the results.


In [None]:
fig, ax = plt.subplots(3,3, figsize = (9,9))
ax[0,0].imshow(pet_preproc[:,::-1,pet_preproc.shape[2]//2].T, cmap = plt.cm.Greys, vmax = 1)
ax[0,1].imshow(pet_preproc[:,pet_preproc.shape[1]//2,::-1].T, cmap = plt.cm.Greys, vmax = 1)
ax[0,2].imshow(pet_preproc[pet_preproc.shape[0]//2,:,::-1].T, cmap = plt.cm.Greys, vmax = 1)
ax[1,0].imshow(mr_preproc[:,::-1,pet_preproc.shape[2]//2].T, cmap = plt.cm.Greys_r, vmax = 1)
ax[1,1].imshow(mr_preproc[:,pet_preproc.shape[1]//2,::-1].T, cmap = plt.cm.Greys_r, vmax = 1)
ax[1,2].imshow(mr_preproc[pet_preproc.shape[2]//2,:,::-1].T, cmap = plt.cm.Greys_r, vmax = 1)
ax[2,0].imshow(pred[:,::-1,pet_preproc.shape[2]//2].T, cmap = plt.cm.Greys, vmax = pet_max)
ax[2,1].imshow(pred[:,pet_preproc.shape[1]//2,::-1].T, cmap = plt.cm.Greys, vmax = pet_max)
ax[2,2].imshow(pred[pet_preproc.shape[0]//2,:,::-1].T, cmap = plt.cm.Greys, vmax = pet_max)
for axx in ax.flatten(): axx.set_axis_off()

ax[0,1].set_title('pre-processed input PET')
ax[1,1].set_title('pre-processed input MR')
ax[2,1].set_title('predicted MAP Bowsher')

fig.tight_layout()