# Keras vs. ONNX model output comparison

This notebook can be used to generate predictions for sct_example_data and sct_testing data using both the Keras and ONNX deepseg models. It can then be used to compare both sets of predictions to ensure that identical results are preserved.

The steps used are as follows:

### 1. Generating Keras predictions from `master`

> Note: To save time, you can instead extract the sample generated Keras predictions from this [data_keras.zip](https://github.com/spinalcordtoolbox/spinalcordtoolbox/files/8399211/data_keras.zip) into your `$SCT_DIR/data` folder. Then, you can skip right to generating the ONNX Files.

1. Install SCT using the `master` branch.
    - Both Keras and Tensorflow should be installed in the `venv_sct` environment.
2. In the first cell, change `dataset_name` to `"sct_example_data"`. Then, go to Kernel -> Restart & Run All.
    - The second cell should output `"Using keras models on b'master\n' branch..."`
    - 16 files ending in `_keras.nii.gz` should be generated. (2 dmri, 3 t1, 5 t2, and 6 t2s)
3. In the first cell, change `dataset_name` to `"sct_testing_data"`. Then, go to Kernel -> Restart & Run All.
    - The same output files should be generated, just for a different dataset.
 
### 2. Generating ONNX predictions from `jn/3735` + comparing Keras/ONNX 

> Note: To save time, you can instead extract the sample generated ONNX predictions from this [data_onnx.zip](https://github.com/spinalcordtoolbox/spinalcordtoolbox/files/8399209/data_onnx.zip) into your `$SCT_DIR/data` folder. Then, you can choose to only run the first and last cells in this notebook, instead of running all of the cells.

1. Switch to the `jn/3735-replace-tensorflow-with-onnxruntime` branch and install SCT.
   - Neither Keras nor Tensorflow should be installed in the `venv_sct` environment.
   - Note: It may be less of a hassle to keep two separate installations (e.g. in two separate working environments), rather than overwriting the existing installation.
2. Ensure that you have the previously-mentioned `_keras.nii.gz` files in your `data/sct_example_data/` and `data/sct_testing_data/` folders.
2. In the first cell, change `dataset_name` to `"sct_example_data"`. Then, go to Kernel -> Restart & Run All.
    - The second cell should output `"Using onnx models on b'jn/3735-replace-tensorflow-with-onnxruntime\n' branch..."`
    - 16 files ending in `_onnx.nii.gz` should be generated. (2 dmri, 3 t1, 5 t2, and 6 t2s)
    - These files should be checked against the `_keras.nii.gz` files in the last cell of the notebook.
3. In the first cell, change `dataset_name` to `"sct_testing_data"`. Then, go to Kernel -> Restart & Run All.
    - The same output files should be generated, just for a different dataset.
    - These files should be checked against the `_keras.nii.gz` files in the last cell of the notebook.

### Dataset setup

In [None]:
from spinalcordtoolbox.scripts import sct_download_data
from spinalcordtoolbox.utils import __data_dir__
import shutil
import os

dataset_name = "sct_example_data"

sct_download_data.main(["-d", dataset_name, "-k"])

# In sct_testing_data, for some reason the T1-weighted file is called 't1w.nii.gz' (instead of 't1.nii.gz' like it 
# is in sct_example_data. 't1.nii.gz' is much easier to work with, because 't1' is used for both the folder name 
# and the `-c` contrast. So, we generate a copy called 't1.nii.gz'.
if dataset_name is "sct_testing_data":
    shutil.copyfile(os.path.join(__data_dir__, dataset_name, "t1", "t1w.nii.gz"),
                    os.path.join(__data_dir__, dataset_name, "t1", "t1.nii.gz"))

### Git branch setup

In [None]:
import subprocess

branch = subprocess.run("git rev-parse --abbrev-ref HEAD", capture_output=True, shell=True).stdout
if b"jn/3735-replace-tensorflow-with-onnxruntime" in branch:
    suffix = "onnx"
else:
    suffix = "keras"
    
print(f"Using {suffix} models on {branch} branch...")

### Inference using sct_deepseg_sc models

In [None]:
import os

import onnxruntime as ort

from spinalcordtoolbox.scripts import sct_deepseg_sc, sct_dmri_separate_b0_and_dwi
from spinalcordtoolbox.utils import __data_dir__, sct_dir_local_path
from spinalcordtoolbox.image import Image
from spinalcordtoolbox.deepseg_sc.core import heatmap
from spinalcordtoolbox.resampling import resample_nib

for contrast in ["dwi", "t1", "t2", "t2s"]:
    # 1. Isolate spinal cord segmentation (to evaluate 2D/3D spinal cord segmentation models)
    if contrast is "dwi":
        # The reason "dwi" is done separately is because it has some quirks:
        #   - no '3d kernel' model
        #   - folder name ("dmri") is different from contrast name ("dwi")
        #   - requires preprocessing (4d image -> 3d image)
        path_data = os.path.join(__data_dir__, dataset_name, "dmri")
        sct_dmri_separate_b0_and_dwi.main(["-i", os.path.join(path_data, "dmri.nii.gz"),
                                           "-bvec", os.path.join(path_data, "bvecs.txt"),
                                           "-ofolder", path_data])
        path_in = os.path.join(path_data, "dmri_dwi_mean.nii.gz")
        path_out = os.path.join(path_data, f"{contrast}_seg_2d_{suffix}.nii.gz")
        sct_deepseg_sc.main(["-i", path_in, "-c", contrast, "-kernel", "2d", 
                             "-centerline", "cnn", "-ofolder", path_data, "-o", path_out])
    else:
        # All other contrasts can be processed using the same steps
        path_data = os.path.join(__data_dir__, dataset_name, contrast)
        path_in = os.path.join(path_data, f"{contrast}.nii.gz")
        for kernel_type in ["2d", "3d"]:
            path_out = os.path.join(path_data, f"{contrast}_seg_{kernel_type}_{suffix}.nii.gz")
            sct_deepseg_sc.main(["-i", path_in, "-c", contrast, "-kernel", kernel_type, 
                                 "-centerline", "cnn", "-ofolder", path_data, "-o", path_out])
    
    # 2. Isolate centerline heatmaps (to evaluate centerline detection models)
    # - NB: This section involves quite a lot of copying and pasting from SCT's source code,
    #       in order to call the `heatmap()` function and access the 'im_heatmap' variable directly.
    im = Image(path_in)
    im.change_orientation('RPI')
    im_res = resample_nib(im, new_size=[0.5, 0.5, im.dim[6]], new_size_type='mm', interpolation='linear')
    dct_patch_ctr = {'t2': {'size': (80, 80), 'mean': 51.1417, 'std': 57.4408},
                     't2s': {'size': (80, 80), 'mean': 68.8591, 'std': 71.4659},
                     't1': {'size': (80, 80), 'mean': 55.7359, 'std': 64.3149},
                     'dwi': {'size': (80, 80), 'mean': 55.744, 'std': 45.003}}
    dct_params_ctr = {'t2': {'features': 16, 'dilation_layers': 2},
                      't2s': {'features': 8, 'dilation_layers': 3},
                      't1': {'features': 24, 'dilation_layers': 3},
                      'dwi': {'features': 8, 'dilation_layers': 2}}
    if suffix is "onnx":
        ctr_model_fname = sct_dir_local_path('data', 'deepseg_sc_models', '{}_ctr.onnx'.format(contrast))
        ort_sess = ort.InferenceSession(ctr_model_fname)
        im_heatmap, z_max = heatmap(im=im_res,
                                    contrast_type=contrast,
                                    model=ort_sess,
                                    patch_shape=dct_patch_ctr[contrast]['size'],
                                    mean_train=dct_patch_ctr[contrast]['mean'],
                                    std_train=dct_patch_ctr[contrast]['std'])
    else:
        from spinalcordtoolbox.deepseg_sc.cnn_models import nn_architecture_ctr
        ctr_model_fname = sct_dir_local_path('data', 'deepseg_sc_models', '{}_ctr.h5'.format(contrast))
        ctr_model = nn_architecture_ctr(height=dct_patch_ctr[contrast]['size'][0],
                                        width=dct_patch_ctr[contrast]['size'][1],
                                        channels=1,
                                        classes=1,
                                        features=dct_params_ctr[contrast]['features'],
                                        depth=2,
                                        temperature=1.0,
                                        padding='same',
                                        batchnorm=True,
                                        dropout=0.0,
                                        dilation_layers=dct_params_ctr[contrast]['dilation_layers'])
        ctr_model.load_weights(ctr_model_fname)
        im_heatmap, z_max = heatmap(im=im_res,
                                    model=ctr_model,
                                    patch_shape=dct_patch_ctr[contrast]['size'],
                                    mean_train=dct_patch_ctr[contrast]['mean'],
                                    std_train=dct_patch_ctr[contrast]['std'])
    im_heatmap.save(os.path.join(path_data, f"{contrast}_ctr_heatmap_{suffix}.nii.gz"))

### Inference using sct_deepseg_gm models

In [None]:
import os

from spinalcordtoolbox.scripts import sct_deepseg_gm

# NB: This code is the same regardless of whether the model is Keras or ONNX,
#     because the branch changes will take care of everything.
for contrast in ["t2s"]:
    for model in ["large", "challenge"]:
        path_data = os.path.join(__data_dir__, dataset_name, contrast)
        path_in = os.path.join(path_data, f"{contrast}.nii.gz")
        path_out = os.path.join(path_data, f"{contrast}_seg_gm_{model}_{suffix}.nii.gz")
        sct_deepseg_gm.main(['-i', path_in, '-m', model, '-o', path_out])

### Inference using sct_deepseg_lesion models

In [None]:
import os

import numpy as np
import nibabel as nib

from spinalcordtoolbox.scripts import sct_deepseg_lesion
from spinalcordtoolbox.deepseg_lesion.core import segment_3d
from spinalcordtoolbox.utils import __sct_dir__

for contrast in ["t2", "t2_ax", "t2s"]:
    if contrast is "t2s":
        path_data = os.path.join(__data_dir__, dataset_name, "t2s")
    elif contrast in ["t2", "t2_ax"]:
        path_data = os.path.join(__data_dir__, dataset_name, "t2")
    path_in = os.path.join(path_data, f"{contrast}_lesion.nii.gz")
    path_out = os.path.join(path_data, f"{contrast}_seg_lesion_{suffix}.nii.gz")
    
    # create fake data containing:
    # - Spinal cord voxels
    # - CSF voxels
    # - A fake lesion
    data = np.zeros((48, 48, 96))
    xx, yy = np.mgrid[:48, :48]
    circle = (xx - 24) ** 2 + (yy - 24) ** 2
    # iterating slice by slice in z axis to create a cylindrical-shaped CSF + SC
    for zz in range(data.shape[2]): 
        data[:,:,zz] += np.logical_and(circle < 400, circle >= 200) * 2400  # intensity = CSF
        data[:,:,zz] += (circle < 200)                              * 500   # intensity = SC
    data[16:22, 16:22, 64:90]                                       = 1000  # intensity = fake lesion
    # NB: While these shapes and contrast values are appropriate for the "t2" contrast,
    #     they likely aren't appropriate for "t2_ax" (b/c orientation) and "t2s" (b/c intensity values)
    #     So, this step should be amended for those contrast options.
    
    # create image using data and save to file
    affine = np.eye(4)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())
    img.save(path_in)
    
    # NB: We can't run the full sct_deepseg_lesion and compare the outputs unless we create 
    #     more realistic lesion data, as the CLI script uses additional steps (e.g resampling)
    #     Otherwise, we would run this command instead:
    # sct_deepseg_lesion.main(["-i", path_in, "-c", contrast, "-ofolder", path_data])
    
    # So instead, we segment fake lesion data by calling `segment_3d()` directly
    if suffix is "onnx":
        model_path = os.path.join(__sct_dir__, 'data', 'deepseg_lesion_models', f'{contrast}_lesion.onnx')
    else:
        model_path = os.path.join(__sct_dir__, 'data', 'deepseg_lesion_models', f'{contrast}_lesion.h5')
    seg = segment_3d(model_path, contrast, img.copy())
    seg.save(path_out)

### Comparing Keras-generated files to ONNX-generated files

In [None]:
import os

from spinalcordtoolbox.image import Image

path_data = os.path.join(__data_dir__, dataset_name)
data_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(path_data) for f in filenames]

for filepath in data_files:
    if "_onnx" in filepath:
        keras_filepath = filepath.replace("_onnx.nii.gz", "_keras.nii.gz")
        if not os.path.isfile(keras_filepath):
            print(f"{filepath} present, but missing corresponding {keras_filepath}!")
        else:
            data_onnx = Image(filepath).data
            data_keras = Image(keras_filepath).data
            print(f"Checking {filepath}... {(data_onnx==data_keras).all()}")
            
    elif "_keras" in filepath:
        onnx_filepath = filepath.replace("_keras.nii.gz", "_onnx.nii.gz")
        if not os.path.isfile(onnx_filepath):
            print(f"{filepath} present, but missing corresponding {onnx_filepath}!")
        else:
            pass
    else:      
        pass