In [None]:
import SimpleITK as sitk
import os
import matplotlib.pyplot as plt
from platipy.imaging import ImageVisualiser
import pydicom
from tqdm import tqdm
import pandas as pd
import numpy as np
from moosez import moose

def read_one_series(directory):
    """
    Read a single DICOM series from a directory
    and return it as a SimpleITK image object.
    This function uses the SimpleITK library to read the DICOM files
    and create a 3D image object.
    The function assumes that the directory contains a complete DICOM series.

    Parameters
    ----------
    directory : str
        Path to the directory containing the DICOM series.

    Returns
    -------
    image : SimpleITK.Image
        The image object representing the DICOM series.
    """
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(directory)
    reader.SetFileNames(dicom_names)
    
    image = reader.Execute()
    
    return image

def get_time_metadata(dicom_dir):
    """
    Get the time metadata from a DICOM series
    stored in the given directory.
    The function extracts time-related information
    from the DICOM files and returns it as a tuple.
    The time-related tags are specified in the function.
    The function assumes that the directory contains a complete DICOM series.

    Parameters
    ----------
    dicom_dir : str
        Path to the directory containing the DICOM series.

    Returns
    -------
    tuple
        A tuple containing the time information (Ti) and the time difference (dt).
        The time information is extracted from the DICOM tags.
    """
    time_related_tags = ['0008|0030',
                         '0008|0031',
                         '0008|0032',
                         '0008|0033',
                         '0008|002A',
                         '0018|1060',
                         '0018|1063',
                         '0018|1065',
                         '0054|1300',
                         '0018|1242']
    dicom_files = [os.path.join(dicom_dir, f) for f in os.listdir(dicom_dir) 
                   if os.path.isfile(os.path.join(dicom_dir, f)) and f.endswith(('.dcm', '.ima'))]
    
    # Read the first DICOM file to extract metadata
    if dicom_files:
        ds = pydicom.dcmread(dicom_files[0])
        
        # Create a dictionary to store time-related information
        time_info = {}
        
        # Extract time-related tags
        for tag_str in time_related_tags:
            group, element = tag_str.split('|')
            tag = (int(group, 16), int(element, 16))
            
            if tag in ds:
                time_info[ds[tag].name] = ds[tag].value
        
        # # Display the extracted time information
        # time_df = pd.DataFrame(list(time_info.items()), columns=['Tag Name', 'Value'])
        # display(time_df)
        
        # # Additional useful time-related information that might be available
        # if hasattr(ds, 'AcquisitionTime'):
        #     print(f"Acquisition Time: {ds.AcquisitionTime}")
        
        # if hasattr(ds, 'SeriesTime'):
        #     print(f"Series Time: {ds.SeriesTime}")
            
        # if hasattr(ds, 'FrameReferenceTime'):
        #     print(f"Frame Reference Time: {ds.FrameReferenceTime}")
        
        # if hasattr(ds, 'ActualFrameDuration'):
        #     print(f"Actual Frame Duration: {ds.ActualFrameDuration} ms")

        return ds.FrameReferenceTime / 1000 / 60, ds.ActualFrameDuration / 1000 / 60
    else:
        print("No DICOM files found in the specified directory")
        return None
    
def process_dicom_series(main_folder):
    """
    Processes DICOM series stored in subfolders of the given folder.

    Each subfolder is assumed to contain a DICOM series.
    The function extracts time metadata and image data from each subfolder,
    and returns three lists containing:
        - The SITK image objects,
        - The time information (Ti),
        - The time differences (dt).

    Parameters:
    -----------
    main_folder : str
        Path to the folder containing subfolders with DICOM series.

    Returns:
    --------
    tuple
        A tuple containing three lists: (sitk_list, Ti_list, dt_list)
    """
    series_info = []
    
    # List subfolders and sort them in alphabetical order.
    subfolders = [os.path.join(main_folder, d) for d in os.listdir(main_folder)
                  if os.path.isdir(os.path.join(main_folder, d))]
    
    for subfolder in tqdm(subfolders):
        # Extract time metadata: Ti and dt.
        Ti, dt = get_time_metadata(subfolder)
        # Read the image data as a SITK object.
        image_data = read_one_series(subfolder)
        series_info.append((Ti, dt, image_data))
    
    # Sort the collected data by Ti in ascending order.
    series_info.sort(key=lambda x: x[0])

    # Unpack the sorted results into separate lists.
    sorted_Ti_list = [info[0] for info in series_info]
    sorted_dt_list = [info[1] for info in series_info]
    sitk_list = [info[2] for info in series_info]
    
    return sitk_list, sorted_Ti_list, sorted_dt_list

def get_IDIF(segmented, sitk_list, Ti_list, dt_list, which=6):
    """
    Get the IDIF (Image Derived Input Function) from the segmented image.
    The function calculates the mean activity in a specified region of interest (ROI)

    over time, using the segmented image and the list of PET images.
    The ROI is defined by the 'which' parameter, which corresponds to a specific
    label in the segmented image.
    The function assumes that the segmented image and the PET images are aligned.
    The function returns a DataFrame containing the time information and the mean
    activity in the ROI.
    The function also plots the mean activity over time.

    Parameters
    ----------
    segmented : SimpleITK.Image
        The segmented image object representing the ROI.
    sitk_list : list
        List of SimpleITK image objects representing the PET images.
    Ti_list : list
        List of time information extracted from the DICOM files.
    dt_list : list
        List of time differences extracted from the DICOM files.
    which : int, optional
        The label corresponding to the ROI in the segmented image.
        The default is 6, which corresponds to the aorta.
        Alternatively, you can use 4, which corresponds to the left ventricle.

    Returns
    -------
    df : pandas.DataFrame
        A DataFrame containing the time information and the mean activity in the ROI.
    """
    seg = segmented
    mean_activities = []

    roi_mask = (seg == which)
    num_of_roi_voxels = np.sum(sitk.GetArrayFromImage(roi_mask))

    for pet_image in tqdm(sitk_list):
        roi = sitk.Mask(pet_image, roi_mask)
        activity = sitk.GetArrayFromImage(roi)
        mean_activity = np.sum(activity) / num_of_roi_voxels
        mean_activities.append(mean_activity)
    data = {'Frame duration (min)': dt_list, 'Time (min)': Ti_list, 'AIF (Bq/mL)': mean_activities}
    df = pd.DataFrame(data)
    df.plot(x = 'Time (min)', y = 'AIF (Bq/mL)', kind = 'scatter')
    plt.show()

    return df

# Function to extract radiopharmaceutical and patient information from PET DICOM
def extract_pet_metadata(dicom_dir):
    """
    Extracts metadata from a DICOM directory containing PET images.
    The function reads the DICOM files in the specified directory,
    extracts relevant metadata such as patient information,
    study information, and radiopharmaceutical details,
    and returns the metadata as a dictionary.

    Parameters
    ----------
    dicom_dir : str
        Path to the directory containing DICOM files.
        
    Returns
    -------
    metadata : dict
        A dictionary containing extracted metadata.
    """

    # Extract metadata from the first PET subfolder
    subfolders = [os.path.join(dicom_dir, d) for d in os.listdir(dicom_dir)
                  if os.path.isdir(os.path.join(dicom_dir, d))]
    first_subfolder = subfolders[0]

    # Common PET-specific tags
    pet_tags = [
        ('0010', '1030'),  # Patient Weight
        ('0054', '1001'),  # Units
        ('0018', '1074'),  # Radionuclide Total Dose
        ('0018', '1075'),  # Radionuclide Half Life
        ('0018', '1076'),  # Radionuclide Positron Fraction
    ]
    
    # Find all DICOM files in the directory
    dicom_files = [os.path.join(first_subfolder, f) for f in os.listdir(first_subfolder) 
                   if os.path.isfile(os.path.join(first_subfolder, f)) and f.endswith(('.dcm', '.ima'))]
    
    if not dicom_files:
        print("No DICOM files found")
        return None
    
    # Read the first DICOM file
    ds = pydicom.dcmread(dicom_files[0])
    
    # Extract and display basic patient information
    metadata = {}
    
    # Patient info
    if hasattr(ds, 'PatientName'):
        metadata['Patient Name'] = str(ds.PatientName)
    if hasattr(ds, 'PatientID'):
        metadata['Patient ID'] = ds.PatientID
    if hasattr(ds, 'PatientBirthDate'):
        metadata['Patient Birth Date'] = ds.PatientBirthDate
    if hasattr(ds, 'PatientSex'):
        metadata['Patient Sex'] = ds.PatientSex
    if hasattr(ds, 'PatientWeight'):
        metadata['Patient Weight'] = f"{ds.PatientWeight} kg"
        
    # Study info
    if hasattr(ds, 'StudyDescription'):
        metadata['Study Description'] = ds.StudyDescription
    if hasattr(ds, 'StudyDate'):
        metadata['Study Date'] = ds.StudyDate
        
    # Radiopharmaceutical info from the sequence
    if hasattr(ds, 'RadiopharmaceuticalInformationSequence'):
        try:
            rpis = ds.RadiopharmaceuticalInformationSequence[0]
            if hasattr(rpis, 'RadiopharmaceuticalStartTime'):
                metadata['Radiopharmaceutical Start Time'] = rpis.RadiopharmaceuticalStartTime
            if hasattr(rpis, 'RadionuclideTotalDose'):
                metadata['Radionuclide Total Dose'] = f"{rpis.RadionuclideTotalDose} Bq"
            if hasattr(rpis, 'RadionuclideHalfLife'):
                metadata['Radionuclide Half Life'] = f"{rpis.RadionuclideHalfLife} seconds"
            if hasattr(rpis, 'RadiopharmaceuticalVolume'):
                metadata['Radiopharmaceutical Volume'] = f"{rpis.RadiopharmaceuticalVolume} ml"
            if hasattr(rpis, 'Radiopharmaceutical'):
                metadata['Radiopharmaceutical'] = rpis.Radiopharmaceutical
        except:
            print("Error extracting radiopharmaceutical sequence")
    
    # Units
    if hasattr(ds, 'Units'):
        metadata['Units'] = ds.Units
    
    # Create a DataFrame and display
    df = pd.DataFrame(list(metadata.items()), columns=['Parameter', 'Value'])
    display(df)
    
    return metadata

In [None]:
## Process PET DICOM Series

pet_path = "PET_DICOM" ## path to the folder containing the DICOM files
sitk_list, Ti_list, dt_list = process_dicom_series(pet_path) ## process the DICOM files
print(sitk_list[0].GetSize()) ## print the size of the first image
pet_metadata = extract_pet_metadata(pet_path) ## extract metadata from the DICOM files
vis = ImageVisualiser(
    image = sitk_list[-1], ## this is the last image
    cut = (220, 220, 322), ## adjust this to the center of your image
    figure_size_in = 6,
    colormap = plt.cm.jet
)
fig = vis.show()

In [None]:
## Process CT DICOM Series

ct_path = "AC_CT_SN_3_0_BR38_HD_FOV_0002" ## path to the folder containing the DICOM files
CT = read_one_series(ct_path) ## read the DICOM files
print(CT.GetSize()) ## print the size of the image
vis = ImageVisualiser(
    image = CT,
    cut = (256, 256, 265), ## adjust this to the center of your image
    figure_size_in = 6
)

fig = vis.show()

In [None]:
## Resample CT to PET space
## This is necessary for the registration/mask step

resampled_ct = sitk.Resample(CT,
                             sitk_list[-1].GetSize(),      # Use PET as the reference image
                             sitk.Transform(),             # Identity transform (since already registered)
                             sitk.sitkNearestNeighbor,     # Nearest neighbor interpolation for masks
                             sitk_list[-1].GetOrigin(),    # PET origin
                             sitk_list[-1].GetSpacing(),   # PET spacing
                             sitk_list[-1].GetDirection(), # PET direction
                             0,                            # Default pixel value
                             CT.GetPixelID())              # Preserve the mask's pixel type

%matplotlib inline

# visualise the resampled CT with the PET overlay
vis = ImageVisualiser(
    image = resampled_ct,
    cut = (220, 220, 220), # the (axial, coronal, sagittal) slice location
    # colormap = plt.cm.jet,
    figure_size_in = 6
)

vis.add_scalar_overlay(
    scalar_image = sitk_list[-1],
    name = "PET activity concentration [Bq/mL]",
    colormap = plt.cm.inferno,
    alpha = 0.5,
    max_value = 10000
)

fig = vis.show()

In [None]:
## Segmentation with Moose

accelerator = 'mps' ## Use 'mps' for MacOS, 'cuda' for NVIDIA GPUs, or 'cpu' for CPU only
input_file = resampled_ct
output_path = "CT_Seg" ## path to the folder where the segmentation will be saved
models = ['clin_ct_cardiac'] ## list of models to use for segmentation
## See MOOSE GitHub for other models: https://github.com/LalithShiyam/MOOSE
segmented = moose(input_file, models, output_path, accelerator)[0][0]
series_IDs = sitk.ImageSeriesReader.GetGDCMSeriesIDs(ct_path)
for i in range(len(series_IDs)):
    sitk.WriteImage(segmented, os.path.join(output_path, series_IDs[i] + ".dcm"))

# ## Use the following code if you want to load the saved segmentation
# segmented = read_one_series("CT_Seg")
# extractor3d = sitk.ExtractImageFilter()
# extractor3d.SetSize([segmented.GetSize()[0], segmented.GetSize()[1], segmented.GetSize()[2], 0])
# extractor3d.SetIndex([0, 0, 0, 0])
# segmented = extractor3d.Execute(segmented)

%matplotlib inline
## Visualise the segmentation
vis = ImageVisualiser(resampled_ct, cut = (220, 220, 322), figure_size_in = 8)
vis.add_contour(
  contour = {
    "Heart Myocardium": segmented == 1,
    "Left Atrium": segmented == 2,
    "Right Atrium": segmented == 3,
    "Left Ventricle": segmented == 4,
    "Right Ventricle": segmented == 5,
    "Aorta": segmented == 6,
    "Left Iliac Artery": segmented == 7,
    "Right Iliac Artery": segmented == 8,
    "Left Iliac Vein": segmented == 9,
    "Right Iliac Vein": segmented == 10,
    "Inferior Vena Cava": segmented == 11,
    "Portal Splenic Vein": segmented == 12,
    "Pulmonary Artery": segmented == 13
  },
  colormap = plt.cm.Spectral,
)
vis.set_limits_from_label(segmented, expansion = 30)
fig = vis.show()

In [None]:
which = 6 ## 6 for Aorta, 4 for Left Ventricle. Used for IDIF extraction
## Extract the IDIF from the segmented image
IDIF = get_IDIF(segmented, sitk_list, Ti_list, dt_list, 6)
IDIF.to_csv("IDIF.csv", index = False) ## save the IDIF to a CSV file
IDIF