# 3D Multimodal Registration Plan (CT ⟷ Ultrasound, focused on CFRP)

This document is a **step-by-step plan** to register 3D **X-ray Computed Tomography (CT)** volumes with **Ultrasound (US)** volumes using **Mutual Information (MI/NMI)** and complementary metrics. It’s written to be pasted as the first markdown cell of a Jupyter notebook and used as a practical checklist. The emphasis is **CFRP** parts (anisotropy, porosity, layer interfaces), but most steps generalize.

---

## 0) Quick Summary (what we’ll do)

1. **Unify geometry**: resample both volumes to a common isotropic spacing and common orientation.
2. **Denoise & normalize**: reduce speckle in US, window/clip CT; normalize intensities.
3. **Mask ROI**: focus the metric on material (exclude air/background and obvious artifacts).
4. **Initialize**: align centers/axes or use known calibration/fiduciaries.
5. **Multi-resolution search**: coarse → fine pyramid; sample a small % of voxels per level.
6. **Optimize**: rigid first (6-DOF), then (only if needed) affine or light B-spline.
7. **Metric**: use **NMI (Mattes)** as the workhorse; optionally blend with gradient-based terms or **MIND** descriptors.
8. **Validate**: visual overlays + quantitative metrics (TRE, Dice/IoU of ROIs, defect overlap).
9. **Refine & save**: if needed, small local deformable step; save transforms & resampled outputs.

---

## 1) Inputs & Assumptions

- **Inputs**:  
  - `CT`: 3D volume (e.g., NIfTI, MHD, DICOM converted), floating or integer.  
  - `US`: 3D volume (e.g., from a tracked probe or 3D sweep), floating preferred.  
- **Assumptions**:  
  - Both datasets imaging the **same part** with **overlap**.  
  - You know approximate voxel spacing (mm) for both.  
  - US may have **speckle**, shadows, gain variations, angle-dependent artifacts.  
  - CT has **porosity/voids**, layer interfaces, and high SNR.

---

## 2) Coordinate System & Geometry

**Goal**: Bring both into a common space to make the metric meaningful.

- **Resampling to isotropic spacing**:
  - Pick a **target spacing** that balances detail and runtime (e.g., **0.4–0.8 mm** for coarse levels; optionally refine to **0.2–0.4 mm** at the finest level).
  - Prefer **linear** interpolation for intensity images during registration; **nearest** for masks/labels.
- **Orientation**:
  - Reorient to a common anatomical/part-centric frame (e.g., LPS/RAS or principal axes of the part).
- **Cropping**:
  - Crop to a **tight bounding box** around the part to reduce computation.

**Options**:
- Use ITK/SimpleITK/ANTs/Elastix for resampling with explicit spacing/origin/direction handling.
- If DICOM, convert to a single 3D volume first (pay attention to slice spacing vs. thickness).

---

## 3) Preprocessing (per modality)

**CT**:
- **Window/clip** intensities to reduce outliers (e.g., **p2–p98 percentiles**), then scale to **[0,1]** or **[0,255]**.
- Optional **mild smoothing** (Gaussian σ ~ 0.5–1.0 vox) if very noisy.

**US**:
- **Speckle reduction**: median filter (3–5³), anisotropic diffusion (few iterations), or a light Gaussian.
- **Log compression** (if raw amplitude): `I_log = log(1 + α * I)` → improves dynamic range.
- Normalize to **[0,1]** after filtering/compression.

**Why**: Stabilizes the histogram for MI and reduces local extrema caused by noise/speckle.

---

## 4) ROI Masking (focus the metric)

**Idea**: MI can be dominated by background (air). Restrict to **material ROI**.

- **CT-derived mask** (preferred):
  - Threshold + morphological cleanup to segment the laminate/part.
- **US mask**:
  - Remove obvious **shadow cones** or saturated regions if known.
- **Final mask for metric**:
  - Use **intersection** of CT-mask (warped to US space during iterations) and a **confidence mask** in US (optional).
  - Alternatively, use **gradient-magnitude threshold** to emphasize interfaces.

**Tips**:
- If you can’t maintain a dynamically warped CT-mask, start with a **generous mask** that still excludes clear background.
- Avoid masking so aggressively that you drop all informative structures.

---

## 5) Initialization (coarse alignment)

**Goal**: Start close enough so the optimizer can converge.

- **Centering & PCA**:
  - Align **centers of mass** of masks and **principal axes** (PCA) to get rough rotation/translation.
- **Fiduciaries / Probe calibration** (if available):
  - Use tracker-to-image calibration or landmark pairs for an initial rigid transform.
- **Manual rough guess**:
  - Use a quick interactive tool to place US relative to CT.

**Rule of thumb**: If initial misalignment > ~10–15° or > 10–20 mm, add a **coarser pyramid level** and/or do a **limited grid search** around the guess (small ± ranges) **before** continuous optimization.

---

## 6) Multi-Resolution Pyramid (coarse→fine)

**Why**: Make MI cheaper and more convex at low resolutions; refine details later.

- **Shrink factors**: e.g., **[8, 4, 2, 1]** (or [4,2,1] for smaller volumes).
- **Smoothing per level**: Gaussian sigmas like **[3, 2, 1, 0]** vox (in physical units if supported).
- **Voxel sampling**:
  - Random subset of voxels for the metric (**2–5%** at coarse, increase to **5–10%** fine).
  - Ensure sampling is **inside the ROI mask**.

**Tip**: Keep a **fixed random seed** for reproducibility, or use stochastic resampling per iteration for robustness.

---

## 7) Transformation Models (in order)

1. **Rigid (6-DOF)**: translations + rotations.  
   - Almost always **first**. Suitable if US is not significantly deformed.
2. **Similarity** (optional): rigid + uniform scale.  
   - If there’s mild scale mismatch.
3. **Affine**: adds shear; use with caution (can absorb intensity-to-geometry bias).
4. **Deformable (lightweight)**: **B-spline** with coarse control point spacing (e.g., **30–60 mm**).  
   - Only after a solid rigid fit; use **regularization** to prevent overfitting to speckle.

**CFRP advice**: Many setups are close to rigid. Try to **avoid** deformable unless probe pressure clearly distorts US.

---

## 8) Similarity Metrics

**Primary**: **Normalized Mutual Information (NMI)** with **Mattes** Parzen histogram:
- **Histogram bins**: **32–64** typically best; too many bins can overfit noise.
- **Parzen smoothing**: stabilizes derivatives; standard in Mattes MI.

**Secondary / hybrid options**:
- **Gradient-based** (e.g., normalized cross-correlation of gradient magnitude/orientation): adds edge agreement.
- **MIND** (Modality-Independent Neighborhood Descriptor): robust structural descriptor for multi-modal.
- **LC²** or local cross-correlation variants: good when local contrast patterns align.

**Blending**:
- Combine **NMI + λ·GradNCC** or **NMI + λ·MIND-NCC**. Start with small λ (e.g., 0.1–0.3) and tune.

---

## 9) Interpolation & Extrapolation

- **Interpolator**: **Linear** for intensity images during optimization; **BSpline** (order 3) is fine when resampling final outputs.
- **Extrapolation outside bounds**: clamp or use a defined background value (0). Avoid introducing artificial edges at borders.

---

## 10) Optimization Strategy

**Coarse levels (robust search)**:
- **Powell / Nelder-Mead / CMA-ES** (derivative-free) with bounded parameter ranges **or**
- **RegularStepGradientDescent** with cautious learning rate and sampling.

**Fine levels (fast convergence)**:
- **LBFGS / Conjugate Gradient / RegularStepGradientDescent**.

**General settings**:
- **Max iterations per level**: 150–300 (stop earlier if metric change < 1e-6 over 10 iters).
- **Learning rate**: start around **1–4** (physical shift scaling on), **min step** ~ **1e-3**.
- **Multi-start**: optionally run **3–5** starts with small random jitters if capture range is tight.

**Caching**:
- If toolkit supports it, cache pyramid images and masks to avoid recomputation.

---

## 11) Stopping Criteria & Safeguards

- Stop when **Δ(metric) < ε** (e.g., 1e-6) for **N** consecutive iterations (e.g., 10–20) or max iters reached.
- Monitor **step size**; if it collapses too early at coarse level, relax learning rate or broaden sampling.
- If metric **oscillates**: increase smoothing, reduce bins, or enlarge sampling %.

---

## 12) Validation & Quality Checks

**Qualitative**:
- Orthogonal slice overlays (CT base + US overlay).
- Checkerboard and edge overlay views.
- 3D rendering with iso-surfaces (e.g., CT part boundary vs. US echo edges).

**Quantitative**:
- **TRE** (Target Registration Error): if you have corresponding landmarks (porosity centroids, fiducials).
- **Dice / IoU**: between masks/ROIs (e.g., laminate region).
- **Defect overlap**: CT pores vs. US echo clusters; measure distance/overlap.

**Accept/Reject criteria**:
- Define thresholds for TRE, Dice, or defect alignment distances tailored to your part tolerances.

---

## 13) CFRP-Specific Tips

- **Use a CT-derived mask** to exclude air and fixtures; MI is much more stable then.
- **Edge emphasis helps**: compute gradient magnitude (Sobel/Scharr) and either:
  - Build a **composite image** (e.g., concatenate intensity + gradient as channels in a custom metric), or
  - Weight the sampling by gradient magnitude (higher chance to pick informative voxels).
- **Anisotropic artifacts** in US (fiber direction):
  - Try mild **anisotropic smoothing** aligned with the laminate plane.
  - Avoid over-weighting regions prone to **shadowing** or **saturation**.
- **Probe pressure**:
  - If present, consider a **small B-spline** refinement (large grid spacing, strong regularization).

---

## 14) Performance Considerations

- **Downsample aggressively** at coarse levels; it dominates runtime.
- **Random voxel sampling** (2–5%) drastically reduces metric cost while keeping signal.
- **Parallelization**: use multi-threaded ITK/ANTs/Elastix builds. GPU frameworks exist (e.g., PyTorch-based) but require more engineering.
- **I/O**: keep volumes in memory between levels; avoid repeated disk reads.

---

## 15) Practical Parameter Starters

- **Pyramid**: shrink `[8,4,2,1]`, smoothing sigmas `[3,2,1,0]` (in **mm** if possible).
- **Sampling**: 3% (coarse) → 5% (mid) → 8% (fine) **inside ROI**.
- **NMI bins**: 48 (try 32–64).
- **Rigid first**; only then consider **affine** or **B-spline** (grid 40–60 mm).
- **Stopping**: ΔNMI < 1e-6 for 10 iters or max 250 iters/level.

---

## 16) Output Artifacts to Save

- Final **transform(s)**:
  - Rigid (and affine/B-spline if used), in a standard format (e.g., ITK transform file).
- **Resampled US→CT** (or CT→US) for visualization and downstream analysis.
- **Logs**:
  - Per-level metric values, iteration counts, parameter steps.
- **QC images**:
  - Before/after overlays, checkerboards, edge overlays.
- **Masks** used during optimization for reproducibility.

---

## 17) Troubleshooting Guide

- **Metric peaks but misaligned visually**:
  - Reduce bins; add smoothing; ensure ROI mask is sane; try hybrid metric with gradients.
- **Optimizer diverges or stalls**:
  - Lower learning rate; increase smoothing; broaden pyramid; improve initialization.
- **Overfitting (deformable too “wavy”)**:
  - Increase regularization; coarsen B-spline grid; revert to affine/rigid.
- **Speckle dominates**:
  - Stronger US denoise; use gradient-based terms; increase sampling percentage but keep ROI-focused.
- **Background leakage**:
  - Tighten masks; ensure background values are consistent and excluded.

---

## 18) Minimal “Recipe” You’ll Implement (high-level)

1. **Load** CT & US → **resample** to isotropic spacing & common orientation.  
2. **Preprocess**: CT clip+normalize; US denoise+log+normalize.  
3. **Build ROI masks** (CT-based; optional US confidence).  
4. **Init**: centers + PCA (or fiducials/calibration).  
5. **Register** with **multi-resolution pyramid** using **NMI (Mattes)**:  
   - Rigid at all levels → (optional) Affine → (optional) B-spline small.  
   - Random voxel sampling within ROI; linear interpolation.  
6. **Validate** visually + quantitatively (TRE, Dice, defect overlap).  
7. **Refine** if needed; **save** transforms, resampled volumes, and QC artifacts.

---

### Notes for the Notebook
- Keep **parameters at the top** (spacing, bins, sampling %) so you can rerun quickly.
- Implement **utility functions** for: resample, normalize, mask create, visualize overlays, compute metrics, and run the multi-level loop.
- For **reproducibility**, fix random seeds and export a **run log** with all chosen parameters.

---

In [1]:
# Environment & dependency check
import importlib, sys, platform

MIN_PY = (3, 10)
if sys.version_info < MIN_PY:
    raise RuntimeError(
        f"Python >= {MIN_PY[0]}.{MIN_PY[1]} is required (found {sys.version.split()[0]}). "
        "Upgrade Python or change type hints (Image|None -> typing.Optional)."
    )

def check(mod_name, import_name=None):
    name = import_name or mod_name
    try:
        m = importlib.import_module(name)
        ver = getattr(m, "__version__", "OK")
        print(f"{mod_name:<15} {ver}")
        return True
    except Exception as e:
        print(f"{mod_name:<15} NOT INSTALLED -> {e}")
        return False

print("Python  :", sys.version.split()[0])
print("Platform:", platform.platform())

required = ["numpy", "scipy", "matplotlib", "SimpleITK"]
missing = [m for m in required if not check(m)]

if missing:
    print("\n⚠️ Missing dependencies:", ", ".join(missing))
    print("Install with pip:\n  pip install " + " ".join(missing))
    print("or with conda:\n  conda install -c conda-forge " + " ".join(missing))
    raise SystemExit(1)

# Extra info: SimpleITK build/version
try:
    import SimpleITK as sitk
    try:
        ver = f"{sitk.Version_MajorVersion()}.{sitk.Version_MinorVersion()}.{getattr(sitk,'Version_PatchVersion', lambda: 'x')()}"
    except Exception:
        ver = "unknown"
    print("SimpleITK build:", ver)
except Exception:
    pass


Python  : 3.13.5
Platform: Windows-11-10.0.26100-SP0
numpy           2.3.1
scipy           1.16.0
matplotlib      3.10.5
SimpleITK       2.5.2
SimpleITK build: 2.5.2


In [2]:
# Parameters (edit these for your dataset & preferences)

from pathlib import Path

# --- Input paths (use raw strings on Windows to avoid backslash escapes) ---
CT_PATH = Path(r"J:\Dev\xct_measurements\aligned_90rotright_reslicetop_JI_7.tif")   # 3D CT stack (multi-page TIFF)
US_PATH = Path(r"J:\Dev\CompSTLar\tiff_results\rf_abs.tif")           # 3D US volume saved as NumPy .npy

# Optional: precomputed masks (if available). If None, masks will be created from images.
CT_MASK_PATH = None                     # Path(r"J:\...\ct_mask.nii.gz")
US_MASK_PATH = None                     # Path(r"J:\...\us_mask.nii.gz")

# --- Outputs ---
OUT_DIR = Path(r"J:\Dev\CompSTLar\reg_results")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# --- Native physical spacing (CRITICAL) ---
# Set these to the true voxel sizes in millimeters (x, y, z).
# UT is anisotropic 1.0 x 1.0 x 0.16 mm as you mentioned.
US_NATIVE_SPACING_MM = (1.0, 1.0, 0.016)

# TODO: replace with your real CT voxel size (for TIFF stacks metadata is often missing).
CT_NATIVE_SPACING_MM = (0.025, 0.025, 0.025)

# --- Registration geometry strategy (two-stage) ---
# Stage 1 (coarse): make both isotropic for a robust initial alignment.
COARSE_SPACING_MM = (0.8, 0.8, 0.8)

# Stage 2 (fine): keep US at its native anisotropy; make CT reasonably fine isotropic.
# Note: we will NOT resample US in the fine stage; US stays as acquired.
CT_FINE_SPACING_MM   = (0.5, 0.5, 0.5)
US_KEEP_NATIVE_FOR_FINE = True

# Final export choice: resample CT onto the native US grid (US remains untouched).
EXPORT_CT_INTO_US_SPACE = True

# (Deprecated in this workflow) Single-target spacing:
# Keeping this here commented just to avoid confusion in later cells; we won't use it anymore.
# TARGET_SPACING_MM = (0.5, 0.5, 0.5)

# Optional cropping margin around masks (in mm). Set >0 to auto-crop after mask creation.
CROP_MARGIN_MM = 0.0

# --- Preprocessing ---
CT_CLIP_P = (2.0, 98.0)                # percentile window for CT
US_LOG_ALPHA = 2.0                     # log compression factor for US (0 disables)
US_SMOOTH_SIGMA = 0.8                  # Gaussian smoothing sigma (in voxels) for US denoising

# --- Registration (rigid) ---
MI_BINS = 48
SAMPLING_PERCENT = 0.05                # 3–8% typical
PYR_SHRINK = [8, 4, 2, 1]              # multi-resolution pyramid shrink factors
PYR_SIGMAS = [3, 2, 1, 0]              # smoothing (in mm) per level

MAX_ITERS = 250
LEARNING_RATE = 2.0
MIN_STEP = 1e-3
RANDOM_SEED = 42

# --- Optional deformable refinement (B-spline) ---
USE_BSPLINE = False                    # set True if probe pressure causes local distortions
BSPLINE_GRID_SPACING_MM = 50.0         # coarse grid spacing (mm)
BSPLINE_MAX_ITERS = 150


In [3]:
import numpy as np
import SimpleITK as sitk
from pathlib import Path
from typing import Optional, Tuple

# ---------- I/O helpers ----------

def sitk_from_numpy(
    arr: np.ndarray,
    spacing_mm: Tuple[float, float, float] = (1.0, 1.0, 1.0),
    origin: Tuple[float, float, float] = (0.0, 0.0, 0.0),
    direction: Optional[Tuple[float, ...]] = None,
    pixel_type=sitk.sitkFloat32
) -> sitk.Image:
    """
    Create a SimpleITK 3D image from a NumPy array.
    Expects arr with shape (Z, Y, X). If your array is (X, Y, Z), transpose before calling.
    """
    if arr.ndim != 3:
        raise ValueError(f"Expected a 3D array, got shape {arr.shape}")
    img = sitk.GetImageFromArray(arr)  # NumPy (Z,Y,X) -> ITK image
    img = sitk.Cast(img, pixel_type)
    img.SetSpacing(tuple(spacing_mm))
    img.SetOrigin(tuple(origin))
    if direction is None:
        # Identity direction (3x3 matrix flattened row-major)
        direction = (1.0, 0.0, 0.0,
                     0.0, 1.0, 0.0,
                     0.0, 0.0, 1.0)
    img.SetDirection(direction)
    return img

def read_image_any(
    path: Path,
    spacing_mm: Optional[Tuple[float, float, float]] = None,
    pixel_type=sitk.sitkFloat32,
    numpy_axis_order: str = "zyx"
) -> sitk.Image:
    """
    Read 3D image from various formats:
    - .tif/.tiff: multi-page TIFF stack (XCT).
    - .npy: NumPy array (expects 3D; assumes (Z,Y,X) by default).
    - others: delegated to SimpleITK.ReadImage.

    spacing_mm: if provided, overrides header spacing (recommended for TIFF/NPY).
    numpy_axis_order: "zyx" if array is (Z,Y,X). Use "xyz" if your array is (X,Y,Z).
    """
    ext = path.suffix.lower()
    if ext == ".npy":
        arr = np.load(str(path))
        if arr.ndim != 3:
            raise ValueError(f"NumPy file must be 3D, got shape {arr.shape}")
        if numpy_axis_order.lower() == "xyz":
            # Convert (X,Y,Z) -> (Z,Y,X)
            arr = np.transpose(arr, (2, 1, 0))
        img = sitk_from_numpy(arr, spacing_mm=spacing_mm or (1.0, 1.0, 1.0), pixel_type=pixel_type)
        return img

    # For TIFF and other image formats supported by SimpleITK:
    img = sitk.ReadImage(str(path))
    if pixel_type is not None and img.GetPixelID() != pixel_type:
        img = sitk.Cast(img, pixel_type)

    # Override spacing if provided (TIFF stacks often lack reliable spacing)
    if spacing_mm is not None:
        img.SetSpacing(tuple(spacing_mm))
    return img

def write_image(img: sitk.Image, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    sitk.WriteImage(img, str(path), useCompression=True)

# ---------- Stats & normalization ----------

def np_percentile_image(img: sitk.Image, p_low: float, p_high: float):
    arr = sitk.GetArrayFromImage(img).astype(np.float32)
    lo, hi = np.percentile(arr, [p_low, p_high])
    return float(lo), float(hi)

def normalize_to_unit(img: sitk.Image, lo: float, hi: float) -> sitk.Image:
    """Map [lo, hi] -> [0,1], clamp outside."""
    img_f = sitk.Cast(img, sitk.sitkFloat32)
    img_f = sitk.Clamp(img_f, lowerBound=lo, upperBound=hi)
    return sitk.RescaleIntensity(img_f, 0.0, 1.0)

def normalize_ct(img: sitk.Image, clip_p=(2.0, 98.0)) -> sitk.Image:
    lo, hi = np_percentile_image(img, clip_p[0], clip_p[1])
    return normalize_to_unit(img, lo, hi)

def normalize_us(img: sitk.Image, log_alpha=2.0, smooth_sigma=0.8) -> sitk.Image:
    us = sitk.Cast(img, sitk.sitkFloat32)
    if smooth_sigma and smooth_sigma > 0:
        us = sitk.DiscreteGaussian(us, smooth_sigma)
    # Simple log compression after percentile normalization
    lo, hi = np_percentile_image(us, 1.0, 99.0)
    us = normalize_to_unit(us, lo, hi)
    if log_alpha and log_alpha > 0:
        arr = sitk.GetArrayFromImage(us)
        arr = np.log1p(log_alpha * arr)
        arr /= (arr.max() if arr.max() > 0 else 1.0)
        us = sitk.GetImageFromArray(arr.astype(np.float32))
        # Preserve geometry
        us.CopyInformation(img)
    return us

# ---------- Resampling ----------

def resample_to_spacing(
    img: sitk.Image,
    target_spacing_xyz: Tuple[float, float, float],
    interpolator=sitk.sitkLinear,
    default_value=0.0
) -> sitk.Image:
    """
    Resample to an arbitrary target spacing (isotropic or anisotropic).
    Preserves origin/direction and computes the new size to preserve physical extent.
    """
    original_spacing = np.array(img.GetSpacing(), dtype=float)
    original_size = np.array(img.GetSize(), dtype=int)
    target_spacing = np.array(target_spacing_xyz, dtype=float)

    new_size = np.round(original_size * (original_spacing / target_spacing)).astype(int).tolist()

    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(img)
    resampler.SetOutputSpacing(tuple(target_spacing))
    resampler.SetSize([int(x) for x in new_size])
    resampler.SetInterpolator(interpolator)
    resampler.SetOutputOrigin(img.GetOrigin())
    resampler.SetOutputDirection(img.GetDirection())
    resampler.SetDefaultPixelValue(default_value)
    return resampler.Execute(img)

def resample_isotropic(img: sitk.Image, target_spacing_xyz: Tuple[float, float, float], interpolator=sitk.sitkLinear, default_value=0.0) -> sitk.Image:
    """Alias kept for compatibility; delegates to resample_to_spacing."""
    return resample_to_spacing(img, target_spacing_xyz, interpolator=interpolator, default_value=default_value)

# ---------- Masks & metrics ----------

def otsu_mask(img: sitk.Image, closing_radius=2, keep_largest=True) -> sitk.Image:
    """
    Basic Otsu threshold + morphological closing + keep largest component.
    NOTE: SimpleITK expects a vector radius for closing; we pass [r, r, r].
    """
    # Light smoothing to stabilize Otsu on noisy data
    smooth = sitk.CurvatureFlow(image1=img, timeStep=0.125, numberOfIterations=3)

    # Otsu -> binary {0,1}
    mask = sitk.OtsuThreshold(smooth, 0, 1, 200)

    # Closing with per-axis radius (in voxels)
    if closing_radius and closing_radius > 0:
        r = int(closing_radius)
        mask = sitk.BinaryMorphologicalClosing(mask, [r, r, r])  # <-- FIX: pass a vector radius

    # Keep largest connected component (optional but helps with background)
    if keep_largest:
        cc = sitk.ConnectedComponent(mask)
        stats = sitk.LabelShapeStatisticsImageFilter()
        stats.Execute(cc)
        if stats.GetNumberOfLabels() > 0:
            # Choose the label with largest physical size
            areas = [(lab, stats.GetPhysicalSize(lab)) for lab in stats.GetLabels()]
            largest = max(areas, key=lambda x: x[1])[0]
            mask = sitk.BinaryThreshold(cc, lowerThreshold=largest, upperThreshold=largest, insideValue=1, outsideValue=0)

    mask = sitk.Cast(mask, sitk.sitkUInt8)
    mask.CopyInformation(img)
    return mask


def apply_mask(img: sitk.Image, mask: sitk.Image, outside_value=0.0) -> sitk.Image:
    return sitk.Mask(img, sitk.Cast(mask, sitk.sitkUInt8), outsideValue=outside_value)

def init_rigid_by_moments(fixed: sitk.Image, moving: sitk.Image) -> sitk.Transform:
    return sitk.CenteredTransformInitializer(
        fixed, moving, sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.MOMENTS
    )

def transform_and_resample(
    moving: sitk.Image,
    fixed: sitk.Image,
    transform: sitk.Transform,
    interp=sitk.sitkLinear,
    default_value=0.0
) -> sitk.Image:
    """Resample 'moving' into 'fixed' geometry using the provided transform."""
    return sitk.Resample(moving, fixed, transform, interp, default_value, moving.GetPixelID())

def dice_coefficient(mask_a: sitk.Image, mask_b: sitk.Image) -> float:
    """Dice for binary masks {0,1}."""
    a = sitk.Cast(mask_a > 0, sitk.sitkUInt8)
    b = sitk.Cast(mask_b > 0, sitk.sitkUInt8)
    inter = sitk.GetArrayFromImage(a & b).sum()
    vol_a = sitk.GetArrayFromImage(a).sum()
    vol_b = sitk.GetArrayFromImage(b).sum()
    return (2.0 * inter) / (vol_a + vol_b + 1e-8)

# ---------- Quick visualization ----------

def show_overlay_slices(fixed: sitk.Image, moving: sitk.Image, title="Overlay", num_slices=6, axis=0):
    """
    Show paired slices: fixed in gray, moving overlaid with transparency.
    axis: 0=z, 1=y, 2=x (arrays are indexed in (Z,Y,X) order).
    """
    import matplotlib.pyplot as plt
    f_arr = sitk.GetArrayFromImage(fixed)
    m_arr = sitk.GetArrayFromImage(moving)
    if axis not in (0, 1, 2):
        axis = 0
    size = f_arr.shape[axis]
    indices = np.linspace(int(size * 0.1), int(size * 0.9), num_slices).astype(int)
    plt.figure(figsize=(12, 2 * num_slices))
    for i, idx in enumerate(indices, 1):
        if axis == 0:
            f_slice = f_arr[idx, :, :]
            m_slice = m_arr[idx, :, :]
        elif axis == 1:
            f_slice = f_arr[:, idx, :]
            m_slice = m_arr[:, idx, :]
        else:
            f_slice = f_arr[:, :, idx]
            m_slice = m_arr[:, :, idx]
        plt.subplot(num_slices, 2, 2 * i - 1)
        plt.imshow(f_slice, cmap="gray")
        plt.title(f"{title} - Fixed (slice {idx})")
        plt.axis("off")
        plt.subplot(num_slices, 2, 2 * i)
        plt.imshow(f_slice, cmap="gray")
        plt.imshow(m_slice, alpha=0.5)  # semi-transparent overlay
        plt.title("Overlay: Moving on Fixed")
        plt.axis("off")
    plt.tight_layout()
    plt.show()


In [4]:
import SimpleITK as sitk
from typing import Optional, Sequence, Tuple

def _ensure_mask(mask: Optional[sitk.Image]) -> Optional[sitk.Image]:
    if mask is None:
        return None
    return sitk.Cast(mask, sitk.sitkUInt8)

def register_rigid_mi(
    fixed: sitk.Image,
    moving: sitk.Image,
    fixed_mask: Optional[sitk.Image] = None,
    moving_mask: Optional[sitk.Image] = None,
    bins: int = 48,
    sampling: float = 0.05,
    shrink: Sequence[int] = (8, 4, 2, 1),
    sigmas_mm: Sequence[float] = (3, 2, 1, 0),
    max_iters: int = 250,
    lr: float = 2.0,
    min_step: float = 1e-3,
    seed: int = 42,
    initial_transform: Optional[sitk.Transform] = None,
    use_physical_sigmas: bool = True,
    verbose: bool = False,
) -> Tuple[sitk.Transform, sitk.ImageRegistrationMethod]:
    """
    Rigid (6-DOF) registration using Mattes Mutual Information.
    - Supports ROI masks (fixed/moving).
    - Multi-resolution pyramid (shrink/sigmas).
    - Optional 'initial_transform' to start from a prior solution (recommended for fine stage).
    """
    if len(shrink) != len(sigmas_mm):
        raise ValueError(f"'shrink' and 'sigmas_mm' must have same length, got {len(shrink)} vs {len(sigmas_mm)}")

    reg = sitk.ImageRegistrationMethod()
    reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=bins)
    reg.SetMetricSamplingStrategy(reg.RANDOM)
    reg.SetMetricSamplingPercentage(sampling, seed)

    fmask = _ensure_mask(fixed_mask)
    mmask = _ensure_mask(moving_mask)
    if fmask is not None:
        reg.SetMetricFixedMask(fmask)
    if mmask is not None:
        reg.SetMetricMovingMask(mmask)

    reg.SetInterpolator(sitk.sitkLinear)

    # Optimizer: Regular step gradient descent (robust and simple to tune)
    reg.SetOptimizerAsRegularStepGradientDescent(
        learningRate=lr,
        minStep=min_step,
        numberOfIterations=max_iters,
        relaxationFactor=0.5
    )
    reg.SetOptimizerScalesFromPhysicalShift()

    # Multi-resolution
    reg.SetShrinkFactorsPerLevel(list(shrink))
    reg.SetSmoothingSigmasPerLevel(list(sigmas_mm))
    if use_physical_sigmas:
        reg.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
    else:
        reg.SmoothingSigmasAreSpecifiedInPhysicalUnitsOff()

    # Initialization
    if initial_transform is not None:
        reg.SetInitialTransform(initial_transform, inPlace=False)
    else:
        init_tx = sitk.CenteredTransformInitializer(
            fixed, moving, sitk.Euler3DTransform(),
            sitk.CenteredTransformInitializerFilter.MOMENTS
        )
        reg.SetInitialTransform(init_tx, inPlace=False)

    # Verbose callbacks
    if verbose:
        def _on_level():
            lvl = reg.GetCurrentLevel() if hasattr(reg, "GetCurrentLevel") else "?"
            print(f"\n-- Entering pyramid level {lvl} --")
        def _on_iter():
            try:
                it = reg.GetOptimizerIteration()
                mv = reg.GetMetricValue()
                lr_now = reg.GetOptimizerLearningRate()
                print(f"Iter {it:4d} | metric {mv:.6f} | lr {lr_now:.4g}")
            except Exception:
                pass
        reg.AddCommand(sitk.sitkMultiResolutionIterationEvent, _on_level)
        reg.AddCommand(sitk.sitkIterationEvent, _on_iter)

    final_tx = reg.Execute(fixed, moving)
    return final_tx, reg


def register_bspline_refine(
    fixed: sitk.Image,
    moving: sitk.Image,
    initial_tx: sitk.Transform,
    fixed_mask: Optional[sitk.Image] = None,
    moving_mask: Optional[sitk.Image] = None,
    bins: int = 48,
    sampling: float = 0.05,
    grid_spacing_mm: float = 50.0,
    shrink: Sequence[int] = (4, 2, 1),
    sigmas_mm: Sequence[float] = (2, 1, 0),
    max_iters: int = 150,
    seed: int = 42,
    use_physical_sigmas: bool = True,
    verbose: bool = False,
) -> Tuple[sitk.Transform, sitk.ImageRegistrationMethod]:
    """
    Lightweight deformable refinement (B-spline) on top of an existing transform (usually rigid).
    Use a coarse grid (e.g., 40–60 mm) and strong regularization-friendly optimizer (LBFGSB).
    """
    if len(shrink) != len(sigmas_mm):
        raise ValueError(f"'shrink' and 'sigmas_mm' must have same length, got {len(shrink)} vs {len(sigmas_mm)}")

    # Initialize a coarse B-spline over the fixed domain
    grid_physical_spacing = [float(grid_spacing_mm)] * 3
    image_physical_size = [sz * sp for sz, sp in zip(fixed.GetSize(), fixed.GetSpacing())]
    mesh_size = [max(int(round(ps / gps)) - 1, 1) for ps, gps in zip(image_physical_size, grid_physical_spacing)]
    bspline_tx = sitk.BSplineTransformInitializer(fixed, mesh_size)

    reg = sitk.ImageRegistrationMethod()
    reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=bins)
    reg.SetMetricSamplingStrategy(reg.RANDOM)
    reg.SetMetricSamplingPercentage(sampling, seed)

    fmask = _ensure_mask(fixed_mask)
    mmask = _ensure_mask(moving_mask)
    if fmask is not None:
        reg.SetMetricFixedMask(fmask)
    if mmask is not None:
        reg.SetMetricMovingMask(mmask)

    reg.SetInterpolator(sitk.sitkLinear)

    # Optimizer: prefer LBFGSB (bounded quasi-Newton). Fall back to RegularStepGD if unavailable.
    try:
        reg.SetOptimizerAsLBFGSB(
            gradientConvergenceTolerance=1e-5,
            numberOfIterations=max_iters,
            maximumNumberOfCorrections=5,
            maximumNumberOfFunctionEvaluations=1000
        )
    except Exception:
        reg.SetOptimizerAsRegularStepGradientDescent(
            learningRate=1.0, minStep=1e-3,
            numberOfIterations=max_iters, relaxationFactor=0.5
        )
    reg.SetOptimizerScalesFromPhysicalShift()

    # Multi-resolution
    reg.SetShrinkFactorsPerLevel(list(shrink))
    reg.SetSmoothingSigmasPerLevel(list(sigmas_mm))
    if use_physical_sigmas:
        reg.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
    else:
        reg.SmoothingSigmasAreSpecifiedInPhysicalUnitsOff()

    # Compose: keep the rigid as moving initial transform, optimize only the B-spline on top
    reg.SetMovingInitialTransform(initial_tx)
    reg.SetInitialTransform(bspline_tx, inPlace=False)

    if verbose:
        def _on_level():
            lvl = reg.GetCurrentLevel() if hasattr(reg, "GetCurrentLevel") else "?"
            print(f"\n-- Entering B-spline level {lvl} --")
        def _on_iter():
            try:
                it = reg.GetOptimizerIteration()
                mv = reg.GetMetricValue()
                print(f"Iter {it:4d} | metric {mv:.6f}")
            except Exception:
                pass
        reg.AddCommand(sitk.sitkMultiResolutionIterationEvent, _on_level)
        reg.AddCommand(sitk.sitkIterationEvent, _on_iter)

    out_bspline = reg.Execute(fixed, moving)

    # Return a composite: initial (rigid) + bspline
    composite = sitk.Transform(initial_tx)
    composite.AddTransform(out_bspline)
    return composite, reg


In [5]:
# Load & prepare volumes (two-stage: coarse isotropic, fine with native US)
# All outputs saved as .tif stacks.

# 1) Read raw images with native spacing
ct_raw = read_image_any(CT_PATH, spacing_mm=CT_NATIVE_SPACING_MM, pixel_type=sitk.sitkFloat32)
us_raw = read_image_any(US_PATH, spacing_mm=US_NATIVE_SPACING_MM, pixel_type=sitk.sitkFloat32)  # , numpy_axis_order="zyx"

print("CT raw   -> size:", ct_raw.GetSize(), "spacing (mm):", ct_raw.GetSpacing())
print("US raw   -> size:", us_raw.GetSize(), "spacing (mm):", us_raw.GetSpacing())

# 2) Stage 1 (COARSE): resample both to the same isotropic spacing for robust global alignment
ct_coarse = resample_isotropic(ct_raw, COARSE_SPACING_MM, interpolator=sitk.sitkBSpline)
us_coarse = resample_isotropic(us_raw, COARSE_SPACING_MM, interpolator=sitk.sitkBSpline)

print("CT coarse -> size:", ct_coarse.GetSize(), "spacing (mm):", ct_coarse.GetSpacing())
print("US coarse -> size:", us_coarse.GetSize(), "spacing (mm):", us_coarse.GetSpacing())

# Normalize for metric stability (percentile windowing, denoise+log for US)
ct_coarse_n = normalize_ct(ct_coarse, clip_p=CT_CLIP_P)
us_coarse_n = normalize_us(us_coarse, log_alpha=US_LOG_ALPHA, smooth_sigma=US_SMOOTH_SIGMA)

# Coarse masks (CT from Otsu; US light threshold)
if CT_MASK_PATH is not None and CT_MASK_PATH.exists():
    ct_mask_coarse = read_image_any(CT_MASK_PATH, pixel_type=sitk.sitkUInt8)
else:
    ct_mask_coarse = otsu_mask(ct_coarse_n, closing_radius=2, keep_largest=True)

if US_MASK_PATH is not None and US_MASK_PATH.exists():
    us_mask_coarse = read_image_any(US_MASK_PATH, pixel_type=sitk.sitkUInt8)
else:
    lo_c, hi_c = np_percentile_image(us_coarse_n, 5.0, 99.0)
    us_mask_coarse = sitk.Cast(us_coarse_n > max(lo_c, 0.05), sitk.sitkUInt8)

# Save coarse artifacts as TIFF
write_image(sitk.Cast(ct_coarse_n, sitk.sitkFloat32), OUT_DIR / "ct_coarse_norm.tif")
write_image(sitk.Cast(us_coarse_n, sitk.sitkFloat32), OUT_DIR / "us_coarse_norm.tif")
write_image(sitk.Cast(ct_mask_coarse, sitk.sitkUInt8), OUT_DIR / "ct_mask_coarse.tif")
write_image(sitk.Cast(us_mask_coarse, sitk.sitkUInt8), OUT_DIR / "us_mask_coarse.tif")

# 3) Stage 2 (FINE): keep US at native spacing; make CT isotropic fine
ct_fine = resample_isotropic(ct_raw, CT_FINE_SPACING_MM, interpolator=sitk.sitkBSpline)
ct_fine_n = normalize_ct(ct_fine, clip_p=CT_CLIP_P)

if US_KEEP_NATIVE_FOR_FINE:
    us_fine = us_raw  # leave US untouched: native 1.0 x 1.0 x 0.16 mm
else:
    US_TARGET_SPACING_FINE = (0.8, 0.8, 0.3)  # example only
    us_pref = sitk.SmoothingRecursiveGaussian(us_raw, sigma=0.3)  # anti-alias in mm
    us_fine = resample_to_spacing(us_pref, US_TARGET_SPACING_FINE, interpolator=sitk.sitkBSpline)

us_fine_n = normalize_us(us_fine, log_alpha=US_LOG_ALPHA, smooth_sigma=US_SMOOTH_SIGMA)

# Fine masks
ct_mask_fine = otsu_mask(ct_fine_n, closing_radius=2, keep_largest=True)
lo_f, hi_f = np_percentile_image(us_fine_n, 5.0, 99.0)
us_mask_fine = sitk.Cast(us_fine_n > max(lo_f, 0.05), sitk.sitkUInt8)

# Save fine artifacts as TIFF
write_image(sitk.Cast(ct_fine_n, sitk.sitkFloat32), OUT_DIR / "ct_fine_norm.tif")
write_image(sitk.Cast(us_fine_n, sitk.sitkFloat32), OUT_DIR / "us_fine_norm.tif")
write_image(sitk.Cast(ct_mask_fine, sitk.sitkUInt8), OUT_DIR / "ct_mask_fine.tif")
write_image(sitk.Cast(us_mask_fine, sitk.sitkUInt8), OUT_DIR / "us_mask_fine.tif")

print("Preprocessing complete. Saved coarse & fine normalized images and masks as .tif in:", OUT_DIR)

# (Optional) Save sidecar JSON with spacing/origin/direction to avoid metadata loss in some TIFF viewers
try:
    import json
    meta = {
        "ct_raw": dict(spacing=ct_raw.GetSpacing(), origin=ct_raw.GetOrigin(), direction=ct_raw.GetDirection()),
        "us_raw": dict(spacing=us_raw.GetSpacing(), origin=us_raw.GetOrigin(), direction=us_raw.GetDirection()),
        "ct_coarse": dict(spacing=ct_coarse.GetSpacing()), "us_coarse": dict(spacing=us_coarse.GetSpacing()),
        "ct_fine": dict(spacing=ct_fine.GetSpacing()), "us_fine": dict(spacing=us_fine.GetSpacing())
    }
    with open(OUT_DIR / "spacing_metadata.json", "w") as f:
        json.dump(meta, f, indent=2)
    print("Saved spacing metadata JSON to help external tools interpret voxel sizes.")
except Exception as e:
    print("Could not save spacing metadata JSON:", e)


CT raw   -> size: (1618, 3282, 197) spacing (mm): (0.025, 0.025, 0.025)
US raw   -> size: (65, 109, 500) spacing (mm): (1.0, 1.0, 0.016)
CT coarse -> size: (51, 103, 6) spacing (mm): (0.8, 0.8, 0.8)
US coarse -> size: (81, 136, 10) spacing (mm): (0.8, 0.8, 0.8)
Preprocessing complete. Saved coarse & fine normalized images and masks as .tif in: J:\Dev\CompSTLar\reg_results
Saved spacing metadata JSON to help external tools interpret voxel sizes.


In [None]:
# --- Rigid registration (two-stage) and export CT -> native US grid ---
# Assumes you already have: ct_coarse_n, us_coarse_n, ct_mask_coarse, us_mask_coarse,
#                           ct_fine_n,   us_fine_n,   ct_mask_fine,   us_mask_fine,
#                           ct_raw, us_raw, OUT_DIR, and register_rigid_mi() defined.

import SimpleITK as sitk
from pathlib import Path

# Helper: save transforms robustly (handles CompositeTransform by falling back to .h5)
def save_transform_safely(tx: sitk.Transform, out_base: Path):
    """
    Save a SimpleITK transform robustly:
    - Try .tfm (text). If it fails (e.g., CompositeTransform), fallback to .h5.
    - If it's a Composite with a single sub-transform, save that sub-transform to .tfm for compatibility.
    """
    out_base = Path(out_base)
    tfm_path = out_base.with_suffix(".tfm")
    h5_path  = out_base.with_suffix(".h5")

    name = tx.GetName()
    print(f"Saving transform: {name}")

    # If composite with a single child, save child as .tfm
    try:
        if "Composite" in name and hasattr(tx, "GetNumberOfTransforms"):
            n = tx.GetNumberOfTransforms()
            print(f"Composite contains {n} sub-transform(s).")
            if n == 1:
                sub = tx.GetNthTransform(0)
                print(f"Saving the single sub-transform ({sub.GetName()}) as .tfm -> {tfm_path}")
                sitk.WriteTransform(sub, str(tfm_path))
                return
    except Exception as e:
        print("Could not inspect composite structure:", e)

    # Try .tfm directly
    try:
        sitk.WriteTransform(tx, str(tfm_path))
        print(f"Saved .tfm -> {tfm_path}")
        return
    except Exception as e:
        print(f".tfm save failed ({e}). Falling back to .h5 ...")

    # Fallback: .h5 supports CompositeTransform
    sitk.WriteTransform(tx, str(h5_path))
    print(f"Saved .h5 -> {h5_path}")


# ---------------- Stage 1: coarse isotropic rigid (robust global alignment) ----------------
# Robust initialization: align centers by geometry to guarantee overlap (multi-modal safe)
init_tx_geom = sitk.CenteredTransformInitializer(
    us_coarse_n, ct_coarse_n, sitk.Euler3DTransform(),
    sitk.CenteredTransformInitializerFilter.GEOMETRY
)

rigid_tx_stage1, reg1 = register_rigid_mi(
    fixed=us_coarse_n, moving=ct_coarse_n,
    fixed_mask=us_mask_coarse, moving_mask=ct_mask_coarse,
    bins=MI_BINS, sampling=SAMPLING_PERCENT,
    shrink=PYR_SHRINK, sigmas_mm=PYR_SIGMAS,
    max_iters=MAX_ITERS, lr=LEARNING_RATE, min_step=MIN_STEP,
    seed=RANDOM_SEED,
    initial_transform=init_tx_geom,
    verbose=True
)
print("Stage 1 (coarse) done.")
print("  NMI (coarse):", reg1.GetMetricValue())
save_transform_safely(rigid_tx_stage1, OUT_DIR / "rigid_stage1")

# Optional QC: resample CT (coarse) onto US (coarse) for a quick visual check
ct_on_us_coarse = transform_and_resample(
    moving=ct_coarse_n, fixed=us_coarse_n, transform=rigid_tx_stage1,
    interp=sitk.sitkLinear, default_value=0.0
)
write_image(sitk.Cast(ct_on_us_coarse, sitk.sitkFloat32), OUT_DIR / "ct_on_us_coarse.tif")
show_overlay_slices(us_coarse_n, ct_on_us_coarse, title="Coarse rigid QC (US space)", num_slices=6, axis=0)


# ---------------- Stage 2: fine rigid (US fixed native anisotropic, CT fine isotropic) — pre-warp CT to US space, then refine delta ----------------
# Pre-warp CT fine image and mask into US fine geometry using Stage 1 to guarantee overlap
ct_fine_on_us_stage1 = transform_and_resample(
    moving=ct_fine_n, fixed=us_fine_n, transform=rigid_tx_stage1,
    interp=sitk.sitkLinear, default_value=0.0
)
ct_mask_fine_on_us_stage1 = transform_and_resample(
    moving=ct_mask_fine, fixed=us_mask_fine, transform=rigid_tx_stage1,
    interp=sitk.sitkNearestNeighbor, default_value=0
)
ct_mask_fine_on_us_stage1 = sitk.Cast(ct_mask_fine_on_us_stage1 > 0, sitk.sitkUInt8)

# Build overlap fixed mask in US space (intersection of US mask and prewarped CT mask)
overlap_mask = sitk.Cast((us_mask_fine > 0) & (ct_mask_fine_on_us_stage1 > 0), sitk.sitkUInt8)
try:
    overlap_mask = sitk.BinaryDilate(overlap_mask, [2, 2, 2])
except Exception:
    pass

# Identity delta in US space (CT already expressed in US geometry)
delta_tx = sitk.Euler3DTransform()

reg2 = sitk.ImageRegistrationMethod()
reg2.SetMetricAsMattesMutualInformation(numberOfHistogramBins=MI_BINS)
reg2.SetMetricSamplingStrategy(reg2.RANDOM)
reg2.SetMetricSamplingPercentage(max(SAMPLING_PERCENT, 0.15), RANDOM_SEED)
reg2.SetMetricFixedMask(overlap_mask)
reg2.SetInterpolator(sitk.sitkLinear)
reg2.SetOptimizerAsRegularStepGradientDescent(
    learningRate=LEARNING_RATE,
    minStep=MIN_STEP,
    numberOfIterations=MAX_ITERS,
    relaxationFactor=0.5
)
reg2.SetOptimizerScalesFromPhysicalShift()
reg2.SetShrinkFactorsPerLevel([8, 4, 2, 1])
reg2.SetSmoothingSigmasPerLevel([3, 2, 1, 0])
reg2.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Optimize delta only (CT already expressed in US geometry)
reg2.SetInitialTransform(delta_tx, inPlace=False)

delta_opt = reg2.Execute(us_fine_n, ct_fine_on_us_stage1)

# Compose final fine-stage transform: Stage 1 + delta (CT->US)
rigid_tx_stage2 = sitk.Transform(rigid_tx_stage1)
# Optional QC: resample US (fine) onto CT (fine) for visual check in CT space
us_on_ct_fine = transform_and_resample(
    moving=us_fine_n, fixed=ct_fine_n, transform=rigid_tx_stage2,
print("  NMI (fine):", reg2.GetMetricValue())
save_transform_safely(rigid_tx_stage2, OUT_DIR / "rigid_stage2")
write_image(sitk.Cast(us_on_ct_fine, sitk.sitkFloat32), OUT_DIR / "us_on_ct_fine.tif")
show_overlay_slices(ct_fine_n, us_on_ct_fine, title="Fine rigid QC (CT space)", num_slices=6, axis=0)
ct_on_us_fine = transform_and_resample(
    moving=ct_fine_n, fixed=us_fine_n, transform=rigid_tx_stage2,
    interp=sitk.sitkLinear, default_value=0.0
)
write_image(sitk.Cast(ct_on_us_fine, sitk.sitkFloat32), OUT_DIR / "ct_on_us_fine.tif")
show_overlay_slices(us_fine_n, ct_on_us_fine, title="Fine rigid QC (US space)", num_slices=6, axis=0)


# ---------------- Final export: CT -> native US grid (US unchanged) ----------------
# We registered with fixed=CT and moving=US, so rigid_tx_stage2 maps US -> CT.
# To place CT onto the US grid, use the inverse transform.
try:
    tx_ct_to_us = rigid_tx_stage2.GetInverse()
except Exception as e:
    raise RuntimeError(f"Could not compute inverse transform for CT->US export: {e}")

# Save the inverse too (often useful to reuse later)
save_transform_safely(tx_ct_to_us, OUT_DIR / "rigid_stage2_inverse")

# Option A (visualization): normalized CT -> US space
ct_in_us_vis = transform_and_resample(
    moving=ct_fine_n, fixed=us_fine_n, transform=tx_ct_to_us,
    interp=sitk.sitkLinear, default_value=0.0
)
write_image(sitk.Cast(ct_in_us_vis, sitk.sitkFloat32), OUT_DIR / "ct_in_us_space_VIS.tif")

# Option B (analysis / preserve original CT units): raw CT -> US space
ct_in_us_raw = transform_and_resample(
    moving=ct_raw, fixed=us_fine_n, transform=tx_ct_to_us,
    interp=sitk.sitkBSpline, default_value=0.0
)
write_image(ct_in_us_raw, OUT_DIR / "ct_in_us_space_RAW.tif")

print("Export complete. CT has been resampled into native US space. Files saved in:", OUT_DIR)

# Quick overlay QC in US space (Fixed=US native, Moving=CT->US)
show_overlay_slices(us_fine_n, ct_in_us_vis, title="US fixed, CT->US moving", num_slices=6, axis=0)



-- Entering pyramid level 0 --
Iter    0 | metric -1.098609 | lr 2


RuntimeError: Exception thrown in SimpleITK ImageRegistrationMethod_Execute: D:\bld\libsimpleitk_1750106088442\_h_env\Library\include\ITK-5.4\itkMattesMutualInformationImageToImageMetricv4.hxx:311:
ITK ERROR: MattesMutualInformationImageToImageMetricv4(000001EDEDB326E0): All samples map outside moving image buffer. The images do not sufficiently overlap. They need to be initialized to have more overlap before this metric will work. For instance, you can align the image centers by translation.


In [None]:
# --- Optional B-spline refinement on top of the fine rigid result ---
# Requires:
#   - ct_fine_n, us_fine_n, ct_mask_fine, us_mask_fine
#   - rigid_tx_stage2 (fine rigid transform)
#   - register_bspline_refine(), save_transform_safely(), transform_and_resample(), write_image()

if USE_BSPLINE:
    print("Starting B-spline refinement on top of the fine rigid transform...")

    bspline_tx, bspline_reg = register_bspline_refine(
        fixed=ct_fine_n,                 # CT (fine, isotropic)
        moving=us_fine_n,                # US (native anisotropic)
        initial_tx=rigid_tx_stage2,      # start from the fine rigid
        fixed_mask=ct_mask_fine,
        moving_mask=us_mask_fine,
        bins=MI_BINS,
        sampling=SAMPLING_PERCENT,
        grid_spacing_mm=BSPLINE_GRID_SPACING_MM,  # e.g., 40–60 mm is typical
        shrink=(4, 2, 1),                # slightly tighter pyramid
        sigmas_mm=(2, 1, 0),
        max_iters=BSPLINE_MAX_ITERS,
        seed=RANDOM_SEED,
        use_physical_sigmas=True,
        verbose=True
    )

    print("B-spline refinement done.")
    print("  NMI (deformable):", bspline_reg.GetMetricValue())

    # Save transform robustly (CompositeTransform -> .h5 fallback)
    save_transform_safely(bspline_tx, OUT_DIR / "composite_rigid_bspline")

    # QC in CT space: resample US onto CT using the composite (rigid + bspline)
    us_bspline_on_ct = transform_and_resample(
        moving=us_fine_n, fixed=ct_fine_n, transform=bspline_tx,
        interp=sitk.sitkLinear, default_value=0.0
    )
    write_image(sitk.Cast(us_bspline_on_ct, sitk.sitkFloat32), OUT_DIR / "us_bspline_on_ct.tif")
    show_overlay_slices(ct_fine_n, us_bspline_on_ct, title="B-spline QC (CT space)", num_slices=6, axis=0)

    # Attempt CT -> US export for the deformable result (usually NOT available analytically)
    try:
        tx_ct_to_us_bs = bspline_tx.GetInverse()  # many B-spline transforms don't support analytic inverse
        ct_in_us_vis_bs = transform_and_resample(
            moving=ct_fine_n, fixed=us_fine_n, transform=tx_ct_to_us_bs,
            interp=sitk.sitkLinear, default_value=0.0
        )
        write_image(sitk.Cast(ct_in_us_vis_bs, sitk.sitkFloat32), OUT_DIR / "ct_in_us_space_VIS_bspline.tif")

        ct_in_us_raw_bs = transform_and_resample(
            moving=ct_raw, fixed=us_fine_n, transform=tx_ct_to_us_bs,
            interp=sitk.sitkBSpline, default_value=0.0
        )
        write_image(ct_in_us_raw_bs, OUT_DIR / "ct_in_us_space_RAW_bspline.tif")
        print("Deformable CT->US export complete (inverse available).")
    except Exception as e:
        print("WARNING: Could not compute inverse for the composite B-spline transform.")
        print("         Skipping deformable CT->US export. Use the rigid result for CT->US,")
        print("         or export a displacement field and invert it with a dedicated method (advanced).")


In [None]:
# --- Quantitative validation: Dice between masks in CT space and (if possible) in US space ---

import SimpleITK as sitk
from pathlib import Path

def load_transform_if_needed(preferred_bases):
    """
    Try to load a transform from a list of base paths (without extension).
    Checks .h5 first (supports composite), then .tfm.
    Returns the first transform found, or raises FileNotFoundError.
    """
    for base in preferred_bases:
        base = Path(base)
        h5 = base.with_suffix(".h5")
        tfm = base.with_suffix(".tfm")
        if h5.exists():
            return sitk.ReadTransform(str(h5))
        if tfm.exists():
            return sitk.ReadTransform(str(tfm))
    raise FileNotFoundError(f"No transform file found for bases: {preferred_bases}")

# 1) Choose the transform US->CT to apply (prefer B-spline if enabled and available, else rigid fine)
tx_us_to_ct = None
if 'bspline_tx' in globals() and USE_BSPLINE:
    tx_us_to_ct = bspline_tx
else:
    # if not in memory (e.g., after restart), try to load from disk
    try:
        preferred = [OUT_DIR / "composite_rigid_bspline"] if USE_BSPLINE else [OUT_DIR / "rigid_stage2"]
        tx_us_to_ct = load_transform_if_needed(preferred)
    except Exception:
        # fallback to in-memory rigid if available
        if 'rigid_tx_stage2' in globals():
            tx_us_to_ct = rigid_tx_stage2
        else:
            raise RuntimeError("No suitable transform available for validation.")

# 2) Dice in CT space: bring US mask -> CT space and compare to CT mask (fine)
us_mask_in_ct = transform_and_resample(
    moving=us_mask_fine,  # binary mask in US space (fine)
    fixed=ct_mask_fine,   # target geometry = CT fine
    transform=tx_us_to_ct,
    interp=sitk.sitkNearestNeighbor,
    default_value=0
)
us_mask_in_ct = sitk.Cast(us_mask_in_ct > 0, sitk.sitkUInt8)

dice_ct_space = dice_coefficient(ct_mask_fine, us_mask_in_ct)
print(f"Dice (CT space): CT_mask vs (US_mask -> CT) = {dice_ct_space:.4f}")
write_image(us_mask_in_ct, OUT_DIR / "us_mask_in_ct.tif")

# 3) Dice in US space (optional): requires inverse (CT->US)
dice_us_space = None
tx_ct_to_us = None

# Try in-memory inverse first
try:
    tx_ct_to_us = tx_us_to_ct.GetInverse()
except Exception:
    # If rigid inverse was saved earlier, try to load it
    try:
        tx_ct_to_us = load_transform_if_needed([OUT_DIR / "rigid_stage2_inverse"])
    except Exception:
        tx_ct_to_us = None

if tx_ct_to_us is not None:
    ct_mask_in_us = transform_and_resample(
        moving=ct_mask_fine,  # CT mask in CT space
        fixed=us_mask_fine,   # target geometry = US fine (native)
        transform=tx_ct_to_us,
        interp=sitk.sitkNearestNeighbor,
        default_value=0
    )
    ct_mask_in_us = sitk.Cast(ct_mask_in_us > 0, sitk.sitkUInt8)
    dice_us_space = dice_coefficient(ct_mask_in_us, us_mask_fine)
    print(f"Dice (US space): (CT_mask -> US) vs US_mask = {dice_us_space:.4f}")
    write_image(ct_mask_in_us, OUT_DIR / "ct_mask_in_us.tif")
else:
    print("Note: Inverse transform not available (likely due to B-spline). Skipping Dice in US space.")

# 4) Optional: quick overlay QC (comment out if not needed)
# show_overlay_slices(ct_fine_n, sitk.Cast(us_mask_in_ct*255, sitk.sitkFloat32), title="US mask -> CT (overlay)", num_slices=6, axis=0)
# if tx_ct_to_us is not None:
#     show_overlay_slices(us_fine_n, sitk.Cast(ct_mask_in_us*255, sitk.sitkFloat32), title="CT mask -> US (overlay)", num_slices=6, axis=0)


Dice (CT space): CT_mask vs (US_mask -> CT) = 0.8606
Dice (US space): (CT_mask -> US) vs US_mask = 0.0829


## Tuning Notes & Next Steps

- If the rigid result is **visually off**:
  - Increase smoothing at coarse levels (`PYR_SIGMAS`), reduce `MI_BINS` (e.g., 32), or increase `SAMPLING_PERCENT`.
  - Verify masks (save and inspect `ct_mask.nii.gz`, `us_mask.nii.gz`). Background leaking will hurt NMI.
  - Add a coarser shrink level (e.g., `[16,8,4,2,1]`) if initial offset is large.
- If the optimizer **stalls or oscillates**:
  - Try a lower `LEARNING_RATE` (1.0) or increase `MIN_STEP` slightly (2e-3).
  - Consider `PYR_SIGMAS` in **voxels** if your SimpleITK build doesn't support physical units well.
- If **speckle dominates** in US:
  - Increase `US_SMOOTH_SIGMA` or try anisotropic diffusion (SimpleITK `CurvatureAnisotropicDiffusion`).
  - Blend a gradient-based term (requires a custom metric; not included here for simplicity).
- If probe pressure causes **local distortions**:
  - Enable `USE_BSPLINE=True` with a large grid spacing (40–60 mm) and strong regularization (LBFGSB).
- Always **save** transforms and resampled outputs in `OUT_DIR` for reproducibility and downstream analysis.