<a href="https://colab.research.google.com/github/stratis-forge/radiomics-workflows/blob/main/demo_OMT_distance_between_dose_distributions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction
This notebook demonstrates optimal mass transport (OMT) distance calculation using CERR.

## Requirements
* GNU Octave with `statistics`, `io`, and `image` packages
* CERR
* Python packages for visualization
Note: Installing these tools may incur a one-time extra runtime overhead.

## I/O
In this example, we use sample dose plans in CERR's native `planC` format (imported from DICOM) to compute OMT distances.

----


## Install dependencies


### Install latest Octave compile



In [1]:
# %%capture
# ! apt-get update
# ! apt-get install libgraphicsmagick++1-dev libsuitesparse-dev libqrupdate1 \
# libfftw3-3 gnuplot zsh openjdk-8-jdk

In [2]:
# Download latest compiled octave package 
# def get_octave(root_path):
#   os.chdir(root_path)
#   with urllib.request.urlopen("https://api.github.com/repos/cerr/octave-colab/releases/latest") as url:
#       data = json.loads(url.read().decode())
#   fname = data['assets'][0]['name']
#   requrl = data['assets'][0]['browser_download_url']
#   urllib.request.urlretrieve(requrl, fname)
#   # Unzip, identify octave folder name
#   !tar xf {fname}
#   top_folder = !tar tf {fname} | head -1
#   octave_folder = top_folder[0][:-1]
#   octave_path = os.path.join(root_path,octave_folder)
#   return octave_path

In [3]:
%%capture
# Set path to Octave exectuable
# import os, urllib.request, json
# octave_path = get_octave('/content')
os.environ['OCTAVE_EXECUTABLE'] = octave_path + '/bin/octave-cli' #Replace with OCTAVE_EXECUTABLE path
os.environ['PATH'] = octave_path + '/bin:' + os.environ['PATH']   #Replace with Octave path

### Install Python-Octave bridge

In [4]:
%%capture
# ! pip3 install octave_kernel
# ! pip3 install oct2py==5.3.0

### Download CERRx



In [5]:
%%capture
# !git clone --single-branch --branch octave_dev https://www.github.com/cerr/CERR.git
# import os   
# currDir = os.getcwd()
# os.chdir("/content/CERR")
# !git checkout c2b65179da40622bc7b095f679edd17f5ebc681e
# os.chdir(currDir)

## Sample OMT distance calculations

### Load required Octave packages

In [6]:
%load_ext oct2py.ipython
from oct2py import octave

In [7]:
%%capture
%%octave
pkg load image
pkg load io
pkg load statistics

In [8]:
#Load sample data
octave.addpath(octave.genpath('/content/CERR'))

sampleData = '/content/CERR/Unit_Testing/data_for_cerr_tests/CERR_plans/' + \
             'lung_ex1_20may03.mat.bz2'

%octave_push sampleData

planC = octave.loadPlanC(sampleData,octave.tempdir());
planC = octave.updatePlanFields(planC);
planC = octave.quality_assure_planC(sampleData,planC);

**Example 1. Sanity check (distance between identical distributions)**

In [9]:
doseNum1 = 1;
doseNum2 = 1;
structNum = 3;       #GTV
gamma = 0.1;
downsampleIndex = 3;

dist = octave.calcOMTDoseDistance(doseNum1,doseNum2,structNum,gamma,downsampleIndex,planC);
print('OMT distance = ' + str(dist))

Downsampling...
3
done downsampling.
[KKT error,dual error,cost,1/(||lambda||_infy+1),step_size,lspar3]=
[0.001034, 0.128902, 0.000002, 0.998856, 0.000001, 0.750000]
Elapsed time is 6.92625 seconds.
layer 1 Computed cost is: 0.000002, iteration number=11
[KKT error,dual error,cost,1/(||lambda||_infy+1),step_size,lspar3]=
[0.110454, 7.735743, 0.000027, 0.951295, 0.000001, 1.125000]
Elapsed time is 103.029 seconds.
layer 2 Computed cost is: 0.000027, iteration number=6
OMT distance = 2.6897258431798253e-05


**Example 2. Distance between two dose distributions in region of interest**




In [10]:
doseNum1 = 1;
doseNum2 = 2;
structNum = 3; 
gamma = 0.1;
downsampleIndex = 3;

dist = octave.calcOMTDoseDistance(doseNum1,doseNum2,structNum,gamma,downsampleIndex,planC);
print('OMT distance = ' + str(dist))

Downsampling...
3
done downsampling.
[KKT error,dual error,cost,1/(||lambda||_infy+1),step_size,lspar3]=
[0.396779, 0.523316, 0.010361, 0.837719, 0.000001, 0.333333]
Elapsed time is 6.64365 seconds.
layer 1 Computed cost is: 0.010361, iteration number=9
[KKT error,dual error,cost,1/(||lambda||_infy+1),step_size,lspar3]=
[0.009116, 1.054535, 0.010947, 0.967930, 0.000001, 0.222222]
Elapsed time is 129.6 seconds.
layer 2 Computed cost is: 0.010947, iteration number=10
OMT distance = 0.010946641919002586


## Display ROI and dose distributions

In [11]:
from oct2py import octave
%octave_push planC structNum doseNum1 doseNum2 dist

In [12]:
%%octave 

addpath(genpath('/content/CERR'));

# Get scan array
indexS = planC{end};
scanNum = getStructureAssociatedScan(structNum,planC);
scan3M = getScanArray(scanNum,planC);
CToffset = planC{indexS.scan}(1).scanInfo(1).CTOffset;
scan3M = double(scan3M - CToffset);

# Get dose arrays on CT grid
dose1M = getDoseOnCT(doseNum1, scanNum, 'normal', planC);
dose2M = getDoseOnCT(doseNum2, scanNum, 'normal', planC);

# Crop to slices of interest
mask3M = getStrMask(structNum,planC);
[rMin,rMax,cMin,cMax,sMin,sMax] = compute_boundingbox(mask3M);
scan3M = scan3M(:,:,sMin:sMax);
mask3M = mask3M(:,:,sMin:sMax);
dose1M = dose1M(:,:,sMin:sMax);
dose2M = dose2M(:,:,sMin:sMax);
doseC = {dose1M,dose2M};

In [13]:
%octave_pull doseNum1 doseNum2 structNum doseC scan3M mask3M dist

In [14]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import ipywidgets as widgets
import numpy as np
from matplotlib import cm
from IPython.display import clear_output
from functools import partial
from skimage import measure

clear_output(wait=True)    

def window_image(image, window_center, window_width):
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    window_image = image.copy()
    window_image[window_image < img_min] = img_min
    window_image[window_image > img_max] = img_max
    
    return window_image

def show_roi(ax,scanM,maskM):
    #Show region of interest
    window_center = 0
    window_width = 300
    extent = 0, 511, 0, 511
   
    im0 = ax.imshow(scanM, cmap=plt.cm.gray, alpha=1,
                    interpolation='none', extent=extent)
    
    contours = measure.find_contours(maskM, 0.5)
    for contour in contours:
        ax.plot(contour[:, 1], contour[:, 0], linewidth=2, color='red')

    ax.set_xticks([])
    ax.set_yticks([])
    ax.title.set_text('Region of interest')    

def show_dose_overlay(scan3M, mask3M, doseNum1, doseNum2, dist, slcNum):

    clear_output(wait=True)
    print('Slice '+str(slcNum))
    if 'fig' in locals():
        fig.remove()
   
    # Define color maps & window levels
    window_center = 0
    window_width = 300
    jet=plt.cm.jet
    jet_map = [jet(i) for i in range(jet.N)]
    #jet_map[0] = [0.0,0.0,0.0,1.0]
    dose_cmap = jet.from_list('Custom cmap', jet_map, jet.N)
    dose_cmap.set_under('k', alpha=0)
    
    fig, ax = plt.subplots(1,3)
    fig.set_size_inches(20, 12)
    extent = 0, 511, 0, 511

    # Show ROI
    windowed_img = window_image(scan3M[:,:,slcNum-1],
                                window_center,window_width)
    maskM = mask3M[:,:,slcNum-1]
    show_roi(ax[0],windowed_img,maskM)

    # Show dose1 overlay
    im1 = ax[1].imshow(windowed_img, cmap=plt.cm.gray, alpha=1,
                    interpolation='none', extent=extent)
    dose1M = doseC[0][doseNum1-1]
    d1 = ax[1].imshow(dose1M[:,:,slcNum-1], 
                        cmap=dose_cmap, alpha=.4, extent=extent,
                        interpolation='none')  
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    ax[1].title.set_text('Dose 1')

    # Show dose2 overlay
    im2 = ax[2].imshow(windowed_img, cmap=plt.cm.gray, alpha=1,
                    interpolation='none', extent=extent)
    
    dose2M = doseC[0][doseNum2-1];
    d2 = ax[2].imshow(dose2M[:,:,slcNum-1], 
                        cmap=dose_cmap, alpha=.4, extent=extent,
                        interpolation='none',clim=[0,90])  
    
    ax[2].set_xticks([])
    ax[2].set_yticks([])
    ax[2].title.set_text('Dose 2')

    cax = fig.add_axes([0.95,0.32,0.03,0.36]) 
    clb = fig.colorbar(d2, cax=cax)
    clb.ax.set_title('Gy')

    fig.subplots_adjust(wspace=0.3)
    txt = 'OMT distance = ' + str("{:.4f}".format(dist))
    fig.text(.66, .25, txt, fontsize = 14, fontweight = 'bold', ha='center')
    
    plt.show()   

slice_slider = widgets.IntSlider(min=1,max=20,step=1)
outputSlc = widgets.Output()
display(slice_slider, outputSlc) 

def update_display(change):
  global scan3M, mask3M, doseNum1, doseNum2, dist
  with outputSlc:
    show_dose_overlay(scan3M, mask3M, int(doseNum1), int(doseNum2), \
                      dist, change['new'])

slice_slider.observe(update_display, names='value')

IntSlider(value=1, max=20, min=1)

Output()