# QC: Registration Module Validation

This notebook validates the `starfinder.registration` module for phase correlation-based image registration.

In [None]:
import sys
sys.path.insert(0, "../src/python")

import json
import numpy as np
from pathlib import Path

from starfinder.registration import (
    phase_correlate,
    apply_shift,
    register_volume,
    phase_correlate_skimage,
)
from starfinder.registration.benchmark import benchmark_registration, print_benchmark_table
from starfinder.testdata import create_test_volume
from starfinder.io import load_image_stacks

print("Imports successful")

## 1. Test Known Shift Recovery

Create a test volume, apply a known shift, and verify the recovered shift matches.

In [None]:
# Create test volume with synthetic spots
fixed = create_test_volume(
    shape=(10, 128, 128),
    n_spots=30,
    spot_intensity=200,
    background=20,
    seed=42,
)

print(f"Fixed volume shape: {fixed.shape}")
print(f"Fixed volume dtype: {fixed.dtype}")
print(f"Fixed volume range: [{fixed.min()}, {fixed.max()}]")

# Apply known shift using np.roll
known_shift = (2, -5, 8)  # (dz, dy, dx)
moving = np.roll(fixed, known_shift, axis=(0, 1, 2))

print(f"\nApplied known shift: {known_shift}")

# Recover shift using phase correlation
recovered_shift = phase_correlate(fixed, moving)

print(f"Recovered shift: {recovered_shift}")

# Verify
shift_error = np.sqrt(sum((r - k) ** 2 for r, k in zip(recovered_shift, known_shift)))
print(f"\nShift error (L2): {shift_error:.4f}")

assert shift_error < 0.5, f"Shift error too large: {shift_error}"
print("PASS: Known shift recovered correctly")

## 2. Test apply_shift() Edge Zeroing

Apply a shift to a constant volume and verify that edges are zeroed out (not wrapped).

In [None]:
# Create a constant volume (all 100s)
constant_volume = np.ones((10, 64, 64), dtype=np.float32) * 100

# Apply positive shift
shift_positive = (2, 5, 3)
shifted_pos = apply_shift(constant_volume, shift_positive)

print(f"Shift: {shift_positive}")
print(f"Result shape: {shifted_pos.shape}")

# Check that leading edges are zeroed
assert np.allclose(shifted_pos[:2, :, :], 0), "Z leading edge should be zero"
assert np.allclose(shifted_pos[:, :5, :], 0), "Y leading edge should be zero"
assert np.allclose(shifted_pos[:, :, :3], 0), "X leading edge should be zero"

# Check that interior is non-zero
interior = shifted_pos[3:, 6:, 4:]
assert interior.mean() > 50, "Interior should be non-zero"

print("PASS: Positive shift - leading edges zeroed correctly")

# Apply negative shift
shift_negative = (-3, -4, -6)
shifted_neg = apply_shift(constant_volume, shift_negative)

print(f"\nShift: {shift_negative}")

# Check that trailing edges are zeroed
assert np.allclose(shifted_neg[-2:, :, :], 0), "Z trailing edge should be zero"
assert np.allclose(shifted_neg[:, -3:, :], 0), "Y trailing edge should be zero"
assert np.allclose(shifted_neg[:, :, -5:], 0), "X trailing edge should be zero"

print("PASS: Negative shift - trailing edges zeroed correctly")

## 3. Multi-Channel Registration

Test `register_volume()` with (Z, Y, X, C) input.

In [None]:
# Create multi-channel volume (Z, Y, X, C)
n_channels = 4
shape_3d = (10, 128, 128)

# Create base volume
base = create_test_volume(
    shape=shape_3d,
    n_spots=25,
    spot_intensity=200,
    background=20,
    seed=123,
)

# Create multi-channel by adding noise to each channel
rng = np.random.default_rng(42)
multi_channel = np.stack(
    [base + rng.integers(0, 20, base.shape, dtype=np.uint8) for _ in range(n_channels)],
    axis=-1
)

print(f"Multi-channel volume shape: {multi_channel.shape}")
assert multi_channel.shape == (*shape_3d, n_channels)

# Create reference and moving (use channel 0 for shift calculation)
ref_image = base
known_shift = (-3, 4, -2)
mov_image = np.roll(base, known_shift, axis=(0, 1, 2))

# Apply same shift to all channels to simulate misaligned data
shifted_multi = np.stack(
    [np.roll(multi_channel[:, :, :, c], known_shift, axis=(0, 1, 2)) for c in range(n_channels)],
    axis=-1
)

print(f"Shifted multi-channel shape: {shifted_multi.shape}")

# Register
registered, detected_shift = register_volume(shifted_multi, ref_image, mov_image)

print(f"\nKnown shift: {known_shift}")
print(f"Detected shift: {detected_shift}")
print(f"Registered shape: {registered.shape}")

# Verify shift detection
shift_error = np.sqrt(sum((d - k) ** 2 for d, k in zip(detected_shift, known_shift)))
print(f"Shift error: {shift_error:.4f}")

assert shift_error < 0.5, f"Shift error too large: {shift_error}"
assert registered.shape == multi_channel.shape, "Output shape should match input"

print("\nPASS: Multi-channel registration works correctly")

## 4. Backend Comparison

Compare NumPy vs scikit-image backends using `benchmark_registration()`.

In [None]:
# Run benchmark with multiple sizes
results = benchmark_registration(
    sizes=[
        (5, 64, 64),      # tiny
        (10, 128, 128),   # small
        (20, 256, 256),   # medium
    ],
    methods=["numpy", "skimage"],
    n_runs=3,
    seed=42,
)

# Print formatted table
print_benchmark_table(results)

# Verify both methods recover shifts correctly
for r in results:
    assert r.metrics["shift_error"] < 0.5, f"{r.method} has large shift error: {r.metrics['shift_error']}"

print("PASS: Both backends recover shifts accurately")

## 5. napari Before/After Overlay

Visualize registration with napari. Fixed (green) + Registered (red) = Yellow for good alignment.

In [None]:
# Create test data for visualization
fixed = create_test_volume(
    shape=(10, 128, 128),
    n_spots=30,
    spot_intensity=200,
    background=20,
    seed=42,
)

known_shift = (2, -5, 8)
moving = np.roll(fixed, known_shift, axis=(0, 1, 2))

# Register
detected_shift = phase_correlate(fixed, moving)
registered = apply_shift(moving, detected_shift)

print(f"Known shift: {known_shift}")
print(f"Detected shift: {detected_shift}")

try:
    import napari
    
    viewer = napari.Viewer()
    viewer.add_image(fixed, name="fixed", colormap="green", blending="additive")
    viewer.add_image(registered, name="registered", colormap="red", blending="additive")
    print("napari viewer opened. Yellow = good alignment (green + red)")
    
except ImportError:
    print("napari not installed. Skipping visualization.")
    print("To install: pip install napari[all]")

## 6. Test with Synthetic Dataset Shifts

Load ground_truth.json from the mini dataset, register rounds, and compare to expected shifts.

In [None]:
# Load ground truth
mini_path = Path("fixtures/synthetic/mini")
with open(mini_path / "ground_truth.json") as f:
    ground_truth = json.load(f)

print("Ground truth shifts:")
expected_shifts = ground_truth["fovs"]["FOV_001"]["shifts"]
for round_name, shift in expected_shifts.items():
    print(f"  {round_name}: {shift}")

# Load reference round (round1)
channel_order = ["ch00", "ch01", "ch02", "ch03"]
ref_stack, _ = load_image_stacks(
    mini_path / "FOV_001" / "round1",
    channel_order=channel_order,
)

print(f"\nReference stack shape: {ref_stack.shape}")

# Use channel 0 as reference for registration
ref_ch0 = ref_stack[:, :, :, 0]

# Register each round and compare to expected shifts
print("\nRegistration results:")
print("-" * 60)

for round_idx in range(2, 5):  # rounds 2, 3, 4
    round_name = f"round{round_idx}"
    
    # Load moving round
    mov_stack, _ = load_image_stacks(
        mini_path / "FOV_001" / round_name,
        channel_order=channel_order,
    )
    mov_ch0 = mov_stack[:, :, :, 0]
    
    # Detect shift
    detected = phase_correlate(ref_ch0, mov_ch0)
    expected = tuple(expected_shifts[round_name])
    
    # Calculate error
    error = np.sqrt(sum((d - e) ** 2 for d, e in zip(detected, expected)))
    status = "PASS" if error < 1.0 else "FAIL"
    
    print(f"{round_name}:")
    print(f"  Expected: {expected}")
    print(f"  Detected: {tuple(int(d) for d in detected)}")
    print(f"  Error: {error:.2f} [{status}]")

print("-" * 60)
print("\nNote: Some error is expected due to synthetic data generation.")

## Summary

**Validation Checklist:**

- [ ] `phase_correlate()` recovers known shifts accurately
- [ ] `apply_shift()` zeros out edges (no wrap-around artifacts)
- [ ] `register_volume()` handles multi-channel (Z, Y, X, C) input
- [ ] NumPy and scikit-image backends produce consistent results
- [ ] napari overlay shows good alignment (yellow = green + red)
- [ ] Registration works with synthetic dataset ground truth shifts