In [None]:
from itertools import chain
from pathlib import Path
from timeit import default_timer

import matlab.engine

from skrt import Patient
from skrt.registration import get_default_pfiles_dir, Registration, set_elastix_dir

# Define paths to patient data.
# Each dataset consists of the CT image used in treatment planning,
# and the associated structure set.
# In each structure set, IMNs are outlined on the patient's left or right only.
# The side on which the IMNs are outlined is indicated in the name of the
# directory containing the patient directories.
data_dir = Path("~/data/casscade").expanduser()
paths = sorted(list(chain(data_dir.glob('casscade*/[!.]*'))))

# Set path to Elastix installation.
elastix_dir = Path('~/sw/elastix-5.0.1').expanduser()
set_elastix_dir(elastix_dir)

# Obtain path to default Elastix parameter files.
pfiles_dir = get_default_pfiles_dir()

# Indicate whether to check data content.
check_data = False

In [None]:
# Load data.
patients = {"left": [], "right": []}
ids = []
for path in paths[0:4]:
    if ids and path.name not in ids:
        continue
    # Determine side on which IMNs are outlined from path.
    side = path.parent.name.split("_")[1]
    patients[side].append(Patient(path, unsorted_dicom=True))
    print(f"({side}) {patients[side][-1].id} - initialisation; "
          f"{patients[side][-1]._init_time:.2f} s; ")

In [None]:
# For each patient, check that number of images and number of structure sets
# are both equal to 1, then print names of all ROIs
# and names or OARs to be considered in auto-segmentation.
if check_data:
    all_oar_names = set()
    for side in patients:
        print(f"\n{side}: {len(patients[side])}")
        for patient in patients[side]:

            print(f"\n({side}) {patient.id} - study timestamp: {patient.studies[0].timestamp}")
            images = patient.combined_objs("image_types")
            assert 1 == len(images)

            structure_sets = images[0].structure_sets
            assert 1 <= len(structure_sets)
            print(len(structure_sets))

            for ss in structure_sets:
                roi_names = [roi_name for roi_name in sorted(ss.get_roi_names())]
                oar_names = [roi_name for roi_name in roi_names if True in
                             [oar_name in roi_name for oar_name in
                              ["heart", "Heart", "inm", "imn", "INM", "IMN"]]]
                print(f"    {ss.path}")
                print(f"    {roi_names}")
                print(f"    {oar_names}")
                all_oar_names = all_oar_names.union(oar_names)  

    all_oar_names = sorted(list(all_oar_names))
    if verbose:
        print(f"\n{all_oar_names}")

In [None]:
# For selected patient, create registration with relapse scan as fixed image
side = "left"
idx1 = 0
idx2 = 3
p1 = patients[side][idx1]
p2 = patients[side][idx2]

roi_names = {
    "heart": ["heart", "Heart"],
    "imn" : ["CTV IMN", "CTV INM", "CTVn_IMN", "IMN", "ctv imn"],
}
bands = {-1024:(None, 80)}
# Crop relapse scan to include structure-set ROIs plus margin.
ct1_original = p1.combined_objs("ct_images")[0]
ct1 = ct1_original.clone()
ct1.apply_selective_banding(bands)
ss1 = ct1.structure_sets[0].filtered_copy(names=roi_names, keep_renamed_only=True, copy_roi_data=False)
ss1.name = "ss1"
ct1.structure_sets = [ss1]

ct2_original = p2.combined_objs("ct_images")[0]
ct2 = ct2_original.clone()
ct2.apply_selective_banding(bands)
ss2 = ct2.structure_sets[0].filtered_copy(names=roi_names, keep_renamed_only=True, copy_roi_data=False)
ss2.name = "ss2"
ct2.structure_sets = [ss2]

# Crop planning scan to size of cropped relapse scan.
roi_to_align = "heart"
#roi_centroid = ss1[roi_to_align].get_centroid()
#lims = [(-80, 80), (-80, 80), (0, 160)]
#ct1.crop_about_point(roi_centroid, *lims) 
#ct1.crop_to_roi(ss1[roi_to_align], buffer=0)
roi_extents = ss1[roi_to_align].get_extents(buffer=10)
roi_extents[1][0] += -50
roi_extents[1][1] += 0
roi_extents[2][0] += -100
roi_extents[2][1] += 100
print(roi_extents)
ct1.crop(*roi_extents)
ct2.crop_to_image(ct1, alignment=roi_to_align)
ct1.view(images=ct2, rois=[ss1, ss2], mask=[ss["heart"].get_mask_image() for ss in [ss1, ss2]])

# Define registration strategy.
reg = Registration(
    Path(f"results/{p1.id}_{p2.id}"),
    fixed=ct1,
    moving=ct2,
    #fixed_mask=ss1[roi_to_align].get_mask_image(),
    #moving_mask=ss2[roi_to_align].get_mask_image(),
    initial_alignment = roi_to_align,
    pfiles={
        "translation": pfiles_dir / "MI_Translation.txt",
        #"rigid": pfiles_dir / "MI_Rigid.txt",
        #"affine": pfiles_dir / "MI_Affine.txt",
        "bspline": pfiles_dir / "MI_BSpline30.txt",
    },
    overwrite=True,
    capture_output=True,
    keep_tmp_dir = True,
)

In [None]:
# Perform registration.
reg.register()

In [None]:
# Show results at each step of registration.
for step in reg.steps:
    print(step)
    reg.view_result(step)

In [None]:
# Push ROI contours from relapse frame to planning frame.
ss1_transformed = {}
ss2_transformed = {}
for step in reg.steps:
    ss1_transformed[step] = reg.transform(ss1, step=step, transform_points=True)
    ss1_transformed[step].name = "ss1_points_transformed"
    ss2_transformed[step] = reg.transform(ss2, step=step)
    ss2_transformed[step].name = "ss2_masks_transformed"
    ss2_transformed[step].reset_contours(most_points=True)

In [None]:
step = reg.steps[-1]
ct2_original.view(rois = ss2 + ss1_transformed[step])

In [None]:
ct1_original.view(figsize=15, rois = ss1 + ss2_transformed[step])