In [None]:
# Import the necessary modules
from sirf.STIR import (ImageData, AcquisitionData,
                       SPECTUBMatrix, AcquisitionModelUsingMatrix,
                       MessageRedirector)
from src.simind import *
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import pandas as pd
import os
import numpy as np

In [None]:
# Redirect STIR messages
msg = MessageRedirector()

# Hardcoded arguments that were originally passed via argparse
total_activity = 258.423  # in MBq
time_per_projection = 43  # in seconds
photon_multiplier = 0.001
photopeak_energy = 208  # keV
window_lower = 187.56  # keV
window_upper = 229.24  # keV
source_type = 'lu177'
collimator = 'G8-MEGP'
kev_per_channel = 10.0
max_energy = 498.3
mu_map_path = 'data/Lu177/registered_CTAC.hv'
image_path = 'data/Lu177/osem_reconstruction_postfilter_555.hv'
measured_data_path = 'data/Lu177/SPECTCT_NEMA_128_EM001_DS_en_1_Lu177_EM.hdr'
measured_additive = '/home/sam/working/STIR_users_MIC2023/data/Lu177/STIR_TEW.hs'
output_dir = 'simind_output'
output_prefix = 'output'
input_smc_file_path = 'input/input.smc'
scoring_routine = 1
collimator_routine = 0
photon_direction = 3
crystal_thickness = 7.25  # mm
crystal_half_length_radius = 393.6 / 2  # mm
crystal_half_width = 511.7 / 2  # mm
flag_11 = True

# Calculating the number of energy spectra channels
num_energy_spectra_channels = max_energy // kev_per_channel

In [None]:
# Define the acquisition model function
def get_acquisition_model(measured_data, additive_data, image, mu_map_stir):
    acq_matrix = SPECTUBMatrix()
    acq_matrix.set_attenuation_image(mu_map_stir)
    acq_matrix.set_keep_all_views_in_cache(True)
    acq_matrix.set_resolution_model(1.81534, 0.02148, False)
    
    acq_model = AcquisitionModelUsingMatrix(acq_matrix)
    acq_model.set_up(measured_data, image)
    
    try:
        acq_model.set_additive(additive_data)
    except Exception as e:
        print(e)
        print("Could not set additive data")
        
    return acq_model

In [None]:
# Load image, mu_map, measured data, and measured additive data
image = ImageData(image_path)
mu_map = ImageData(mu_map_path)
measured_data = AcquisitionData(measured_data_path)
measured_additive = AcquisitionData(measured_additive)

# Adjusting mu_map for STIR
mu_map_stir = mu_map.clone()
mu_map_stir.fill(np.flip(mu_map.as_array(), axis=2))

In [None]:
# Change to the correct directory
os.chdir("/home/sam/working/STIR_users_MIC2023")

# Set up the SIMIND simulator
simulator = SimindSimulator(template_smc_file_path=input_smc_file_path,
                            output_dir=output_dir, output_prefix=output_prefix,
                            source=image, mu_map=mu_map, template_sinogram=measured_data)

# Configure the simulator parameters
simulator.add_comment("Demonstration of SIMIND simulation")
simulator.set_windows(window_lower, window_upper, 0)
simulator.add_index("photon_energy", photopeak_energy)
simulator.add_index("scoring_routine", scoring_routine)
simulator.add_index("collimator_routine", collimator_routine)
simulator.add_index("photon_direction", photon_direction)
simulator.add_index("source_activity", total_activity * time_per_projection)
simulator.add_index("crystal_thickness", crystal_thickness / 10)  # cm
simulator.add_index("crystal_half_length_radius", crystal_half_length_radius / 10)  # cm
simulator.add_index("crystal_half_width", crystal_half_width / 10)  # cm
simulator.config.set_flag(11, flag_11)
simulator.add_index("step_size_photon_path_simulation", min(*image.voxel_sizes()) / 10)  # cm
simulator.add_index("energy_resolution", 9.5)  # percent
simulator.add_index("intrinsic_resolution", 0.31)  # cm

# Set runtime switches
simulator.add_runtime_switch("CC", collimator)
simulator.add_runtime_switch("NN", photon_multiplier)
simulator.add_runtime_switch("FI", source_type)

In [None]:
# Run the simulation
simulator.run_simulation()

In [None]:
# Retrieve simulation outputs
simind_total = simulator.get_output_total()
simind_scatter = simulator.get_output_scatter()
simind_true = simind_total - simind_scatter

In [None]:
# Run the acquisition model
acq_model = get_acquisition_model(measured_data, measured_additive, image, mu_map_stir)
stir_forward_projection = acq_model.forward(image)

In [None]:
# Display counts
print(f"simind total counts: {simind_total.sum()}")
print(f"simind true counts: {simind_true.sum()}")
print(f"simind scatter counts: {simind_scatter.sum()}")
print("\n")
print(f"measured total counts: {measured_data.sum()}")
print(f"stir true counts: {stir_forward_projection.sum()}")
print(f"measured additive counts: {measured_additive.sum()}")

In [None]:
# Define the data to be plotted
data_list = [
    ((simind_total), "simind total"),
    ((simind_true), "simind true"),
    ((simind_scatter), "simind scatter"),
    ((measured_data), "measured"),
    ((stir_forward_projection), "stir forward"),
    ((measured_additive), "measured additive")
]

data_list = [(data.as_array(), title) for data, title in data_list]

# Plot the axial slice
axial_slice = 66
vmax = max([data[0][axial_slice].max() for data, _ in data_list])

# Define consistent font size and colormap
font_size = 14
colormap = 'viridis'

# Create a figure and a GridSpec with 3 rows
fig = plt.figure(figsize=(len(data_list) * 4, 7 * 2))
gs = GridSpec(3, len(data_list), height_ratios=[2, 0.15, 3])  # Adjusted GridSpec for clarity

# Create image subplots
ax_images = [fig.add_subplot(gs[0, i]) for i in range(len(data_list))]

for i, (data, title) in enumerate(data_list):
    im = ax_images[i].imshow(data[0, axial_slice], vmin=0, vmax=vmax, cmap=colormap)
    ax_images[i].set_title(f"{title}: {np.trunc(data.sum())} ", fontsize=font_size)
    ax_images[i].axis('off')

# Create a colorbar
cbar_ax = fig.add_subplot(gs[1, :])
fig.colorbar(im, cax=cbar_ax, orientation='horizontal', pad=0.02)
cbar_ax.set_xlabel('Counts', fontsize=font_size)
cbar_ax.xaxis.set_label_position('top')

# Plot the line plot
ax_line = fig.add_subplot(gs[2, :])
colours = plt.cm.viridis(np.linspace(0, 1, len(data_list)))
for i, (data, title) in enumerate(data_list):
    ax_line.plot(data[0, axial_slice][60], linewidth=2, color=colours[i], linestyle='-', label=title)

# Enhance the line plot appearance
ax_line.set_xlabel('Projection angle', fontsize=font_size)
ax_line.set_ylabel('Intensity', fontsize=font_size)
ax_line.set_title(f'Profile Through Sinogram', fontsize=font_size + 2)
ax_line.grid(True, which='both', linestyle='--', linewidth=0.5)
ax_line.legend(loc='upper left', fontsize=font_size)
ax_line.set_xlim(0, 128)

In [None]:


# Adjust layout and save the figure
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "comparison_axial.png"))
plt.close()