# Import Libraries

In [None]:
# install required libraries
!pip install git+https://github.com/AAMIASoftwares-research/HCATNetwork.git@google-colab
!pip install git+https://github.com/AAMIASoftwares-research/DatasetUtilities.git@google-colab

In [None]:
# import general libraries
import os
import sys
import subprocess
from IPython.display import FileLink
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from typing import List
# import data libraries
import hcatnetwork
import HearticDatasetManager

# Define Global Variables and Utilities

In [None]:
# Set folder paths
ASOCA_FOLDER = "/kaggle/input/heart-data/ASOCA/ASOCA"
CAT08_FOLDER = "/kaggle/input/heart-data/CAT08/CAT08"

In [None]:
# global variables
NUM_PATIENTS_ASOCA = 40
NUM_PATIENTS_CAT08 = 8
CUBE_SIDE_MM = 12
CUBE_ISOTROPIC_SPACING_MM = 0.5
CUBE_SIDE_N_SAMPLES = int(CUBE_SIDE_MM * (1/CUBE_ISOTROPIC_SPACING_MM))
NUM_CUBES_PATIENT = 100

In [None]:
# define workspace utility functions

# call this function when in need to free RAM space
def check_variables():
    """
    Check the memory usage of variables in the workspace.
    Print variables and their memory sizes in descending order.
    """
    # get the memory size of each variable
    variable_sizes = {k: sys.getsizeof(v) for k, v in locals().items() if not k.startswith('__')}
    # sort the variables based on their memory size
    sorted_variables = sorted(variable_sizes.items(), key=lambda x: x[1], reverse=True)
    # print the variables and their memory sizes in descending order
    for var, size in sorted_variables:
        print(f"{var}: {size} bytes")

# save anything via pickle
def save(item, name: str, path="/kaggle/working/"):
    """
    Save an item using pickle.

    Parameters:
        item: The item to be saved.
        name (str): The name of the file.
        path (str): The path where the file will be saved (default: "/kaggle/working/").
    """
    item_file = path + name
    with open(item_file, 'wb') as file:
        pickle.dump(item, file)

# download item as zip
def download_file(source_path: str, download_file_name: str, output_path="/kaggle/working/"):
    """
    Create a zip file from the specified source path and provide a download link.
    
    Parameters:
        source_path (str): The path to the source file or directory to be zipped.
        download_file_name (str): The name of the zip file and download link.
        output_path (str): The output path for the zip file (default: "/kaggle/working/").
    """
    # save the current working directory
    current_working_directory = os.getcwd()  
    os.chdir(output_path)

    try:
        zip_name = f"{download_file_name}.zip"
        command = f"zip {zip_name} {source_path} -r"
        result = subprocess.run(command, shell=True, capture_output=True, text=True)
        if result.returncode != 0:
            raise RuntimeError(f"Unable to run zip command! Error: {result.stderr}")

        display(FileLink(zip_name))
    finally:
        # restore the original working directory
        os.chdir(current_working_directory)  

In [None]:
# define graphs utility functions 

# compute maximum, minimum and average radius for graphs nodes
def compute_r_stats(graph: hcatnetwork.graph.graph.SimpleCenterlineGraph):
    """
    Given a SimpleCenterlineGraph, calculate statistics on node radii.

    Parameters:
        graph (SimpleCenterlineGraph): The graph containing node radii.

    Returns:
        tuple: A tuple containing the maximum, minimum, and average radius.
    """
    minR = 10
    maxR = 0
    avgR = 0
    for node_id in graph.nodes:
        r = graph.nodes[node_id]["r"]
        avgR += r
        if(r>maxR): maxR=r
        elif(r<minR): minR = r

    avgR /= len(graph.nodes)
    
    return maxR, minR, avgR

# save for each pair of Ostia the corresponding coordinates and radius
def get_ostia_data(graph: hcatnetwork.graph.graph.SimpleCenterlineGraph):
    """
    For a given SimpleCenterlineGraph, get data for pairs of Ostia nodes.

    Parameters:
        graph (SimpleCenterlineGraph): The graph containing coronary Ostia nodes.

    Returns:
        tuple: A tuple containing node IDs, coordinates, and radii for two Ostia nodes.
    """
    ids = graph.get_coronary_ostia_node_id()
    coord_ostium_1 = np.array([graph.nodes[ids[0]]['x'],graph.nodes[ids[0]]['y'],graph.nodes[ids[0]]['z'],
                               np.array([graph.nodes[ids[0]]['r']])])
    coord_ostium_2 = np.array([graph.nodes[ids[1]]['x'],graph.nodes[ids[1]]['y'],graph.nodes[ids[1]]['z'],
                               np.array([graph.nodes[ids[1]]['r']])])
    return ids, coord_ostium_1, coord_ostium_2

# generates coordinate matrices from x,y,z sampled points. Flattening and trasposing is then performed to generate (Nx3) arrays
def get_cube_sample_points(center: np.ndarray, side_mm: float, n_samples_per_side: int):
    """
    Sample a cube centered at the given point with a specified number of points per side.

    Parameters:
        center (np.ndarray): The center coordinates of the cube.
        side_mm (float): The side length of the cube in millimeters.
        n_samples_per_side (int): The number of sample points per side.

    Returns:
        np.ndarray: A (Nx3) array containing the sampled points.
    """
    xs = np.linspace(center[0] - side_mm/2, center[0] + side_mm/2, n_samples_per_side)
    ys = np.linspace(center[1] - side_mm/2, center[1] + side_mm/2, n_samples_per_side)
    zs = np.linspace(center[2] - side_mm/2, center[2] + side_mm/2, n_samples_per_side)
    return np.array(np.meshgrid(xs, ys, zs)).reshape(3, -1).T

# returns samples as a (NxNxN) array
def cube_samples_to_array(samples: np.ndarray, n_samples_per_side: int) -> np.ndarray:
    """
    Convert sampled points from a cube to a 3D numpy array.

    Parameters:
        samples (np.ndarray): The sampled points.
        n_samples_per_side (int): The number of samples per side.

    Returns:
        np.ndarray: A 3D array representing the cube.
    """
    return samples.reshape(n_samples_per_side, n_samples_per_side, n_samples_per_side)

# returns directly the cube as numpy (NxNxN) array from center point in RAS coordinates
def get_input_data_from_vertex_ras_position(
        image: HearticDatasetManager.cat08.Cat08ImageCT|HearticDatasetManager.asoca.AsocaImageCT,
        position: np.ndarray,
        side_mm: float,
        n_samples_per_side: int,
        affine=np.eye(4)
    ) -> np.ndarray:
    """Get the input data from a vertex position expressed in RAS coordinates system.

    Parameters
    ----------
    image : HearticDatasetManager.cat08.Cat08ImageCT | HearticDatasetManager.asoca.AsocaImageCT
        The image from which to extract the data.
    position : numpy.ndarray
        The position of the cube center in RAS coordinates system.
    side_mm : float
        The side of the cube in mm.
    n_samples_per_side : int
        The number of samples per side.
    affine : np.ndarray, optional
        The affine transformation to apply to the position of the samples used to create the cube, by default np.eye(4) (which does nothing).
        This is useful in data augmemtation, if you want to rotate, flip, or do whatever operation
        on the cube sample points, you can do it by passing the affine transformation here.
        For example, HearticDatasetManager.affine.get_affine_3d_rotation_around_vector() will rotate the cube (see the function docs).
    """
    # Get the cube sample points
    cube_pos = get_cube_sample_points(position, side_mm, n_samples_per_side)
    # Apply transformation affine if any
    if affine is None:
        affine = np.eye(4)
    cube_pos = HearticDatasetManager.affine.apply_affine_3d(affine, cube_pos)
    # Sample the image
    samples = image.sample(cube_pos, interpolation="linear").T
    # Convert to ndarray
    cube_array = cube_samples_to_array(samples, n_samples_per_side)
    return cube_array

In [None]:
# define functions to load ASOCA and CAT08 data

# ASOCA
def load_ASOCA_data():
    """
    Load all patients' images and graphs for ASOCA dataset.

    Returns:
        tuple: A tuple containing lists of ASOCA images and graphs.
    """
    images = []
    graphs = []

    for i in tqdm(range(NUM_PATIENTS_ASOCA), desc="Loading ASOCA data"):
        # Load image and graph
        if (i < NUM_PATIENTS_ASOCA / 2):
            # Normal patients
            image_file = os.path.join(
                ASOCA_FOLDER,
                HearticDatasetManager.asoca.DATASET_ASOCA_IMAGES_DICT["Normal"][i]
            )

            graph_file = os.path.join(
                ASOCA_FOLDER,
                HearticDatasetManager.asoca.DATASET_ASOCA_GRAPHS_RESAMPLED_05MM_DICT["Normal"][i]
            )
        else:
            # Diseased patients
            image_file = os.path.join(
                ASOCA_FOLDER,
                HearticDatasetManager.asoca.DATASET_ASOCA_IMAGES_DICT["Diseased"][i - NUM_PATIENTS_ASOCA // 2]
            )

            graph_file = os.path.join(
                ASOCA_FOLDER,
                HearticDatasetManager.asoca.DATASET_ASOCA_GRAPHS_RESAMPLED_05MM_DICT["Diseased"][i - NUM_PATIENTS_ASOCA // 2]
            )

        image = HearticDatasetManager.asoca.AsocaImageCT(image_file)
        graph = hcatnetwork.io.load_graph(
            graph_file,
            output_type=hcatnetwork.graph.SimpleCenterlineGraph
        )
        # Convert graph coordinates to RAS
        for node_id in graph.nodes:
            old_coords = np.array(
                [graph.nodes[node_id]["x"], graph.nodes[node_id]["y"], graph.nodes[node_id]["z"]]
            )
            new_coords = HearticDatasetManager.affine.apply_affine_3d(image.affine_centerlines2ras, old_coords)
            graph.nodes[node_id]["x"] = new_coords[0]
            graph.nodes[node_id]["y"] = new_coords[1]
            graph.nodes[node_id]["z"] = new_coords[2]

        images.append(image)
        graphs.append(graph)

        # Simulating loading time for each iteration
        time.sleep(0.1)

    return images, graphs

# CAT08
def load_CAT08_data():
    """
    Load all patients' images and graphs for CAT08 dataset.

    Returns:
        tuple: A tuple containing lists of CAT08 images and graphs.
    """
    images = []
    graphs = []

    for i in tqdm(range(NUM_PATIENTS_CAT08), desc="Loading CAT08 data"):
        # Load image and graph
        image_file = os.path.join(
            CAT08_FOLDER,
            HearticDatasetManager.cat08.DATASET_CAT08_IMAGES[i]
        )

        graph_file = os.path.join(
            CAT08_FOLDER,
            HearticDatasetManager.cat08.DATASET_CAT08_GRAPHS_RESAMPLED_05MM[i]
        )

        image = HearticDatasetManager.cat08.Cat08ImageCT(image_file)
        graph = hcatnetwork.io.load_graph(
            graph_file,
            output_type=hcatnetwork.graph.SimpleCenterlineGraph
        )

        # Convert graph coordinates to RAS
        for node_id in graph.nodes:
            old_coords = np.array(
                [graph.nodes[node_id]["x"], graph.nodes[node_id]["y"], graph.nodes[node_id]["z"]]
            )
            new_coords = HearticDatasetManager.affine.apply_affine_3d(image.affine_centerlines2ras, old_coords)
            graph.nodes[node_id]["x"] = new_coords[0]
            graph.nodes[node_id]["y"] = new_coords[1]
            graph.nodes[node_id]["z"] = new_coords[2]

        images.append(image)
        graphs.append(graph)

        # Simulating loading time for each iteration
        time.sleep(0.1)

    return images, graphs

In [None]:
# define functions to extract volume patches from each patient's CT scan and the corresponding labels

# returns the data representing the input to the network
def get_labeled_patches(
    image: HearticDatasetManager.cat08.Cat08ImageCT|HearticDatasetManager.asoca.AsocaImageCT,
    graph: hcatnetwork.graph.graph.SimpleCenterlineGraph,
    rotation_random_sampling = None,
    rotation_local_sampling = None
):
    
    """ 
    The function takes as input a single image-graph pair and provides the extracted patches and labels by following these steps:
    
        1. Firstly it samples a point along the graph and applies a random translation to it.
        2. Then utility funcitons are used to sample a cube centered in such point.
        3. If the point lies close enough to one of the two Ostia, then label = 1, 0 otherwise.
        4. Finally, steps 1-2 are repeated by local sampling around the Ostium to ensure class 1 is represented. 
    
    Parameters:
        image (HearticDatasetManager.cat08.Cat08ImageCT | HearticDatasetManager.asoca.AsocaImageCT):
            The CT scan image data.
        graph (hcatnetwork.graph.graph.SimpleCenterlineGraph):
            The corresponding centerline graph.
        rotation_random_sampling (optional): Rotation parameters for random sampling. Default is None.
        rotation_local_sampling (optional): Rotation parameters for local sampling. Default is None.

    Returns:
        tuple: A tuple containing numpy arrays of extracted cubes and corresponding labels.
    """
    
    # compute radius statistics
    maxR, minR, avgR = compute_r_stats(graph)
    # retrieve ostia
    ids, coord_ostium_1, coord_ostium_2 = get_ostia_data(graph)
    
    cubes = []
    labels = []

    # firstly random sample along the graph
    for cube_i in tqdm(range(NUM_CUBES_PATIENT // 2), desc="Random Sampling"):
        # choose a random node of the graph
        node_id = np.random.choice(list(graph.nodes.keys()))
        # get the position of the node in RAS
        node_position = np.array([graph.nodes[node_id]["x"], graph.nodes[node_id]["y"], graph.nodes[node_id]["z"]])
        # select parameters for random translation
        r = np.random.uniform(-avgR*1.5, avgR*1.5)#translation vector
        theta = np.random.uniform(0, 2*np.pi)#xy plane angle
        phi = np.random.uniform(0, np.pi)#z to xy plane angle
        # apply the translation to the selected point
        node_position += np.array([r*np.sin(phi)*np.cos(theta), r*np.sin(phi)*np.sin(theta), r*np.cos(phi)]).reshape(3,1)
        if (rotation_random_sampling):
            # define rotation vector
            vector_axis_of_rotation = np.array([np.random.uniform(-1, 1) , np.random.uniform(-1, 1), np.random.uniform(-1, 1)])
            transformation_to_apply = HearticDatasetManager.affine.get_affine_3d_rotation_around_vector(
                vector=vector_axis_of_rotation,
                vector_source=node_position.reshape(3,1), # the center of the cube
                rotation=np.random.choice(range(90+1)), # max degree of rotation +1 
                rotation_units="deg"
            )
        else:
            transformation_to_apply = None
        # sample
        cube_array = get_input_data_from_vertex_ras_position(image,node_position,
                                                             CUBE_SIDE_MM,CUBE_SIDE_N_SAMPLES, 
                                                             affine=transformation_to_apply)
        # compute distances from ostia to assign labels
        dist1 = np.linalg.norm(coord_ostium_1[:3] - node_position)
        dist2 = np.linalg.norm(coord_ostium_2[:3] - node_position)
        label = 1 if (dist1<=1.2*coord_ostium_1[-1] or dist2<=1.2*coord_ostium_2[-1]) else 0

        cubes.append(cube_array)
        labels.append(label)
        # Simulating loading time for each iteration
        time.sleep(0.1)
        
    # then sample locally from the ostia applying augmentation via rotation if selected
    for cube_i in tqdm(range(NUM_CUBES_PATIENT // 2), desc="Local Sampling"):
        # select one of the two ostia
        node_id = np.random.choice(ids)
        # get the position of the node in RAS
        node_position = np.array([graph.nodes[node_id]["x"], graph.nodes[node_id]["y"], graph.nodes[node_id]["z"]])
        # select parameters for random translation
        r = np.random.uniform(-graph.nodes[node_id]["r"]*1.2, graph.nodes[node_id]["r"]*1.2)#translation vector
        theta = np.random.uniform(0, 2*np.pi)#xy plane angle
        phi = np.random.uniform(0, np.pi)#z to xy plane angle
        # apply the translation to the selected point
        node_position += np.array([r*np.sin(phi)*np.cos(theta), r*np.sin(phi)*np.sin(theta), r*np.cos(phi)]).reshape(3,1)
        if (rotation_local_sampling):
            # define rotation vector
            vector_axis_of_rotation = np.array([np.random.uniform(-1, 1) , np.random.uniform(-1, 1), np.random.uniform(-1, 1)])
            transformation_to_apply = HearticDatasetManager.affine.get_affine_3d_rotation_around_vector(
                vector=vector_axis_of_rotation,
                vector_source=node_position.reshape(3,1), # the center of the cube
                rotation=np.random.choice(range(90+1)), # max degree of rotation +1 
                rotation_units="deg"
            )
        else:
            transformation_to_apply = None
        # sample
        cube_array = get_input_data_from_vertex_ras_position(image,node_position,
                                                             CUBE_SIDE_MM,CUBE_SIDE_N_SAMPLES, 
                                                             affine=transformation_to_apply)

        cubes.append(cube_array)
        labels.append(1)

        # Simulating loading time for each iteration
        time.sleep(0.1)
    
    
    cubes = np.array(cubes)
    labels = np.array(labels)
        
    return cubes, labels

In [None]:
# define patient class able to store cubes and corresponding labels
class Patient:
    def __init__(self, cubes=None, labels=None):
        """
        Initialize a Patient instance.

        Parameters:
            cubes (List): List of cube data.
            labels (List): List of corresponding labels.
        """
        self.cubes = cubes if cubes is not None else []
        self.labels = labels if labels is not None else []

In [None]:
# define function to compute and visualize label distribution
def check_label_distribution(patients: List[Patient]):
    """
    Compute and visualize the label distribution for a list of patients.

    Parameters:
        patients (List[Patient]): List of Patient instances.
    """
    # extract labels
    labels_distribution = [patient.labels for patient in patients]
    # flatten the list of lists into a single list
    all_labels = [label for sublist in labels_distribution for label in sublist]

    # calculate counts
    label_counts = {0: all_labels.count(0), 1: all_labels.count(1)}
    # print counts
    for label, count in label_counts.items():
        print(f"Label {label}: {count} occurrences")

    # plot the label distribution
    plt.hist(all_labels, bins=[0, 1, 2], align='left', rwidth=0.8)
    plt.xlabel('Label')
    plt.ylabel('Count')
    plt.title('Label Distribution')
    plt.xticks([0, 1])
    plt.show()

# Data Loading and Volume Extraction

In [None]:
# load data
print("--- Data Loading: ---")
images_asoca, graphs_asoca = load_ASOCA_data()
images_cat08, graphs_cat08 = load_CAT08_data()

In [None]:
# perform train-validation-test splitting
# import train-test-split utility
from sklearn.model_selection import train_test_split

# split patients to ensure that within the same set the data is of the same patient
def split_patients(patients: List[Patient], validation_size=0.1, test_size=0.1, random_seed=42):
    """
    Split a list of Patient instances into train, validation, and test sets while keeping data for each patient together.

    Parameters:
        patients (list): List of Patient instances.
        validation_size (float): Percentage of data to include in the validation set.
        test_size (float): Percentage of data to include in the test set.
        random_seed (int or None): Seed for reproducibility.

    Returns:
        tuple: Three tuples (train_set, val_set, test_set), where each set is a list of Patient instances.
    """
    # split patients into train+validation and test
    patients_train_val, patients_test = train_test_split(patients, test_size=test_size, random_state=random_seed)

    # split the remaining patients into train and validation
    patients_train, patients_val = train_test_split(patients_train_val, test_size=validation_size, random_state=random_seed)
    
    return patients_train, patients_val, patients_test

# extract X (cubes) and y (labels) for each set
def get_train_val_test(patients_train: List[Patient], patients_val: List[Patient], patients_test: List[Patient]):
    """
    Extract X (cubes) and y (labels) for each set.

    Parameters:
        patients_train (list): List of train set Patient instances.
        patients_val (list): List of validation set Patient instances.
        patients_test (list): List of test set Patient instances.
        
    Returns:
        tuple: Six numpy arrays (X_train, y_train, X_val, y_val, X_test, y_test) representing
               the features and labels for the training, validation, and test sets.
    """
    
    # split data within patient sets
    X_train = np.array([patient.cubes for patient in patients_train])
    y_train = np.array([patient.labels for patient in patients_train])

    X_val = np.array([patient.cubes for patient in patients_val])
    y_val = np.array([patient.labels for patient in patients_val])

    X_test = np.array([patient.cubes for patient in patients_test])
    y_test = np.array([patient.labels for patient in patients_test])

    return X_train, y_train, X_val, y_val, X_test, y_test

In [None]:
# create empty Patient instances
patient_instances = []
for i in range(NUM_PATIENTS_ASOCA+NUM_PATIENTS_CAT08):
    patient_instance = Patient()
    patient_instances.append(patient_instance)
# split patients into train - validation - test
validation_size = 5
test_size = NUM_PATIENTS_CAT08
patients_train, patients_val, patients_test = split_patients(patient_instances, validation_size=validation_size, test_size=test_size)

In [None]:
# extract volumes: ASOCA is used to retrieve train and validation samples, while CAT08 is used entirely for testing 

# shuffle indexes for ASOCA volumes extraction
common_seed = 42
asoca_indexes = np.arange(NUM_PATIENTS_ASOCA)
np.random.seed(common_seed)
np.random.shuffle(asoca_indexes)

print("--- ASOCA Volumes Extraction: ---")
for i,patient_id in enumerate(asoca_indexes[:NUM_PATIENTS_ASOCA-validation_size]):
     # call the function to get labeled patches for the current patient
    print(f"-> Patient {patient_id +1}:")
    patients_train[i].cubes, patients_train[i].labels = get_labeled_patches(images_asoca[patient_id],
                                                                            graphs_asoca[patient_id],
                                                                            rotation_random_sampling=True,
                                                                            rotation_local_sampling=True
                                                                           )

for i,patient_id in enumerate(asoca_indexes[NUM_PATIENTS_ASOCA-validation_size:]):
     # call the function to get labeled patches for the current patient
    print(f"-> Patient {patient_id +1}:")
    patients_val[i].cubes, patients_val[i].labels = get_labeled_patches(images_asoca[patient_id],
                                                                        graphs_asoca[patient_id], 
                                                                        rotation_random_sampling=False,
                                                                        rotation_local_sampling=False
                                                                       )

print("--- CAT08 Volumes Extraction: ---")
# iterate over each CAT08 patient
for patient_id in range(NUM_PATIENTS_CAT08):
    # call the function to get labeled patches for the current patient
    print(f"-> Patient {patient_id +1}:")
    patients_test[patient_id].cubes, patients_test[patient_id].labels = get_labeled_patches(images_cat08[patient_id],
                                                                                            graphs_cat08[patient_id], 
                                                                                            rotation_random_sampling=False,
                                                                                            rotation_local_sampling=False
                                                                                           )

In [None]:
# check class proportion in train set
check_label_distribution(patients_train)

In [None]:
# check class proportion in validation set
check_label_distribution(patients_val)

In [None]:
# check class proportion in test set
check_label_distribution(patients_test)

In [None]:
# delete data to save RAM space
del images_asoca, graphs_asoca
del images_cat08, graphs_cat08

# Data Preparation

In [None]:
# apply clipping and standardization to the cubes
def preprocess_cubes(cubes: np.array, lower_clip=-800, upper_clip=1000):
    """
    Preprocesses a set of 3D cubes by applying clipping and standardization.

    Parameters:
    - cubes (np.array): An array of 3D cubes with shape (num_cubes, cube_size, cube_size, cube_size).
    - lower_clip (float): The lower bound for clipping values. Values below this bound will be set to this value.
    - upper_clip (float): The upper bound for clipping values. Values above this bound will be set to this value.

    Returns:
    - np.array: An array of preprocessed cubes with the same shape as the input.

    Note: The function does not modify the original 'cubes' array; it returns a new array with the processed cubes.
    """
    # create an empty list to store preprocessed cubes
    preprocessed_cubes = []
    # iterate over each cube in the input array
    for cube in cubes:
        # clip values in the cube
        cube = np.clip(cube, lower_clip, upper_clip)
        # standardize values in the cube
        mean = np.mean(cube)
        std = np.std(cube)
        cube = (cube - mean) / std if std != 0 else (cube - mean)
        # append the preprocessed cube to the list
        preprocessed_cubes.append(cube)
    # convert the list of preprocessed cubes to a NumPy array and return
    return np.array(preprocessed_cubes)

In [None]:
# get split data
X_train, y_train, X_val, y_val, X_test, y_test = get_train_val_test(patients_train, patients_val, patients_test)

# reshape data 
input_shape = (24,24,24)
X_train = X_train.reshape((-1,) + input_shape)
X_val = X_val.reshape((-1,) + input_shape)
X_test = X_test.reshape((-1,) + input_shape)

y_train = y_train.reshape(-1)
y_val = y_val.reshape(-1)
y_test = y_test.reshape(-1)

# shuffle data
common_seed = 42
np.random.seed(common_seed)

train_indices = np.arange(len(X_train))
np.random.shuffle(train_indices)
X_train = X_train[train_indices]
y_train = y_train[train_indices]

val_indices = np.arange(len(X_val))
np.random.shuffle(val_indices)
X_val = X_val[val_indices]
y_val = y_val[val_indices]

test_indices = np.arange(len(X_test))
np.random.shuffle(test_indices)
X_test = X_test[test_indices]
y_test = y_test[test_indices]

In [None]:
# check for data leaks
def check_data_leaks(X_train: np.array, X_val: np.array, X_test: np.array):
    """
    Check for data leaks by comparing cubes between training, validation, and test sets.

    Parameters:
        X_train (np.array): Cubes in the training set.
        X_val (np.array): Cubes in the validation set.
        X_test (np.array): Cubes in the test set.
    """
    #check for duplicates in train and validation set
    for i,train_cube in enumerate(X_train):
        for j,val_cube in enumerate(X_val):
            if(np.array_equal(train_cube,val_cube)):
                print(f"Copy found at (train,val) {i,j}")
    #check for duplicates in train and test set
    for i,train_cube in enumerate(X_train):
        for j,test_cube in enumerate(X_test):
            if(np.array_equal(train_cube,test_cube)):
                print(f"Copy found at (train,test) {i,j}")
    #check for duplicates in validation and test set
    for i,val_cube in enumerate(X_val):
        for j,test_cube in enumerate(X_test):
            if(np.array_equal(val_cube,test_cube)):
                print(f"Copy found at (val,test) {i,j}")
    #check for duplicates in train set
    for i,train_cube in enumerate(X_train):
        for j,train_cube2 in enumerate(X_train):
            if(np.array_equal(train_cube,train_cube2) and i!=j):
                print(f"Copy found in Train set {i,j}")
    #check for duplicates in validation set
    for i,val_cube in enumerate(X_val):
        for j,val_cube2 in enumerate(X_val):
            if(np.array_equal(val_cube,val_cube2) and i!=j):
                print(f"Copy found in Validation set {i,j}")
    #check for duplicates in test set
    for i,test_cube in enumerate(X_test):
        for j,test_cube2 in enumerate(X_test):
            if(np.array_equal(test_cube,test_cube2) and i!=j):
                print(f"Copy found in Test set {i,j}")                

In [None]:
# check for data leaks 
check_data_leaks(X_train, X_val, X_test)

In [None]:
# apply preprocessing
X_train = preprocess_cubes(X_train)
X_val = preprocess_cubes(X_val)
X_test = preprocess_cubes(X_test)

In [None]:
# delete variable to save RAM space
del patient_instances

# Data Visualization

In [None]:
# plot sample cube
def plot_sampled_cube(cube_set: np.ndarray, edge_size: int, label_set = None, index = None):
    """
    Plot a sampled cube from a given cube set.

    Parameters:
        cube_set (np.ndarray): The set of cubes to choose from.
        edge_size (int): The edge size of the cube.
        label_set (Optional[np.ndarray]): The corresponding labels (optional).
        index (Optional[int]): The index of the cube to plot. If not provided, a random cube is selected.

    Returns:
        None: Displays the 3D scatter plot.
    """
    # random selection
    if(index is None):
        index = np.random.choice(range(len(cube_set)))
    # plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    # get the indices for each dimension
    x, y, z = np.meshgrid(np.arange(edge_size), np.arange(edge_size), np.arange(edge_size))
    # flatten the arrays for plotting
    x = x.flatten()
    y = y.flatten()
    z = z.flatten()
    values = cube_set[index].flatten()

    # Scatter plot
    ax.scatter(x, y, z, c=values, cmap='gray')
    # show label if required
    if (label_set is not None):
        ax.set_title(f"label: {label_set[index]}")

    # Show the plot
    plt.show()

In [None]:
# visualize some cubes from each set
plot_sampled_cube(X_train, edge_size=24, label_set=y_train)
plot_sampled_cube(X_val, edge_size=24, label_set=y_val)
plot_sampled_cube(X_test, edge_size=24, label_set=y_test)

# Convolutional Neural Network

In [None]:
# import tensorflow framework
import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow.keras.layers as tfkl

## FastConv3DNet V1 Architecture

In [None]:
# build the neural network layer by layer
def build_FastConv3DNet_V1(input_shape: tuple):
    """
    Build a 3D convolutional neural network model for binary classification.

    Parameters:
        input_shape (tuple): The input shape of the 3D data (e.g., (24, 24, 24, 1)).

    Returns:
        tfk.Model: Compiled CNN model for binary classification.
    """
    # input layer
    input_layer = tfkl.Input(shape=input_shape, name='Input')
    # first convolutional block
    x = tfkl.Conv3D(filters=8, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(1,1,1))(input_layer)
    x = tfkl.BatchNormalization()(x)
    x = tfkl.ReLU()(x)
    # second convolutional block
    x = tfkl.Conv3D(filters=16, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(1,1,1))(x)
    x = tfkl.BatchNormalization()(x)
    x = tfkl.ReLU()(x)
    # third convolutional block
    x = tfkl.Conv3D(filters=16, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(2,2,2))(x)
    x = tfkl.BatchNormalization()(x)
    x = tfkl.ReLU()(x)
    # fourth convolutional block
    x = tfkl.Conv3D(filters=32, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(4,4,4))(x)
    x = tfkl.BatchNormalization()(x)
    x = tfkl.ReLU()(x)
    # GAP layer
    x = tfkl.GlobalAveragePooling3D(name='GAP')(x)
    # Dense layer
    x = tfkl.Dense(units=32, activation='relu', name='dense')(x)
    x = tfkl.Dropout(0.5)(x)
    # Output layer
    output_layer = tfkl.Dense(units=1, activation='sigmoid', name='Output')(x)
    
    # connect input and output through the Model class
    model = tfk.Model(inputs=input_layer, outputs=output_layer, name='FastConv3DNet_V1')
    # compile the model
    model.compile(loss=tfk.losses.BinaryCrossentropy(), optimizer=tfk.optimizers.Adam(1e-3), 
                  metrics=['accuracy',tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
     
    return model

In [None]:
# build FastConv3DNet_V1 model
model_fastconv3dnet_v1 = build_FastConv3DNet_V1(input_shape=(24,24,24,1))
model_fastconv3dnet_v1.summary()

In [None]:
# visualize the model
tf.keras.utils.plot_model(model_fastconv3dnet_v1, show_shapes=True, show_layer_names=True)

## Squeeze and Excitation Block

In [None]:
# define the squeeze and excitation block
def se_block(input_tensor, ratio=16):
    """
    Creates a squeeze and excitation block.

    Parameters:
    input_tensor (tensor): The input tensor to the block.
    ratio (int): The ratio for dimensionality reduction in the squeeze step.

    Returns:
    output_tensor (tensor): The output tensor after the squeeze and excitation.
    """
    init = input_tensor
    channel_axis = -1
    filters = init.shape[channel_axis]
    se_shape = (1, 1, 1, filters)

    se = tfkl.GlobalAveragePooling3D()(init)
    se = tfkl.Reshape(se_shape)(se)
    se = tfkl.Dense(filters // ratio, activation='relu')(se)
    se = tfkl.Dense(filters, activation='sigmoid')(se)

    x = tfkl.multiply([init, se])
    return x

## FastConv3DNet V2 Architecture

In [None]:
# build the neural network layer by layer
def build_FastConv3DNet_V2(input_shape: tuple):
    """
    Build a 3D convolutional neural network model for binary classification.

    Parameters:
        input_shape (tuple): The input shape of the 3D data (e.g., (24, 24, 24, 1)).

    Returns:
        tfk.Model: Compiled CNN model for binary classification.
    """
    # input layer
    input_layer = tfkl.Input(shape=input_shape, name='Input')
    # first convolutional block
    x = tfkl.Conv3D(filters=8, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(1,1,1))(input_layer)
    x = tfkl.BatchNormalization()(x)
    x = tfkl.ReLU()(x)
    # squeeze and excitation block
    x = se_block(x, ratio=2)
    # second convolutional block
    x = tfkl.Conv3D(filters=8, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(1,1,1))(x)
    x = tfkl.BatchNormalization()(x)
    x = tfkl.ReLU()(x)
    # squeeze and excitation block
    x = se_block(x, ratio=2)
    # third convolutional block
    x = tfkl.Conv3D(filters=16, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(2,2,2))(x)
    x = tfkl.BatchNormalization()(x)
    x = tfkl.ReLU()(x)
    # squeeze and excitation block
    x = se_block(x, ratio=4)
    # fourth convolutional block
    x = tfkl.Conv3D(filters=32, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(4,4,4))(x)
    x = tfkl.BatchNormalization()(x)
    x = tfkl.ReLU()(x)
    # squeeze and excitation block
    x = se_block(x, ratio=8)
    # GAP layer
    x = tfkl.GlobalAveragePooling3D()(x)
    # Dense layer
    x = tfkl.Dense(units=32, activation='relu')(x)
    x = tfkl.Dropout(0.5)(x)
    # Output layer
    output_layer = tfkl.Dense(units=1, activation='sigmoid', name='Output')(x)
    
    # connect input and output through the Model class
    model = tfk.Model(inputs=input_layer, outputs=output_layer, name='FastConv3DNet_V2')
    # compile the model
    model.compile(loss=tfk.losses.BinaryCrossentropy(), optimizer=tfk.optimizers.Adam(1e-3), 
                  metrics=['accuracy',tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
     
    return model

In [None]:
# build FastConv3DNet_V2 model
model_fastconv3dnet_v2 = build_FastConv3DNet_V2(input_shape=(24,24,24,1))
model_fastconv3dnet_v2.summary()

In [None]:
# visualize the model
tf.keras.utils.plot_model(model_fastconv3dnet_v2, show_shapes=True, show_layer_names=True)

## Depthwise Separable Convolution

In [None]:
# define DepthwiseSeparableConv3D layer 
class DepthwiseSeparableConv3D(tf.keras.layers.Layer):
    def __init__(self, kernel_size, strides, padding, dilation_rate, **kwargs):
        super().__init__(**kwargs)
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.dilation_rate = dilation_rate

    def build(self, input_shape):
        self.depthwise_conv = tf.keras.layers.Conv3D(
            filters=input_shape[-1],
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding=self.padding,
            dilation_rate=self.dilation_rate,
            groups=input_shape[-1],  # Depthwise convolution
            use_bias=False
        )
        self.pointwise_conv = tf.keras.layers.Conv3D(
            filters=input_shape[-1],
            kernel_size=(1, 1, 1),
            strides=(1, 1, 1),
            padding=self.padding,
            use_bias=False
        )
        super().build(input_shape)

    def call(self, inputs):
        x = self.depthwise_conv(inputs)
        return self.pointwise_conv(x)

## FasterConv3DNet V1 Architecture

In [None]:
# build the neural network layer by layer
def build_FasterConv3DNet_V1(input_shape: tuple):
    """
    Build a 3D convolutional neural network model for binary classification.

    Parameters:
        input_shape (tuple): The input shape of the 3D data (e.g., (24, 24, 24, 1)).

    Returns:
        tfk.Model: Compiled CNN model for binary classification.
    """
    model = tf.keras.models.Sequential([
        # input layer
        tfkl.Input(shape=input_shape, name='Input'),
        # first convolutional block
        tfkl.Conv3D(filters=4, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(1,1,1)),
        tfkl.BatchNormalization(),
        tfkl.ReLU(),
        # second convolutional block
        DepthwiseSeparableConv3D(kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(1,1,1)),
        tfkl.BatchNormalization(),
        tfkl.ReLU(),
        # third convolutional block
        DepthwiseSeparableConv3D(kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(2,2,2)),
        tfkl.BatchNormalization(),
        tfkl.ReLU(),
        # fourth convolutional block
        DepthwiseSeparableConv3D(kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(4,4,4)),
        tfkl.BatchNormalization(),
        tfkl.ReLU(),
        # GAP layer
        tfkl.GlobalAveragePooling3D(name='GAP'),
        # Dense layer
        tfkl.Dense(units=2, activation='relu', name='dense'),
        #tfkl.Dropout(0.5),
        # Output layer
        tfkl.Dense(units=1, activation='sigmoid', name='Output')
        ], name="FasterConv3DNet_V1")
    
    # compile the model
    model.compile(loss=tfk.losses.BinaryCrossentropy(), optimizer=tfk.optimizers.Adam(1e-3), 
                  metrics=['accuracy',tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
     
    return model

In [None]:
# build FasterConv3DNet_V1 model
model_fasterconv3dnet_v1 = build_FasterConv3DNet_V1(input_shape=(24,24,24,1))
model_fasterconv3dnet_v1.summary()

In [None]:
# visualize the model
tf.keras.utils.plot_model(model_fasterconv3dnet_v1, show_shapes=True, show_layer_names=True)

## FasterConv3DNet V2 Architecture

In [None]:
# redefine the squeeze and excitation block as class
class SEBlock(tfkl.Layer):
    def __init__(self, ratio=2, **kwargs):
        super(SEBlock, self).__init__(**kwargs)
        self.ratio = ratio

    def build(self, input_shape):
        filters = input_shape[-1]
        self.squeeze = tfkl.GlobalAveragePooling3D()
        self.excitation = tf.keras.Sequential([
            tfkl.Dense(filters // self.ratio, activation='relu'),
            tfkl.Dense(filters, activation='sigmoid')
        ])

    def call(self, inputs):
        x = self.squeeze(inputs)
        x = tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1)
        scale = self.excitation(x)
        return inputs * scale

# build the neural network layer by layer
def build_FasterConv3DNet_V2(input_shape: tuple):
    """
    Build a 3D convolutional neural network model for binary classification.

    Parameters:
        input_shape (tuple): The input shape of the 3D data (e.g., (24, 24, 24, 1)).

    Returns:
        tfk.Model: Compiled CNN model for binary classification.
    """
    
    model = tf.keras.models.Sequential([
        # input layer
        tfkl.Input(shape=input_shape, name='Input'),
        # first convolutional block
        tfkl.Conv3D(filters=4, kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(1,1,1)),
        tfkl.BatchNormalization(),
        tfkl.ReLU(),
        # squeeze and excitation block
        SEBlock(ratio=2),
        # second convolutional block
        DepthwiseSeparableConv3D(kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(1,1,1)),
        tfkl.BatchNormalization(),
        tfkl.ReLU(),
        # squeeze and excitation block
        SEBlock(ratio=2),
        # third convolutional block
        DepthwiseSeparableConv3D(kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(2,2,2)),
        tfkl.BatchNormalization(),
        tfkl.ReLU(),
        # squeeze and excitation block
        SEBlock(ratio=2),
        # fourth convolutional block
        DepthwiseSeparableConv3D(kernel_size=3, strides=(1,1,1), padding="valid", dilation_rate=(4,4,4)),
        tfkl.BatchNormalization(),
        tfkl.ReLU(),
        # squeeze and excitation block
        SEBlock(ratio=2),
        # GAP layer
        tfkl.GlobalAveragePooling3D(name='GAP'),
        # Dense layer
        tfkl.Dense(units=2, activation='relu', name='dense'),
        # Output layer
        tfkl.Dense(units=1, activation='sigmoid', name='Output')
        ], name="FasterConv3DNet_V2")
    
    # compile the model
    model.compile(loss=tfk.losses.BinaryCrossentropy(), optimizer=tfk.optimizers.Adam(1e-3), 
                  metrics=['accuracy',tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
     
    return model

In [None]:
# build FastConv3DNet_V1 model
model_fasterconv3dnet_v2 = build_FasterConv3DNet_V2(input_shape=(24,24,24,1))
model_fasterconv3dnet_v2.summary()

In [None]:
# visualize the model
tf.keras.utils.plot_model(model_fasterconv3dnet_v2, show_shapes=True, show_layer_names=True)

## Model Training

In [None]:
# create callbacks

# early stopping 
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    patience=20,
    mode='max',
    restore_best_weights=True,
)

# learning rate scheduler
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_accuracy',
    factor=0.9,
    patience=5,
    verbose=0,
    mode='max',
    min_lr=1e-6,
)

callbacks = [early_stopping, lr_scheduler]

In [None]:
# Train the models

epochs = 200
batch_size = 64

models = [model_fastconv3dnet_v1, model_fastconv3dnet_v2, model_fasterconv3dnet_v1, model_fasterconv3dnet_v2]
model_histories = []

for model in models:
    print(f"-> Training {model.name}")
    model_history = model.fit(
        x = X_train,
        y= y_train,
        batch_size=batch_size,
        validation_data=(X_val,y_val),
        epochs=epochs, 
        callbacks=callbacks,
    )
    model_histories.append(model_history)

## Model Assessment

In [None]:
# show leanirng curves
def plot_learning_curves(model_history: tf.keras.callbacks.History):
    """
    Plot learning curves for accuracy, loss, and learning rate.

    Parameters:
        model_history (tf.keras.callbacks.History): History object obtained during model training.
    """
    best_epoch = np.argmin(model_history.history['val_loss'])
    # show accuracy curve
    plt.figure(figsize=(20,5))
    plt.plot(model_history.history['accuracy'], label='Accuracy [train]', alpha=.8, color='#ff7f0e')
    plt.plot(model_history.history['val_accuracy'], label='Accuracy [val]', alpha=.9, color='#5a9aa5')
    plt.axvline(x=best_epoch, label='Best epoch', alpha=.3, ls='--', color='#5a9aa5')
    plt.title('Accuracy')
    plt.xlabel('epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(alpha=.3)
    # show loss curve
    plt.figure(figsize=(20,5))
    plt.plot(model_history.history['loss'], label='Training Loss', alpha=.8, color='#ff7f0e')
    plt.plot(model_history.history['val_loss'], label='Validation Loss', alpha=.9, color='#5a9aa5')
    plt.axvline(x=best_epoch, label='Best epoch', alpha=.3, ls='--', color='#5a9aa5')
    plt.title('Loss')
    plt.xlabel('epoch')
    plt.ylabel('Cross Entropy')
    plt.legend()
    plt.grid(alpha=.3)
    # show learning rate curve
    plt.figure(figsize=(18,3))
    plt.plot(model_history.history['lr'], label='Learning Rate', alpha=.8, color='#ff7f0e')
    plt.axvline(x=best_epoch, label='Best epoch', alpha=.3, ls='--', color='#5a9aa5')
    plt.title('Learning Rate')
    plt.xlabel('epoch')
    plt.ylabel('Learning rate')
    plt.legend()
    plt.grid(alpha=.3)

    plt.show()

In [None]:
# visualize learning history
for model_history, model in zip(model_histories, models):
    print(f"-> Learning curves for {model.name}")
    plot_learning_curves(model_history)

In [None]:
# define utilities to assess model performance
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
import seaborn as sns

def plot_confusion_matrix(true: np.ndarray, predictions: np.ndarray, output_activation: str):
    """
    Plot the confusion matrix.

    Parameters:
        true (np.ndarray): True labels.
        predictions (np.ndarray): Predicted labels.
        output_activation (str): Activation function used in the output layer of the model.
    """
    # check it predictinos are returned as sigmoid output or softmax output
    if(output_activation=="sigmoid"):
        for i,pred in enumerate(predictions):
            predictions[i] = 0 if (pred<0.5) else 1
    elif(output_activation=="softmax"):
        true = np.argmax(true, axis=-1)
        predictions = np.argmax(predictions, axis=-1)
    
    # compute confusion matrix
    cm = confusion_matrix(true, predictions)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='g', cmap='Blues', cbar=False)

    # add labels and title
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')

    # show the plot
    plt.show()
    
def get_metrics(true: np.ndarray, predictions: np.ndarray, output_activation: str):
    """
    Display classification metrics.

    Parameters:
        true (np.ndarray): True labels.
        predictions (np.ndarray): Predicted labels.
        output_activation (str): Activation function used in the output layer of the model.
    """
    # display the shape of the predictions
    print("Predictions Shape:", predictions.shape)
    
    # check if predictions are returned as sigmoid output or softmax output
    if(output_activation=="sigmoid"):
        for i,pred in enumerate(predictions):
            predictions[i] = 0 if (pred<0.5) else 1
    elif(output_activation=="softmax"):
        true = np.argmax(true, axis=-1)
        predictions = np.argmax(predictions, axis=-1)

    # Compute classification metrics
    accuracy = accuracy_score(true, predictions)
    precision = precision_score(true, predictions) 
    recall = recall_score(true, predictions)
    f1 = f1_score(true, predictions)

    # Display the computed metrics
    print('Accuracy:', accuracy.round(4))
    print('Precision:', precision.round(4))
    print('Recall:', recall.round(4))
    print('F1:', f1.round(4))

In [None]:
# predict labels for the entire test set
# estimate the time required for inference
import time 
output_activation = "sigmoid"
for model in models:
    print(f"-> Model: {model.name}")
    times = []
    for i in range(100):
        start_time = time.time()
        predictions = model.predict(X_test, verbose=0)
        stop_time = time.time()
        times.append(stop_time - start_time)
    print(f"Average inference time for {len(X_test)} samples: {np.mean(times)} seconds")
    # get performance metrics 
    get_metrics(y_test, predictions, output_activation)
    plot_confusion_matrix(y_test, predictions, output_activation)

## Prediction Visualization

In [None]:
# modify get_labels to return also node ids
def get_labeled_patches_new(
    image: HearticDatasetManager.cat08.Cat08ImageCT|HearticDatasetManager.asoca.AsocaImageCT,
    graph: hcatnetwork.graph.graph.SimpleCenterlineGraph,
    rotation_random_sampling = None,
    rotation_local_sampling = None
):
    
    """ 
    The function takes as input a single image-graph pair and provides the extracted patches and labels by following these steps:
    
        1. Firstly it samples a point along the graph and applies a random translation to it.
        2. Then utility funcitons are used to sample a cube centered in such point.
        3. If the point lies close enough to one of the two Ostia, then label = 1, 0 otherwise.
        4. Finally, steps 1-2 are repeated by local sampling around the Ostium to ensure class 1 is represented. 
    
    Parameters:
        image (HearticDatasetManager.cat08.Cat08ImageCT | HearticDatasetManager.asoca.AsocaImageCT):
            The CT scan image data.
        graph (hcatnetwork.graph.graph.SimpleCenterlineGraph):
            The corresponding centerline graph.
        rotation_random_sampling (optional): Rotation parameters for random sampling. Default is None.
        rotation_local_sampling (optional): Rotation parameters for local sampling. Default is None.

    Returns:
        tuple: A tuple containing numpy arrays of extracted cubes, corresponding labels and nodes.
    """
    
    # compute radius statistics
    maxR, minR, avgR = compute_r_stats(graph)
    # retrieve ostia
    ids, coord_ostium_1, coord_ostium_2 = get_ostia_data(graph)
    
    cubes = []
    labels = []
    nodes = []
    

    # firstly random sample along the graph
    for cube_i in tqdm(range(NUM_CUBES_PATIENT // 2), desc="Random Sampling"):
        # choose a random node of the graph
        node_id = np.random.choice(list(graph.nodes.keys()))
        # get the position of the node in RAS
        node_position = np.array([graph.nodes[node_id]["x"], graph.nodes[node_id]["y"], graph.nodes[node_id]["z"]])
        nodes.append(node_position) # save the node position
        # select parameters for random translation
        r = np.random.uniform(-avgR*1.5, avgR*1.5)#translation vector
        theta = np.random.uniform(0, 2*np.pi)#xy plane angle
        phi = np.random.uniform(0, np.pi)#z to xy plane angle
        # apply the translation to the selected point
        node_position += np.array([r*np.sin(phi)*np.cos(theta), r*np.sin(phi)*np.sin(theta), r*np.cos(phi)]).reshape(3,1)
        if (rotation_random_sampling):
            # define rotation vector
            vector_axis_of_rotation = np.array([np.random.uniform(-1, 1) , np.random.uniform(-1, 1), np.random.uniform(-1, 1)])
            transformation_to_apply = HearticDatasetManager.affine.get_affine_3d_rotation_around_vector(
                vector=vector_axis_of_rotation,
                vector_source=node_position.reshape(3,1), # the center of the cube
                rotation=np.random.choice(range(90+1)), # max degree of rotation +1 
                rotation_units="deg"
            )
        else:
            transformation_to_apply = None
        # sample
        cube_array = get_input_data_from_vertex_ras_position(image,node_position,
                                                             CUBE_SIDE_MM,CUBE_SIDE_N_SAMPLES, 
                                                             affine=transformation_to_apply)
        # compute distances from ostia to assign labels
        dist1 = np.linalg.norm(coord_ostium_1[:3] - node_position)
        dist2 = np.linalg.norm(coord_ostium_2[:3] - node_position)
        label = 1 if (dist1<=1.2*coord_ostium_1[-1] or dist2<=1.2*coord_ostium_2[-1]) else 0

        cubes.append(cube_array)
        labels.append(label)
        # Simulating loading time for each iteration
        time.sleep(0.1)
        
    # then sample locally from the ostia applying augmentation via rotation if selected
    for cube_i in tqdm(range(NUM_CUBES_PATIENT // 2), desc="Local Sampling"):
        # select one of the two ostia
        node_id = np.random.choice(ids)
        # get the position of the node in RAS
        node_position = np.array([graph.nodes[node_id]["x"], graph.nodes[node_id]["y"], graph.nodes[node_id]["z"]])
        nodes.append(node_position) # save the node position
        # select parameters for random translation
        r = np.random.uniform(-graph.nodes[node_id]["r"]*1.2, graph.nodes[node_id]["r"]*1.2)#translation vector
        theta = np.random.uniform(0, 2*np.pi)#xy plane angle
        phi = np.random.uniform(0, np.pi)#z to xy plane angle
        # apply the translation to the selected point
        node_position += np.array([r*np.sin(phi)*np.cos(theta), r*np.sin(phi)*np.sin(theta), r*np.cos(phi)]).reshape(3,1)
        if (rotation_local_sampling):
            # define rotation vector
            vector_axis_of_rotation = np.array([np.random.uniform(-1, 1) , np.random.uniform(-1, 1), np.random.uniform(-1, 1)])
            transformation_to_apply = HearticDatasetManager.affine.get_affine_3d_rotation_around_vector(
                vector=vector_axis_of_rotation,
                vector_source=node_position.reshape(3,1), # the center of the cube
                rotation=np.random.choice(range(90+1)), # max degree of rotation +1 
                rotation_units="deg"
            )
        else:
            transformation_to_apply = None
        # sample
        cube_array = get_input_data_from_vertex_ras_position(image,node_position,
                                                             CUBE_SIDE_MM,CUBE_SIDE_N_SAMPLES, 
                                                             affine=transformation_to_apply)

        cubes.append(cube_array)
        labels.append(1)

        # simulating loading time for each iteration
        time.sleep(0.1)
    
    
    cubes = np.array(cubes)
    labels = np.array(labels)
    nodes = np.array(nodes)
        
    return cubes, labels, nodes

In [None]:
# visualize predictions for patient 6 in CAT08 dataset
patient_id = 6
# define an empty patient instance
patient6 = Patient()
# download patient image and graph
image_file = os.path.join(CAT08_FOLDER, HearticDatasetManager.cat08.DATASET_CAT08_IMAGES[patient_id])

graph_file = os.path.join(CAT08_FOLDER, HearticDatasetManager.cat08.DATASET_CAT08_GRAPHS_RESAMPLED_05MM[patient_id])

image = HearticDatasetManager.cat08.Cat08ImageCT(image_file)
graph = hcatnetwork.io.load_graph(graph_file, output_type=hcatnetwork.graph.SimpleCenterlineGraph)

In [None]:
# resample the graph
resampled_graph = graph.resample(mm_between_nodes=0.1)
hcatnetwork.draw.draw_simple_centerlines_graph_2d(resampled_graph)

In [None]:
# save Ostia coordinates and convert to RAS
ids = resampled_graph.get_coronary_ostia_node_id()
coord_ostium_1 = np.zeros(4)
coord_ostium_2 = np.zeros(4)

coord_ostium_1[0]= graph.nodes[ids[0]]['x']
coord_ostium_1[1]= graph.nodes[ids[0]]['y']
coord_ostium_1[2]= graph.nodes[ids[0]]['z']
coord_ostium_1[3]= graph.nodes[ids[0]]['r']
coord_ostium_2[0]= graph.nodes[ids[1]]['x']
coord_ostium_2[1]= graph.nodes[ids[1]]['y']
coord_ostium_2[2]= graph.nodes[ids[1]]['z']
coord_ostium_2[3]= graph.nodes[ids[1]]['r']

coord_ostium_1_ras = HearticDatasetManager.affine.apply_affine_3d(image.affine_centerlines2ras, coord_ostium_1[0:3])
coord_ostium_2_ras = HearticDatasetManager.affine.apply_affine_3d(image.affine_centerlines2ras, coord_ostium_2[0:3])

In [None]:
# convert all graph coordinates to RAS
for node_id in resampled_graph.nodes:
  old_coords = np.array([resampled_graph.nodes[node_id]["x"], resampled_graph.nodes[node_id]["y"], resampled_graph.nodes[node_id]["z"]])
  new_coords = HearticDatasetManager.affine.apply_affine_3d(image.affine_centerlines2ras, old_coords)
  resampled_graph.nodes[node_id]["x"] = new_coords[0]
  resampled_graph.nodes[node_id]["y"] = new_coords[1]
  resampled_graph.nodes[node_id]["z"] = new_coords[2]

In [None]:
# get distance from closest ostium for each node
sorted_nodes = []
for node_id in resampled_graph.nodes():
  node_position = np.array([resampled_graph.nodes[node_id]["x"], resampled_graph.nodes[node_id]["y"], resampled_graph.nodes[node_id]["z"]])
  dist1=((coord_ostium_1_ras[0]-node_position[0])**2 + (coord_ostium_1_ras[1]-node_position[1])**2 + (coord_ostium_1_ras[2]-node_position[2])**2)**(0.5)
  dist2=((coord_ostium_2_ras[0]-node_position[0])**2 + (coord_ostium_2_ras[1]-node_position[1])**2 + (coord_ostium_2_ras[2]-node_position[2])**2)**(0.5)
  if(dist1 > dist2):
    dist=dist2
  else:
    dist=dist1
  node_info = np.array([resampled_graph.nodes[node_id]["x"], resampled_graph.nodes[node_id]["y"], resampled_graph.nodes[node_id]["z"], dist])
  sorted_nodes.append(node_info)

# sort the nodes based on distance from the ostia
sorted_nodes.sort(key=lambda x:x[-1])

In [None]:
# build the volumes
cubes = []
labels = []
for i in range(1500):
  node_position = np.array([sorted_nodes[i][0],sorted_nodes[i][1],sorted_nodes[i][2]])
  cube_array = get_input_data_from_vertex_ras_position(image,node_position,
                                                             CUBE_SIDE_MM,CUBE_SIDE_N_SAMPLES)

  label = 1 if (sorted_nodes[i][-1]<=1.2*coord_ostium_1[-1] or sorted_nodes[i][-1]<=1.2*coord_ostium_2[-1]) else 0
  cubes.append(cube_array)
  labels.append(label)

cubes = np.array(cubes)
labels = np.array(labels)

# apply preprocessing
cubes_preprocessed = preprocess_cubes(cubes)

In [None]:
# get predictions
sorted_predictions = []
for model in models:
    sorted_prediction = model.predict(cubes_preprocessed)
    sorted_predictions.append(sorted_prediction)

In [None]:
import matplotlib.pyplot as plt
# plot the vessels and predicted labels
def plot_vessels(graph: hcatnetwork.graph.graph.SimpleCenterlineGraph, 
                 image: HearticDatasetManager.cat08.Cat08ImageCT|HearticDatasetManager.asoca.AsocaImageCT, 
                 coord_ostium_1_ras: np.ndarray, 
                 coord_ostium_1: np.ndarray, 
                 coord_ostium_2_ras: np.ndarray, 
                 coord_ostium_2: np.ndarray, 
                 sorted_predictions: List[int], 
                 labels: List[int], 
                 sorted_nodes: List[np.ndarray]):
    """
    Parameters:
    - graph: NetworkX graph representing vessels and associated nodes.
    - image: Image data containing affine transformation information.
    - coord_ostium_1_ras, coord_ostium_1, coord_ostium_2_ras, coord_ostium_2: Coordinates for ostium points in RAS space.
    - sorted_predictions: Predicted labels for vessels, sorted for plotting.
    - labels: True labels for vessels.
    - sorted_nodes: Coordinates of vessels, sorted for plotting.

    Returns:
    - mispredicted_nodes_indexes: List of indexes corresponding to mispredicted nodes for further analysis.
    """
    # Extract vessel coordinates from the graph
    x_vessels = []
    y_vessels = []
    z_vessels = []
    for node_id in graph.nodes:
        coords = np.array([graph.nodes[node_id]["x"], graph.nodes[node_id]["y"], graph.nodes[node_id]["z"]])
        new_coords = HearticDatasetManager.affine.apply_affine_3d(image.affine_centerlines2ras, coords)
        x_vessels.append(new_coords[0])
        y_vessels.append(new_coords[1])
        z_vessels.append(new_coords[2])

    # Create a 3D scatter plot
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Plot vessels in blue
    ax.scatter(x_vessels, y_vessels, z_vessels, c='blue', marker='o', s=1)

    # Plot ostium coordinates in green
    ax.scatter(coord_ostium_1_ras[0], coord_ostium_1[1], coord_ostium_1[2], c='green', marker='o', s=100)
    ax.scatter(coord_ostium_2_ras[0], coord_ostium_2[1], coord_ostium_2[2], c='green', marker='o', s=100)

    # Iterate through sorted predictions and plot nodes with colors based on correctness
    mispredicted_nodes_indexes = []
    for i in range(len(sorted_predictions)):
        if labels[i] == sorted_predictions[i]:
            col = "green"
            lab = "Correct labels"
        else:
            col = "red"
            lab = "Mispredicted labels"
            mispredicted_nodes_indexes.append(i)
        ax.scatter(sorted_nodes[i][0], sorted_nodes[i][1], sorted_nodes[i][2], c=col, label=lab, marker='o', s=20)

    # Set axis labels and plot title
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    ax.set_zlabel('Z-axis')
    ax.set_title('3D Scatter Plot of Patient 6 from the CAT08 dataset')

    # Add legend for clarity
    ax.legend(["Artery", "Mispredicted labels", "Correct labels"], loc="best")

    # Display the plot
    plt.show()

    # Return indexes of mispredicted nodes for further analysis
    return mispredicted_nodes_indexes

In [None]:
# compute the distance from the ostium of the mispredicted labels
mispredicted_nodes_indexes_dict = {}
for model in models:
    mispredicted_nodes_indexes = plot_vessels(graph, image, 
                                              coord_ostium_1_ras, coord_ostium_1,
                                              coord_ostium_2_ras, coord_ostium_2, 
                                              sorted_predictions, labels, sorted_nodes)
    mispredicted_nodes_indexes_dict[model] = mispredicted_nodes_indexes

mispredicted_distances_dict = {}
for model, mispredicted_nodes_indexes in mispredicted_nodes_indexes_dict.items():
    mispredicted_distances = []
    for i in mispredicted_nodes_indexes:
        mispredicted_distances.append(np.around(sorted_nodes[i][3], decimals=1))
    mispredicted_distances_dict[model] = mispredicted_distances

In [None]:
# plot the mispredictions histogram for each model
for model, mispredicted_distances in mispredicted_distances_dict.items():
    plt.hist(mispredicted_distances, bins='auto', range=(2.5, 4.5), align='left', edgecolor='black')

    # Set labels and title
    plt.xlabel('Distance from the ostium [mm]')
    plt.ylabel('Mispredictions')
    plt.title(f'Frequency of mispredictions for {model}')

    # Show the plot
    plt.show()