In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
import sys
from pathlib import Path

# ============================================================================
# CONFIGURATION - CHOOSE YOUR ISOTOPE
# ============================================================================
ISOTOPE = "lu177"  # Options: "tc99m" or "lu177"
# ============================================================================

# Set STIR config directory from CONDA_PREFIX (same as activation script does)
if 'CONDA_PREFIX' in os.environ:
    os.environ['STIR_CONFIG_DIR'] = os.path.join(os.environ['CONDA_PREFIX'], 'share', 'STIR-6.3', 'config')
    os.environ['STIR_DOC_DIR'] = os.path.join(os.environ['CONDA_PREFIX'], 'share', 'doc', 'STIR-6.3')

# Set SIMIND environment variables
# Find simind directory relative to this notebook or use default location
notebook_dir = Path.cwd()
simind_dir = notebook_dir.parent / 'simind'
if not simind_dir.exists():
    # Try one level up
    simind_dir = notebook_dir.parent.parent / 'simind'
if simind_dir.exists():
    os.environ['SMC_DIR'] = str(simind_dir / 'smc_dir') + '/'
    # Add simind to PATH
    current_path = os.environ.get('PATH', '')
    if str(simind_dir) not in current_path:
        os.environ['PATH'] = f"{simind_dir}:{current_path}"
    print(f"SIMIND configured: {simind_dir}")
else:
    print(f"Warning: SIMIND directory not found at {simind_dir}")
    print(f"Please ensure SIMIND is installed and accessible in your PATH")

# plotting settings
plt.ion() # interactive 'on' such that plots appear during loops

# Import packages
import stir
from sirf_simind_connection import SimindSimulator, SimulationConfig, configs, utils
from sirf_simind_connection.core.components import ScoringRoutine
from sirf_simind_connection.utils import get_array
from sirf_simind_connection.utils.stir_utils import create_stir_image
from sirf_simind_connection.backends import create_acquisition_data
import phantomgen as phantom
import numpy as np
import shutil
from stir_simind_utils import (
    DEW_scatter_correction, TEW_scatter_correction,
    reconstruct_with_osem, compare_reconstructions,
    add_poisson_noise
)

print(f"\n{'='*60}")
print(f"Configuration: {ISOTOPE.upper()}")
print(f"{'='*60}\n")

In [None]:
# ============================================================================
# Phantom Generation
# ============================================================================

# Create output directory
output_dir = Path(f"output/{ISOTOPE}")
output_dir.mkdir(parents=True, exist_ok=True)

print("Creating phantom and attenuation map...")
if ISOTOPE == "tc99m":
    matrix_size = (256, 256, 256)
    voxel_size = (2.2077, 2.2077, 2.2077)
elif ISOTOPE == "lu177":
    matrix_size = (128, 128, 128)
    voxel_size = (4.4154, 4.4154, 4.4154)

# Define EARL NEMA phantom parameters for each isotope
earl_nema_dict_lu177 = {
    "mu_values": {
        "perspex_mu_value": 0.15,
        "fill_mu_value": 0.14,
        "lung_mu_value": 0.043
    },
    "activity_concentration_background": 0.0,
    "include_lung_insert": False,
    "sphere_dict": {
        "ring_R": 57,
        "ring_z": -37,
        "spheres": {
            "diametre_mm": [13, 17, 22, 28, 37, 60],
            "angle_loc": [270, 150, 30, 90, 330, 210],
            "act_conc_MBq_ml": [1.611, 1.611, 1.611, 1.611, 1.611, 1.611],
        }
    },
    "center_offset_mm": (37.0, 55.0, -4.0), # these have been tested
}

earl_nema_dict_tc99m = {
    "mu_values": {
        "perspex_mu_value": 0.175,
        "fill_mu_value": 0.154,
        "lung_mu_value": 0.046
    },
    "activity_concentration_background": 0.0,
    "include_lung_insert": False,
    "sphere_dict": {
        "ring_R": 57,
        "ring_z": -37,
        "spheres": {
            "diametre_mm": [13, 17, 22, 28, 37, 60],
            "angle_loc": [270, 150, 30, 90, 330, 210],
            "act_conc_MBq_ml": [1.013, 1.013, 1.013, 1.013, 1.013, 1.013],
        }
    },
    "center_offset_mm": (37.0, 50.0, -4.0), # these have not been tested
    
}

# Select phantom parameters based on isotope
phantom_dict = earl_nema_dict_lu177 if ISOTOPE == "lu177" else earl_nema_dict_tc99m

# Generate phantom using phantomgen package
nema_act_arr, nema_ctac_arr = phantom.create_nema(
    matrix_size=matrix_size, 
    voxel_size_mm=voxel_size, 
    nema_dict=phantom_dict,
    supersample=4
)

# Convert to STIR image format
nema_act_image = create_stir_image(nema_act_arr.shape, voxel_size)
nema_ctac_image = create_stir_image(nema_ctac_arr.shape, voxel_size)

nema_act_image.fill(nema_act_arr)
nema_ctac_image.fill(nema_ctac_arr)

print(f"Phantom created for {ISOTOPE.upper()}")
print(f"  Matrix size: {matrix_size}")
print(f"  Voxel size: {voxel_size} mm")
print(f"  Total activity: {nema_act_arr.sum():.2f} MBq")

In [None]:
# visualise phantom slices
fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))
axim0 = axes[0].imshow(nema_act_arr[64], cmap='hot')
axim1 = axes[1].imshow(nema_ctac_arr[matrix_size[0]//2], cmap='gray')
plt.colorbar(axim0, ax=axes[0])
plt.colorbar(axim1, ax=axes[1])
# remove ticks
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
plt.show()

In [None]:
# ============================================================================
# SIMIND Simulation Setup
# ============================================================================

print("Setting up SIMIND simulator...")

# Isotope-specific configuration
if ISOTOPE == "lu177":
    config_file = "Discovery670_lu177.yaml"
    output_prefix = "nema_lu177"
    source_type = "lu177"
    # Lu-177 energy windows: photopeak + 2 scatter windows for TEW
    window_lower = [187.2, 156.4, 229.36]  # keV
    window_upper = [228.8, 183.6, 258.64]  # keV
    photopeak_energy = 208  # keV
    time_per_proj = 43  # seconds
    photopeak_window_idx = 1  # Index of the photopeak window
    lower_scatter_window_idx = 2  # Index of lower scatter window
    upper_scatter_window_idx = 3  # Index of upper scatter window
elif ISOTOPE == "tc99m":
    config_file = "Discovery670_tc99m.yaml"
    output_prefix = "nema_tc99m"
    source_type = "tc99m"
    # Tc-99m energy windows: photopeak + 1 lower scatter window for DEW
    window_lower = [114.0, 126.45]  # keV: [lower scatter, photopeak]
    window_upper = [126.0, 154.55]  # keV
    photopeak_energy = 140  # keV
    time_per_proj = 30  # seconds
    photopeak_window_idx = 2  # Index of the photopeak window
    lower_scatter_window_idx = 1  # Index of lower scatter window
else:
    raise ValueError(f"Unknown isotope: {ISOTOPE}. Choose 'tc99m' or 'lu177'")

# Load configuration using SIRF-SIMIND-Connection
config = SimulationConfig(config_file)

# Create simulator
simulator = SimindSimulator(
    config_source=config,
    output_dir=output_dir,
    output_prefix=output_prefix,
    photon_multiplier=1,
    scoring_routine=ScoringRoutine.SCATTWIN,
)

# Set source and attenuation map
simulator.set_source(nema_act_image)
simulator.set_mu_map(nema_ctac_image)

# Set isotope type
simulator.add_runtime_switch("FI", source_type)

# Set collimator
colls = "GI-MEGP"
simulator.add_runtime_switch("CC", colls)

# Configure energy windows
scatter_orders = [0,] * len(window_lower)
simulator.set_energy_windows(
    lower_bounds=window_lower,
    upper_bounds=window_upper,
    scatter_orders=scatter_orders
)

# Set photon energy
simulator.add_config_value(1, photopeak_energy)

# Set number of projections
num_projections = 120
simulator.add_config_value(29, num_projections)

# Calculate and set source activity
total_activity = nema_act_arr.sum()
source_activity = total_activity * time_per_proj
simulator.add_config_value(25, source_activity)

print(f"Simulation configured for {ISOTOPE.upper()}:")
print(f"  Config file: {config_file}")
print(f"  Photopeak energy: {photopeak_energy} keV")
print(f"  Energy windows: {len(window_lower)}")
print(f"  Total activity: {total_activity:.2f} MBq")
print(f"  Source activity: {source_activity:.2f} MBq*s")
print(f"  Projections: {num_projections}")

print("\nRunning simulation (this may take a few minutes)...")
simulator.run_simulation()
print("Simulation complete!")

In [None]:
# ============================================================================
# Scatter Correction
# ============================================================================


# Load measured data for comparison
if ISOTOPE == "lu177":
    measured_data = create_acquisition_data('measured_data/lu177/EARL_NEMA_128_EM_en_1_Lu177_EM.hdr')
    measured_lower_scatter = create_acquisition_data('measured_data/lu177/EARL_NEMA_128_SC1_en_1_Lu177_SC.hdr')
    measured_upper_scatter = create_acquisition_data('measured_data/lu177/EARL_NEMA_128_SC2_en_1_Lu177_SC.hdr')
    measured_corrected = TEW_scatter_correction(measured_data, measured_lower_scatter, measured_upper_scatter)
elif ISOTOPE == "tc99m":
    measured_data = create_acquisition_data('measured_data/tc99m/earl_tc99m_em_en_1_Tc99m_EM.hdr')
    measured_lower_scatter = create_acquisition_data('measured_data/tc99m/earl_tc99m_sc_en_1_Tc99m_SC.hdr')
    measured_corrected = DEW_scatter_correction(measured_data, measured_lower_scatter)

# Extract simulated data from photopeak window
simind_peak = simulator.get_total_output(window=photopeak_window_idx)
simind_scatter = simulator.get_scatter_output(window=photopeak_window_idx)
simind_unscattered = add_poisson_noise(simind_peak) - add_poisson_noise(simind_scatter)

# Apply appropriate scatter correction based on isotope
if ISOTOPE == "lu177":
    # TEW (Triple Energy Window) correction for Lu-177
    simind_lower_scatter = add_poisson_noise(simulator.get_total_output(window=lower_scatter_window_idx))
    simind_upper_scatter = add_poisson_noise(simulator.get_total_output(window=upper_scatter_window_idx))
    simind_corrected = TEW_scatter_correction(
        add_poisson_noise(simind_peak), 
        simind_lower_scatter, simind_upper_scatter
    )
    correction_method = "TEW"
elif ISOTOPE == "tc99m":
    # DEW (Dual Energy Window) correction for Tc-99m
    simind_lower_scatter = add_poisson_noise(simulator.get_total_output(window=lower_scatter_window_idx))
    simind_corrected = DEW_scatter_correction(
        add_poisson_noise(simind_peak), 
        simind_lower_scatter
    )
    correction_method = "DEW"

print(f"\n{correction_method} scatter correction applied successfully!")

In [None]:
# Visualize projections
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

ax[0].imshow(get_array(measured_corrected)[0,:,60], cmap='gray')
ax[0].set_title('Measured Corrected')
ax[0].axis('off')

ax[1].imshow(get_array(simind_corrected)[0,:,60], cmap='gray')
ax[1].set_title(f'SIMIND {correction_method} Corrected')
ax[1].axis('off')

ax[2].imshow(get_array(simind_unscattered)[0,:,60], cmap='gray')
ax[2].set_title('SIMIND Unscattered (Ground Truth)')
ax[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# OSEM Reconstructions
# ============================================================================

print("Running OSEM reconstructions...")

recon_dir = output_dir / "recon"
recon_dir.mkdir(parents=True, exist_ok=True)

par_template = Path("par_files/recon_OSEM.par")
reconstructions = {}

recon_jobs = [
    ("SIMIND Corrected", simind_corrected, "simind_corrected"),
    ("SIMIND Unscattered", simind_unscattered, "simind_unscattered"),
    ("Measured Corrected", measured_corrected, "measured_corrected"),
]

for label, dataset, prefix in recon_jobs:
    print(f"  - {label}")
    hs_path = recon_dir / f"{prefix}.hs"
    dataset.write(str(hs_path))
    recon_image = reconstruct_with_osem(
        input_file=str(hs_path),
        output_prefix=f"{prefix}",
        par_file_template=str(par_template),
        initial_image_template=nema_act_image.native_object,
        attenuation_image=nema_ctac_image.native_object,
        num_subsets=4,
        num_subiterations=24,
        temp_dir=str(recon_dir / f"{prefix}_temp")
    )
    reconstructions[label] = recon_image

print("All reconstructions complete!\n")


In [None]:
# Advanced comparison: Line profiles through spheres
import numpy as np

if len(reconstructions) > 0:
    print("Generating line profiles through spheres...")
    
    # Define a horizontal line profile through the center
    if ISOTOPE == "tc99m":
        slice_idx = 110  
        row_idx = 140  
    else:
        slice_idx = 65  
        row_idx = 82
    
    fig, axes = plt.subplots(2, 1, figsize=(15, 10))
    
    # Plot 1: Images with line overlay
    n_recons = len(reconstructions)
    for idx, (name, img) in enumerate(reconstructions.items()):
        arr = img.as_array()
        
        # Show the slice
        ax = axes[0] if n_recons <= 2 else axes[0]
        im = ax.imshow(arr[slice_idx, :, :], cmap='hot', alpha=0.7)
        ax.axhline(y=row_idx, color='cyan', linestyle='--', linewidth=2, label=f'{name} profile line')
        ax.set_title(f'Reconstruction Comparison - Slice {slice_idx}')
        ax.legend()
    
    # Plot 2: Line profiles
    ax_profile = axes[1]
    for name, img in reconstructions.items():
        arr = img.as_array()
        profile = arr[slice_idx, row_idx, :]
        ax_profile.plot(profile, label=name, linewidth=2)
    
    ax_profile.set_xlabel('Pixel Index', fontsize=12)
    ax_profile.set_ylabel('Intensity', fontsize=12)
    ax_profile.set_title(f'Horizontal Profile at Row {row_idx}, Slice {slice_idx}', fontsize=14, fontweight='bold')
    ax_profile.legend(fontsize=10)
    ax_profile.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(output_dir / f"{ISOTOPE}_line_profiles.png")
    plt.show()
    
    # Calculate and display differences
    if len(reconstructions) >= 2:
        print("\nPairwise Differences:")
        print("-" * 60)
        names = list(reconstructions.keys())
        arrays = [reconstructions[name].as_array() for name in names]
        
        for i in range(len(names)):
            for j in range(i+1, len(names)):
                diff = arrays[i] - arrays[j]
                rmse = np.sqrt(np.mean(diff**2))
                max_diff = np.abs(diff).max()
                print(f"{names[i]} vs {names[j]}:")
                print(f"  RMSE: {rmse:.4f}")
                print(f"  Max absolute difference: {max_diff:.4f}")
                print()
else:
    print("No reconstructions available for profile analysis!")

In [None]:
# Compare the reconstructed images
print("Comparing reconstructed images...")

if len(reconstructions) > 0:
    # Display comparison for middle slice
    fig, axes = compare_reconstructions(reconstructions, slice_idx=slice_idx, cmap='hot')
    plt.savefig(output_dir / f"{ISOTOPE}_reconstruction_comparison.png")
    plt.show()
    
    # Print some statistics
    print("\nReconstruction Statistics:")
    print("-" * 60)
    for name, img in reconstructions.items():
        arr = img.as_array()
        print(f"{name}:")
        print(f"  Shape: {arr.shape}")
        print(f"  Min: {arr.min():.4f}, Max: {arr.max():.4f}")
        print(f"  Mean: {arr.mean():.4f}, Std: {arr.std():.4f}")
        print()
else:
    print("No reconstructions available to compare!")

In [None]:
np.where