# 5-Parameter Metalens Inversion Inference

This notebook loads a trained model and runs inference with visualizations:
- **Input Phase Map**: The phase map fed to the model
- **Predicted Phase Map**: Reconstructed from predicted parameters
- **Error Map**: Difference between input and predicted phase

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path(os.getcwd()).parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.models.factory import get_model
from src.inversion.forward_model import compute_hyperbolic_phase, wrap_phase, get_2channel_representation
from src.utils.normalization import ParameterNormalizer
from data.loaders.simulation import generate_single_sample

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Select and Load Model

In [None]:
# List available experiments
outputs_dir = PROJECT_ROOT / "outputs_2"
experiments = sorted([d.name for d in outputs_dir.iterdir() if d.is_dir()])

print("Available experiments:")
for i, exp in enumerate(experiments):
    checkpoint_path = outputs_dir / exp / "checkpoints" / "best_model.pth"
    status = "✓" if checkpoint_path.exists() else "✗"
    print(f"  {i}: [{status}] {exp}")

In [None]:
# === SELECT YOUR EXPERIMENT HERE ===
EXPERIMENT_INDEX = 0  # Change this to select a different experiment
# ===================================

experiment_name = experiments[EXPERIMENT_INDEX]
checkpoint_path = outputs_dir / experiment_name / "checkpoints" / "best_model.pth"

print(f"Loading: {experiment_name}")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

# Extract config
config = checkpoint["config"]
print(f"Config keys: {list(config.keys())}")

In [None]:
# Initialize Normalizer
ranges = {
    'xc': tuple(config.get('xc_range', [-500.0, 500.0])),
    'yc': tuple(config.get('yc_range', [-500.0, 500.0])),
    'fov': tuple(config.get('fov_range', [1.0, 20.0])),
    'wavelength': tuple(config.get('wavelength_range', [0.4, 0.7])),
    'focal_length': tuple(config.get('focal_length_range', [10.0, 100.0]))
}
normalizer = ParameterNormalizer(ranges)
print(f"Parameter ranges: {ranges}")

# Initialize Model
model_config = config.copy()
if 'model' in config:
    model_config.update(config['model'])
if 'name' not in model_config and 'type' in model_config:
    model_config['name'] = model_config['type']

model = get_model(model_config)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

output_dim = model_config.get('output_dim', 5)
resolution = config.get('resolution', 256)
window_size = config.get('window_size', 100.0)

print(f"Model: {model_config.get('name', 'unknown')} | Output dim: {output_dim} | Resolution: {resolution}")

## 2. Define Inference Helper Functions

In [None]:
def generate_phase_map(xc, yc, fov, wavelength, focal_length, resolution=256, window_size=100.0):
    """
    Generate a wrapped phase map and its 2-channel representation.
    Returns: (phase_np, input_2ch_np)
    """
    x_coords = np.linspace(xc - window_size/2, xc + window_size/2, resolution, dtype=np.float32)
    y_coords = np.linspace(yc - window_size/2, yc + window_size/2, resolution, dtype=np.float32)
    X, Y = np.meshgrid(x_coords, y_coords)
    
    phi_unwrapped = compute_hyperbolic_phase(X, Y, focal_length, wavelength, theta=fov)
    phi_wrapped = wrap_phase(phi_unwrapped)
    input_2ch = get_2channel_representation(phi_wrapped)
    
    return phi_wrapped, input_2ch

def reconstruct_phase_from_2ch(input_2ch):
    """
    Reconstruct phase from 2-channel (cos, sin) representation.
    input_2ch: (H, W, 2) or (2, H, W)
    """
    if input_2ch.shape[-1] == 2:  # (H, W, 2)
        cos_phi = input_2ch[:, :, 0]
        sin_phi = input_2ch[:, :, 1]
    else:  # (2, H, W)
        cos_phi = input_2ch[0]
        sin_phi = input_2ch[1]
    return np.arctan2(sin_phi, cos_phi)

def run_inference(xc, yc, fov, wavelength, focal_length):
    """
    Run full inference pipeline.
    Returns: (input_phase, pred_phase, pred_params, true_params)
    """
    true_params = np.array([xc, yc, fov, wavelength, focal_length], dtype=np.float32)
    
    # Generate input
    input_phase, input_2ch = generate_phase_map(
        xc, yc, fov, wavelength, focal_length, 
        resolution=resolution, window_size=window_size
    )
    
    # Prepare tensor: (H, W, 2) -> (1, 2, H, W)
    input_tensor = torch.from_numpy(input_2ch).permute(2, 0, 1).unsqueeze(0).to(device)
    
    # Inference
    with torch.no_grad():
        pred_norm = model(input_tensor)
        pred_params = normalizer.denormalize_tensor(pred_norm).cpu().numpy()[0]
    
    # If model outputs 3 params, fill in fixed wavelength/focal_length
    if output_dim == 3:
        pred_params = np.array([pred_params[0], pred_params[1], pred_params[2], wavelength, focal_length])
    
    # Generate predicted phase map
    p_xc, p_yc, p_fov, p_wl, p_fl = pred_params
    pred_phase, _ = generate_phase_map(
        xc, yc, p_fov, p_wl, p_fl,  # Use original xc, yc for same window
        resolution=resolution, window_size=window_size
    )
    
    return input_phase, pred_phase, pred_params, true_params

print("Inference functions defined.")

## 3. Single Sample Inference with Visualization

In [None]:
# === DEFINE GROUND TRUTH PARAMETERS ===
XC_GT = 100.0         # Center X (micrometers)
YC_GT = -50.0         # Center Y (micrometers)
FOV_GT = 10.0         # Field of View / Incident Angle (degrees)
WAVELENGTH_GT = 0.55  # Wavelength (micrometers)
FOCAL_LENGTH_GT = 50.0  # Focal Length (micrometers)
# ========================================

In [None]:
# Run inference
input_phase, pred_phase, pred_params, true_params = run_inference(
    XC_GT, YC_GT, FOV_GT, WAVELENGTH_GT, FOCAL_LENGTH_GT
)

# Compute error map (handling phase wrap-around)
error_map = np.abs(input_phase - pred_phase)
error_map = np.minimum(error_map, 2*np.pi - error_map)

# Print results
param_names = ['xc', 'yc', 'fov', 'wavelength', 'focal_length']
print("Parameter Comparison:")
print(f"{'Param':<15} {'True':>10} {'Predicted':>10} {'Error':>10}")
print("-" * 45)
for i, name in enumerate(param_names):
    err = abs(pred_params[i] - true_params[i])
    print(f"{name:<15} {true_params[i]:>10.4f} {pred_params[i]:>10.4f} {err:>10.4f}")

In [None]:
# 3-Panel Visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Common extent for physical coordinates
extent = [-window_size/2, window_size/2, -window_size/2, window_size/2]
phase_cmap = 'twilight'

# Panel 1: Input Phase
im1 = axes[0].imshow(input_phase, extent=extent, origin='lower', cmap=phase_cmap, vmin=-np.pi, vmax=np.pi)
axes[0].set_title(f"Input Phase\n(fov={FOV_GT}°, λ={WAVELENGTH_GT*1000:.0f}nm, f={FOCAL_LENGTH_GT}μm)")
axes[0].set_xlabel("x (μm)")
axes[0].set_ylabel("y (μm)")
plt.colorbar(im1, ax=axes[0], label='Phase (rad)')

# Panel 2: Predicted Phase
p_xc, p_yc, p_fov, p_wl, p_fl = pred_params
im2 = axes[1].imshow(pred_phase, extent=extent, origin='lower', cmap=phase_cmap, vmin=-np.pi, vmax=np.pi)
axes[1].set_title(f"Predicted Phase\n(fov={p_fov:.1f}°, λ={p_wl*1000:.0f}nm, f={p_fl:.1f}μm)")
axes[1].set_xlabel("x (μm)")
axes[1].set_ylabel("y (μm)")
plt.colorbar(im2, ax=axes[1], label='Phase (rad)')

# Panel 3: Error Map
im3 = axes[2].imshow(error_map, extent=extent, origin='lower', cmap='hot', vmin=0, vmax=np.pi)
axes[2].set_title(f"Error Map\nMean: {np.mean(error_map):.4f} rad")
axes[2].set_xlabel("x (μm)")
axes[2].set_ylabel("y (μm)")
plt.colorbar(im3, ax=axes[2], label='|Δφ| (rad)')

plt.tight_layout()
plt.show()

## 4. Batch Evaluation with Scatter Plots

In [None]:
# Generate multiple random samples and evaluate
NUM_SAMPLES = 50

np.random.seed(42)
all_true = []
all_pred = []

for i in range(NUM_SAMPLES):
    # Random params within training range
    xc = np.random.uniform(*ranges['xc'])
    yc = np.random.uniform(*ranges['yc'])
    fov = np.random.uniform(*ranges['fov'])
    wl = np.random.uniform(*ranges['wavelength'])
    fl = np.random.uniform(*ranges['focal_length'])
    
    _, _, pred, true = run_inference(xc, yc, fov, wl, fl)
    all_true.append(true)
    all_pred.append(pred)

all_true = np.array(all_true)
all_pred = np.array(all_pred)
print(f"Evaluated {NUM_SAMPLES} samples.")

In [None]:
# Scatter Plots
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
param_names = ['xc', 'yc', 'fov', 'wavelength', 'focal_length']

for i, name in enumerate(param_names):
    ax = axes[i]
    true_vals = all_true[:, i]
    pred_vals = all_pred[:, i]
    
    # Correlation
    corr = np.corrcoef(true_vals, pred_vals)[0, 1]
    mse = np.mean((true_vals - pred_vals)**2)
    
    ax.scatter(true_vals, pred_vals, alpha=0.6, s=20)
    
    # Identity line
    lims = [min(true_vals.min(), pred_vals.min()), max(true_vals.max(), pred_vals.max())]
    ax.plot(lims, lims, 'r--', label='Ideal')
    
    ax.set_xlabel(f'True {name}')
    ax.set_ylabel(f'Predicted {name}')
    ax.set_title(f'{name}\nR={corr:.3f}, MSE={mse:.2e}')
    ax.legend(loc='upper left')
    ax.grid(True, alpha=0.3)

plt.suptitle('True vs Predicted Parameter Scatter Plots', fontsize=14)
plt.tight_layout()
plt.show()