In [1]:
# !pip install SimpleITK opencv-python -q

In [2]:
import numpy as np
import pandas as pd
import cv2
import SimpleITK as sitk
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Tuple
import math

BASE_DIR = Path("./project_data")
RAW_ATLAS_DIR = BASE_DIR / "raw_atlas_slices"
SYNTHETIC_TEST_DIR = BASE_DIR / "test_synthetic"
gt_df = pd.read_csv(SYNTHETIC_TEST_DIR / "ground_truth.csv")




### üîç Function for Finding the Most Similar Brain Slice (Mutual Information Registration)


In [3]:
def find_most_similar_slices(query_slice, atlas_dir):
    """Find the most similar slice from a reference atlas directory based on Mutual Information (MI)."""

    # Adjustable parameters
    DOWNSAMPLE_FACTOR = 2       # Downsample to accelerate (>1 means smaller)
    HIST_BINS = 64              # Number of bins for the mutual information histogram
    SAMPLE_PERCENT = 0.2        # Sampling ratio for mutual information (used during registration)
    MAX_ITER = 150              # Maximum number of iterations per slice
    ROT_RANGE_DEG = 20          # Allowable small-angle perturbation (helps avoid local minima)

    # Convert numpy grayscale image to SimpleITK image and optionally downsample
    def to_sitk_gray(img_np, down=1):
        if down > 1:
            img_np = cv2.resize(
                img_np,
                (img_np.shape[1] // down, img_np.shape[0] // down),
                interpolation=cv2.INTER_AREA
            )
        return sitk.GetImageFromArray(img_np.astype(np.float32))

    # Generate binary mask (ignore pure black background)
    def make_mask(img_np, down=1, thresh=0, min_pixels=50):
        # Use ">0" as foreground to preserve weak gray values
        if down > 1:
            img_np = cv2.resize(
                img_np,
                (img_np.shape[1] // down, img_np.shape[0] // down),
                interpolation=cv2.INTER_NEAREST
            )
        mask = (img_np > thresh).astype(np.uint8)
        if mask.sum() < min_pixels:
            return None  # Skip if too few valid pixels
        return sitk.GetImageFromArray(mask)

    # Compute Normalized Mutual Information (NMI)
    def normalized_mutual_information(x, y, mask=None, bins=64):
        x = x.astype(np.float32).ravel()
        y = y.astype(np.float32).ravel()
        if mask is not None:
            m = mask.astype(bool).ravel()
            if m.sum() == 0:
                return -np.inf
            x = x[m]
            y = y[m]

        # Adaptive intensity range (exclude extreme background)
        x_min, x_max = np.percentile(x, 1), np.percentile(x, 99)
        y_min, y_max = np.percentile(y, 1), np.percentile(y, 99)
        x = np.clip(x, x_min, x_max)
        y = np.clip(y, y_min, y_max)

        # Joint histogram
        H, _, _ = np.histogram2d(x, y, bins=bins)
        Pxy = H / np.maximum(H.sum(), 1.0)
        Px = Pxy.sum(axis=1, keepdims=True)
        Py = Pxy.sum(axis=0, keepdims=True)

        # Avoid log(0)
        eps = 1e-12
        Hx = -np.sum(Px * np.log(Px + eps))
        Hy = -np.sum(Py * np.log(Py + eps))
        Hxy = -np.sum(Pxy * np.log(Pxy + eps))

        # Strehl & Ghosh definition; higher = more similar
        return (Hx + Hy) / max(Hxy, eps)

    # Perform registration and compute NMI score
    def register_and_score(fixed_np, moving_np, hist_bins=HIST_BINS,
                           sample_percent=SAMPLE_PERCENT, max_iter=MAX_ITER,
                           down=DOWNSAMPLE_FACTOR) -> Tuple[float, sitk.Transform, np.ndarray, np.ndarray]:

        fixed_img = to_sitk_gray(fixed_np, down=down)
        moving_img = to_sitk_gray(moving_np, down=down)

        fixed_mask = make_mask(fixed_np, down=down, thresh=0)
        moving_mask = make_mask(moving_np, down=down, thresh=0)

        def _run(mask_fixed, mask_moving, sampling='RANDOM'):
            reg = sitk.ImageRegistrationMethod()
            reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=hist_bins)

            if sampling == 'RANDOM':
                reg.SetMetricSamplingStrategy(reg.RANDOM)
                reg.SetMetricSamplingPercentage(sample_percent)
            else:
                reg.SetMetricSamplingStrategy(reg.NONE)

            if mask_fixed is not None:
                reg.SetMetricFixedMask(mask_fixed)
            if mask_moving is not None:
                reg.SetMetricMovingMask(mask_moving)

            reg.SetInterpolator(sitk.sitkLinear)
            reg.SetShrinkFactorsPerLevel([4, 2, 1])
            reg.SetSmoothingSigmasPerLevel([2, 1, 0])
            reg.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

            # Initialize transform (similarity)
            initial_tx = sitk.CenteredTransformInitializer(
                fixed_img,
                moving_img,
                sitk.Similarity2DTransform(),
                sitk.CenteredTransformInitializerFilter.GEOMETRY
            )

            if ROT_RANGE_DEG > 0:
                angle_rad = np.deg2rad(np.random.uniform(-ROT_RANGE_DEG, ROT_RANGE_DEG))
                initial_tx.SetAngle(initial_tx.GetAngle() + float(angle_rad))

            reg.SetOptimizerAsRegularStepGradientDescent(
                learningRate=2.0,
                minStep=1e-3,
                numberOfIterations=max_iter,
                relaxationFactor=0.5,
                gradientMagnitudeTolerance=1e-6
            )
            reg.SetOptimizerScalesFromPhysicalShift()
            reg.SetInitialTransform(initial_tx, inPlace=False)
            final_tx = reg.Execute(fixed_img, moving_img)
            moved = sitk.Resample(moving_img, fixed_img, final_tx, sitk.sitkLinear, 0.0, sitk.sitkFloat32)
            return final_tx, moved

        # 1) Preferred: with mask + random sampling
        try:
            final_tx, moved = _run(fixed_mask, moving_mask, sampling='RANDOM')
        except Exception:
            # 2) Fallback: no mask
            try:
                final_tx, moved = _run(None, None, sampling='RANDOM')
            except Exception:
                # 3) Final fallback: no sampling (full image)
                final_tx, moved = _run(None, None, sampling='NONE')

        moved_np = sitk.GetArrayFromImage(moved)
        fixed_np_ds = sitk.GetArrayFromImage(fixed_img)

        joint_mask = ((moved_np > 0) & (fixed_np_ds > 0)).astype(np.uint8)
        nmi = normalized_mutual_information(fixed_np_ds, moved_np, mask=joint_mask, bins=hist_bins)
        return nmi, final_tx, moved_np, fixed_np_ds

    # Load all candidate slices from atlas_dir
    exts = (".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp")
    atlas_paths: List[Path] = [p for p in atlas_dir.glob("**/*") if p.suffix.lower() in exts]

    if len(atlas_paths) == 0:
        raise RuntimeError(f"No image files found in {atlas_dir}.")

    # Ensure query_slice is uint8
    query_np = query_slice.copy()
    if query_np.dtype != np.uint8:
        query_np = query_np.astype(np.uint8)

    scores: List[Tuple[float, Path]] = []
    best_result = {}

    print(f"Number of candidate slices: {len(atlas_paths)}")
    for i, atlas_path in enumerate(atlas_paths, 1):
        # Read candidate atlas slice
        atlas_np = cv2.imread(str(atlas_path), cv2.IMREAD_GRAYSCALE)
        if atlas_np is None:
            continue

        # Resize to roughly match the query slice size (improves convergence)
        scale = min(query_np.shape[0] / atlas_np.shape[0], query_np.shape[1] / atlas_np.shape[1])
        if scale < 0.5 or scale > 2.0:
            new_wh = (
                max(8, int(atlas_np.shape[1] * scale)),
                max(8, int(atlas_np.shape[0] * scale))
            )
            atlas_resized = cv2.resize(
                atlas_np, new_wh,
                interpolation=cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
            )
        else:
            atlas_resized = atlas_np

        try:
            nmi, transform, moved_np, fixed_np_ds = register_and_score(atlas_resized, query_np)
            scores.append((nmi, atlas_path))
            if len(best_result) == 0 or nmi > best_result["nmi"]:
                best_result = {
                    "nmi": nmi,
                    "path": atlas_path,
                    "fixed": fixed_np_ds,
                    "moved": moved_np
                }
        except Exception as e:
            print(f"[{i}/{len(atlas_paths)}] Skipped {atlas_path.name}: {e}")
            continue

    # Return best matched slice and its visualization info
    return best_result


### üß© Batch mutual information matching and Evaluate Slice Matching Accuracy (Mutual Information Retrieval)


In [4]:
# --- 3. Batch execution of Mutual Information matching ---
import re
results = []
x, y = 0, 0  # x = number of correct matches, y = total samples


# A is the ground-truth original slice numpy array (source of the synthetic image).
# A1 is the synthetic image generated from A (rotated, shifted, or deformed version).
# A2 is the best-matched slice filename predicted by the algorithm.
# a is the numeric index extracted from A (true slice number).
# a2 is the numeric index extracted from A2 (predicted slice number).

for i, row in gt_df.iterrows():

    print(f"\n=== Processing test image {i+1}/{len(gt_df)} ===")

    # Get test image path
    A1_path = SYNTHETIC_TEST_DIR / "images" / row["synthetic_file"]
    
    A = Path(row["source_file"]).name        # Use only filename for reliable comparison

    # Load grayscale image (A is not needed for computation, kept for consistency)
    A1 = cv2.imread(str(A1_path), cv2.IMREAD_GRAYSCALE)

    if A1 is None or (np.sum(A1 < 15) / A1.size > 0.90):
        print(f"‚ö†Ô∏è Cannot read {A1_path.name}, skipping.")
        continue

    # Run matching function and get best match visualization info
    best_vis = find_most_similar_slices(A1, RAW_ATLAS_DIR)

    # Extract returned info (best_vis must contain 'path' and 'nmi')
    A2 = best_vis['path'].name
    a = int(re.search(r'\d+', A).group())
    a2 = int(re.search(r'\d+', A2).group())

    # Statistics
    y += 1
    is_correct = (abs(a - a2) <= 4)
    if is_correct:
        x += 1
    
    # Record results
    results.append({
        "synthetic_file": row["synthetic_file"],
        "best_match": A2,
        "gt_source_file": A,
        "correct": bool(is_correct)
    })

    # ‚úÖ Print variable values (handle None cases for best_vis safely)
    print(f"Prediction: {A2} | Ground Truth: {A}")
    print(f"current Accuracy: {x}/{y} = {x / y:.3f}")

# Print final accuracy
if y > 0:
    print(f"\nFinal Accuracy: {x}/{y} = {x / y:.3f}")
else:
    print("\nNo valid samples were processed.")

# --- Save results ---
results_df = pd.DataFrame(results)
results_df.to_csv(SYNTHETIC_TEST_DIR / "match_results.csv", index=False)
print("\n‚úÖ All test images processed successfully. Results saved to match_results.csv.")



=== Processing test image 1/720 ===
Number of candidate slices: 218
[1/218] Skipped coronal_slice_000.png: Exception thrown in SimpleITK ImageRegistrationMethod_Execute: D:\a\SimpleITK\SimpleITK\bld\ITK-prefix\include\ITK-5.4\itkImageBase.hxx:79:
ITK ERROR: Image(0000019040F2AF40): Zero-valued spacing is not supported and may result in undefined behavior.
Refusing to change spacing from [1, 1] to [0, 3.83338]
[133/218] Skipped horizontal_slice_000.png: Exception thrown in SimpleITK ImageRegistrationMethod_Execute: D:\a\SimpleITK\SimpleITK\bld\ITK-prefix\include\ITK-5.4\itkImageBase.hxx:79:
ITK ERROR: Image(00000190123CC240): Zero-valued spacing is not supported and may result in undefined behavior.
Refusing to change spacing from [1, 1] to [0, 3.83338]
[173/218] Skipped sagittal_slice_000.png: Exception thrown in SimpleITK ImageRegistrationMethod_Execute: D:\a\SimpleITK\SimpleITK\bld\ITK-prefix\include\ITK-5.4\itkImageBase.hxx:79:
ITK ERROR: Image(00000190414A5B30): Zero-valued spacin