# Keras vs. ONNX model output comparison

This notebook needs to be run separately on two different installations of SCT:

1. Installed from the `master` branch (with Tensorflow/Keras installed)
2. Installed from the `jn/3735-replace-tensorflow-with-onnxruntime` branch (with no TF/Keras installed)

In [1]:
from spinalcordtoolbox.scripts import sct_download_data

sct_download_data.main(["-d", "sct_example_data", "-k"])

Status: 100%|██████████| 44.3M/44.3M [00:07<00:00, 5.55MB/s]


Done!



In [14]:
import subprocess

if (b"jn/3735-replace-tensorflow-with-onnxruntime" in 
    subprocess.run("git rev-parse --abbrev-ref HEAD", capture_output=True).stdout):
    suffix = "onnx"
else:
    suffix = "keras"

### sct_deepseg_sc models

In [3]:
import os

import onnxruntime as ort

from spinalcordtoolbox.scripts import sct_deepseg_sc, sct_dmri_separate_b0_and_dwi
from spinalcordtoolbox.utils import __data_dir__, sct_dir_local_path
from spinalcordtoolbox.image import Image
from spinalcordtoolbox.deepseg_sc.core import heatmap
from spinalcordtoolbox.resampling import resample_nib

for contrast in ["dwi", "t1", "t2", "t2s"]:
    # 1. Isolate spinal cord segmentation (to evaluate 2D/3D spinal cord segmentation models)
    if contrast is "dwi":
        # "dwi" is done separately because it has some quirks:
        #   - no '3d kernel' model
        #   - folder name ("dmri") is different from contrast name ("dwi")
        #   - requires preprocessing (4d image -> 3d image)
        path_data = os.path.join(__data_dir__, "sct_example_data", "dmri")
        sct_dmri_separate_b0_and_dwi.main(["-i", os.path.join(path_data, "dmri.nii.gz"),
                                           "-bvec", os.path.join(path_data, "bvecs.txt"),
                                           "-ofolder", path_data])
        path_in = os.path.join(path_data, "dmri_dwi_mean.nii.gz")
        path_out = os.path.join(path_data, f"{contrast}_seg_2d_{suffix}.nii.gz")
        sct_deepseg_sc.main(["-i", path_in, "-c", contrast, "-kernel", "2d", 
                             "-centerline", "cnn", "-ofolder", path_data, "-o", path_out])
    else:
        path_data = os.path.join(__data_dir__, "sct_example_data", contrast)
        path_in = os.path.join(path_data, f"{contrast}.nii.gz")
        for kernel_type in ["2d", "3d"]:
            path_out = os.path.join(path_data, f"{contrast}_seg_{kernel_type}_{suffix}.nii.gz")
            sct_deepseg_sc.main(["-i", path_in, "-c", contrast, "-kernel", kernel_type, 
                                 "-centerline", "cnn", "-ofolder", path_data, "-o", path_out])
    
    # 2. Isolating centerline heatmaps (to evaluate centerline detection models)
    # This section involves some copying and pasting in order to access im_heatmap directly
    im = Image(path_in)
    im_res = resample_nib(im, new_size=[0.5, 0.5, im.dim[6]], new_size_type='mm', interpolation='linear')
    ctr_model_fname = sct_dir_local_path('data', 'deepseg_sc_models', '{}_ctr.onnx'.format(contrast))
    ort_sess = ort.InferenceSession(ctr_model_fname)
    dct_patch_ctr = {'t2': {'size': (80, 80), 'mean': 51.1417, 'std': 57.4408},
                     't2s': {'size': (80, 80), 'mean': 68.8591, 'std': 71.4659},
                     't1': {'size': (80, 80), 'mean': 55.7359, 'std': 64.3149},
                     'dwi': {'size': (80, 80), 'mean': 55.744, 'std': 45.003}}
    im_heatmap, z_max = heatmap(im=im_res,
                                contrast_type=contrast,
                                model=ort_sess,
                                patch_shape=dct_patch_ctr[contrast]['size'],
                                mean_train=dct_patch_ctr[contrast]['mean'],
                                std_train=dct_patch_ctr[contrast]['std'])
    im_heatmap.save(os.path.join(path_data, f"{contrast}_ctr_heatmap_{suffix}.nii.gz"))


Input parameters:
  input file ............c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\dmri\dmri.nii.gz
  bvecs file ............c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\dmri\bvecs.txt
  bvals file ............
  average ...............1

Copy files into temporary folder...
cp c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\dmri\bvecs.txt C:\Users\Joshua\AppData\Local\Temp\sct-20220321145834.524073-dmri_separate-cmuqub69\bvecs

Get dimensions data...
.. 96 x 36 x 15 x 35


Identify b=0 and DWI images...
  Transpose bvecs...
  Number of b=0: 5 [0, 31, 32, 33, 34]
  Number of DWI: 30 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]

Split along T dimension...

Merge b=0...

Average b=0...

Average DWI...

Generate output files...


To avoid intensity overflow due to convertion to +uint8+, intensity will be rescaled to the maximum quantization scale



Remove temporary files...
rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321145834.524073-dmri_separate-cmuqub69

Finished! Elapsed time: 1s
rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321145835.438370-cmt8dw_5


Compute shape analysis: 100%|################| 15/15 [00:00<00:00, 286.30iter/s]
To avoid intensity overflow due to convertion to +uint8+, intensity will be rescaled to the maximum quantization scale


rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321145835.332108-yua4n3u4


To avoid intensity overflow due to convertion to +uint8+, intensity will be rescaled to the maximum quantization scale


rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321145902.874241-dd8ujo6q


Compute shape analysis: 100%|##############| 175/175 [00:00<00:00, 183.89iter/s]
Found isolated voxels on slice 0, Removing them


rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321145836.378817-qcq5i37g


To avoid intensity overflow due to convertion to +uint8+, intensity will be rescaled to the maximum quantization scale


rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321145947.608961-uugvltco


Compute shape analysis: 100%|##############| 196/196 [00:01<00:00, 176.88iter/s]


rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321145921.855755-tkhclfcf


To avoid intensity overflow due to convertion to +uint8+, intensity will be rescaled to the maximum quantization scale
To avoid intensity overflow due to convertion to +uint8+, intensity will be rescaled to the maximum quantization scale


rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321150041.918643-pefm80xf


Filling a hole in the segmentation around z_slice #:6
Filling a hole in the segmentation around z_slice #:14
Filling a hole in the segmentation around z_slice #:17
Filling a hole in the segmentation around z_slice #:31
Filling a hole in the segmentation around z_slice #:35
Filling a hole in the segmentation around z_slice #:39
No properties for slice: 6
No properties for slice: 7
No properties for slice: 8
No properties for slice: 9
No properties for slice: 10
No properties for slice: 11
No properties for slice: 15
No properties for slice: 16
No properties for slice: 18
No properties for slice: 19
No properties for slice: 20
No properties for slice: 21
No properties for slice: 22
No properties for slice: 23
No properties for slice: 24
No properties for slice: 25
No properties for slice: 26
No properties for slice: 27
No properties for slice: 32
No properties for slice: 33
No properties for slice: 36
No properties for slice: 40
No properties for slice: 41
Compute shape analysis: 100%|##

rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321150033.653294-7ofmttbb




rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321150058.409462-zrrfrfou


Compute shape analysis: 100%|##############| 206/206 [00:00<00:00, 221.08iter/s]


rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321150050.320015-ssajy2gw




rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321150120.302605-19sg9s2i


No properties for slice: 1
Compute shape analysis: 100%|################| 27/27 [00:00<00:00, 143.00iter/s]


rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321150118.680486-6iavm3xi




rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321150124.059029-1pijt9oa


Compute shape analysis: 100%|################| 25/25 [00:00<00:00, 134.44iter/s]


rm -rf C:\Users\Joshua\AppData\Local\Temp\sct-20220321150122.405088-48zrwx29




### sct_deepseg_gm models

In [4]:
import os

from spinalcordtoolbox.scripts import sct_deepseg_gm

for contrast in ["t2s"]:
    for model in ["large", "challenge"]:
        path_data = os.path.join(__data_dir__, "sct_example_data", contrast)
        path_in = os.path.join(path_data, f"{contrast}.nii.gz")
        path_out = os.path.join(path_data, f"{contrast}_seg_gm_{model}_{suffix}.nii.gz")
        sct_deepseg_gm.main(['-i', path_in, '-m', model, '-o', path_out])

### sct_deepseg_lesion models

In [6]:
import os

import numpy as np
import nibabel as nib

from spinalcordtoolbox.scripts import sct_deepseg_lesion
from spinalcordtoolbox.deepseg_lesion.core import segment_3d
from spinalcordtoolbox.utils import __sct_dir__

for contrast in ["t2", "t2_ax", "t2s"]:
    if contrast.startswith("t2s"):
        path_data = os.path.join(__data_dir__, "sct_example_data", "t2s")
    else:
        path_data = os.path.join(__data_dir__, "sct_example_data", "t2")
    path_in = os.path.join(path_data, f"{contrast}_lesion.nii.gz")
    path_out = os.path.join(path_data, f"{contrast}_seg_lesion_{suffix}.nii.gz")
    
    # create fake lesion data
    data = np.zeros((48, 48, 96))
    xx, yy = np.mgrid[:48, :48]
    circle = (xx - 24) ** 2 + (yy - 24) ** 2
    for zz in range(data.shape[2]):
        data[:,:,zz] += np.logical_and(circle < 400, circle >= 200) * 2400 # CSF
        data[:,:,zz] += (circle < 200) * 500 # SC
    data[16:22, 16:22, 64:90] = 1000 # fake lesion
    affine = np.eye(4)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())
    img.save(path_in)
    
    # Segment fake lesion data
    model_path = os.path.join(__sct_dir__, 'data', 'deepseg_lesion_models', f'{contrast}_lesion.onnx')
    seg = segment_3d(model_path, contrast, img.copy())
    seg.save(path_out)
    
    # NB: We can't compare the output of the the full sct_deepseg_lesion unless we 
    #     create actually realistic lesion data, as there are additional steps (e.g resampling)
    # sct_deepseg_lesion.main(["-i", path_in, "-c", contrast, "-ofolder", path_data])



### Comparing Keras-generated files to ONNX-generated files

1. Run the above cells on both the `master` branch (with Keras/Tensorflow installed), then copy the files to a safe location.
2. Re-install and re-run the notebook on the `jn/3735-replace-tensorflow-with-onnxruntime` branch.
3. Copy all of the generated files from both runs so that they exist in the same `data/sct_example_data` folder.
4. Finally, change `if False:` to `if True:` and run the cell below to compare the files from both runs.

In [13]:
import os

from spinalcordtoolbox.image import Image

if False:
    path_data = os.path.join(__data_dir__, "sct_example_data")
    data_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(path_data) for f in filenames]
    for filepath in data_files:
        if "_onnx" not in filepath:
            continue
        else:
            data_onnx = Image(filepath).data
            data_keras = Image(filepath.replace("_onnx.nii.gz", "_keras.nii.gz")).data
            print(f"Checking {filepath}... {(data_onnx==data_keras).all()}")

Checking c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\dmri\dwi_ctr_heatmap_onnx.nii.gz... True
Checking c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\dmri\dwi_seg_2d_onnx.nii.gz... True
Checking c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\t1\t1_ctr_heatmap_onnx.nii.gz... True
Checking c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\t1\t1_seg_2d_onnx.nii.gz... True
Checking c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\t1\t1_seg_3d_onnx.nii.gz... True
Checking c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\t2\t2_ax_seg_lesion_onnx.nii.gz... True
Checking c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\t2\t2_ctr_heatmap_onnx.nii.gz... True
Checking c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\t2\t2_seg_2d_onnx.nii.gz... True
Checking c:\users\joshua\repos\spinalcordtoolbox\data\sct_example_data\t2\t2_seg_3d_onnx.nii.gz... True
Checking c:\users\joshua\repos\spina