# 3-Parameter Model Inference

This notebook demonstrates how to load a trained 3-parameter inversion model (predicting `xc`, `yc`, `fov`) and run inference on synthetic data generated by the physical forward model.

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt

# Add project root to path to access src
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from src.models.hybrid import SpectralResNet
from src.inversion.forward_model import compute_hyperbolic_phase, wrap_phase, get_2channel_representation
from src.utils.normalization import ParameterNormalizer

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Load Model and Configuration

In [None]:
# Path to the checkpoint
# We default to one of the smoke test checkpoints which contains full metadata.
# Update this path to your specific model checkpoint.
CHECKPOINT_PATH = "../outputs_2/test_smoke_v2/checkpoints/best_model.pth"

# Load checkpoint
if not os.path.exists(CHECKPOINT_PATH):
    print(f"Error: Checkpoint not found at {CHECKPOINT_PATH}")
else:
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    print("Checkpoint loaded successfully.")
    
    # Extract config and state dict
    config = checkpoint.get("config", {})
    state_dict = checkpoint["model_state_dict"]
    
    print("Config keys:", config.keys())
    
    # Initialize Normalizer
    # We need to reconstruct the ranges from the config if available, or assume defaults
    # The training code likely saved 'xc_range', 'yc_range', etc. in the config.
    ranges = {}
    for param in ['xc', 'yc', 'fov', 'wavelength', 'focal_length']:
        key = f"{param}_range"
        if key in config:
            ranges[param] = config[key]
    
    # If ranges are empty (old config format?), define reasonable defaults matching the training data assumption
    if not ranges:
        print("Warning: No ranges found in config. Using defaults.")
        ranges = {
            'xc': (-500.0, 500.0),
            'yc': (-500.0, 500.0),
            'fov': (10.0, 80.0)
        }
    
    normalizer = ParameterNormalizer(ranges)
    print("Normalizer initialized with means:", normalizer.means)

    # Initialize Model
    # Based on our analysis, this is a SpectralResNet with output_dim=3
    model = SpectralResNet(in_channels=2, modes=16, output_dim=3)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    print("Model initialized and loaded.")

## 2. Generate Synthetic Input

We use the forward model to generate a phase map for a given set of `(xc, yc, fov)` parameters.  
We assume `focal_length` and `wavelength` are fixed (or not predicted by this model).

In [None]:
def generate_prediction_sample(xc, yc, fov, resolution=1024, focal_length=100.0, wavelength=0.532, window_size=100.0):
    """
    Generates a sample and runs inference.
    """
    # 1. Generate Input Phase Map
    # Create coordinate grid centered on xc, yc? 
    # NOTE: The training data generation usually centers the grid at (xc, yc) OR 
    # centers it at 0 and assumes xc, yc are offsets.  
    # Looking at data/loaders/simulation.py: 
    # x_coords = np.linspace(xc - window_size/2, xc + window_size/2 ...)
    # So the window IS centered on xc, yc.
    
    x_coords = np.linspace(xc - window_size / 2.0, xc + window_size / 2.0, resolution, dtype=np.float32)
    y_coords = np.linspace(yc - window_size / 2.0, yc + window_size / 2.0, resolution, dtype=np.float32)
    X_grid, Y_grid = np.meshgrid(x_coords, y_coords)
    
    # Compute Phase
    phi_unwrapped = compute_hyperbolic_phase(X_grid, Y_grid, focal_length, wavelength, theta=fov)
    phi_wrapped = wrap_phase(phi_unwrapped)
    
    # Create 2-channel input (Cos, Sin)
    inp_np = get_2channel_representation(phi_wrapped)
    
    # Prepare for Model (Add Batch Dim, Channel First)
    # inp_np is (H, W, 2) -> (1, 2, H, W)
    inp_tensor = torch.from_numpy(inp_np).permute(2, 0, 1).unsqueeze(0).to(device)
    
    # 2. Run Inference
    with torch.no_grad():
        pred_normalized = model(inp_tensor)
        
    # 3. Denormalize
    # The model outputs a tensor of shape (1, 3)
    # We need to map this back to real values.
    # The normalizer generally handles a generic dictionary or full tensor list.
    # We know the model outputs [xc, yc, fov] (based on exp config/training).
    
    # Use the denormalize manual logic since helper assumes specific dict keys potentially
    # or we can construct a dummy tensor if needed.
    # Let's verify the order the normalizer expects: ['xc', 'yc', 'fov', 'wavelength', 'focal_length']
    # BUT our model only outputs 3 things. 
    # Assuming the training output was [xc, yc, fov].
    
    pred_vals = pred_normalized.cpu().numpy()[0]
    
    # Manually denormalize based on known order
    pred_dict = {}
    param_order = ['xc', 'yc', 'fov']
    
    for i, p in enumerate(param_order):
        if p in normalizer.means:
            val = pred_vals[i]
            real_val = val * normalizer.stds[p] + normalizer.means[p]
            pred_dict[p] = real_val
            
    return inp_np, pred_dict

print("Inference function defined.")

## 3. Run Inference Demo

In [None]:
# Define Ground Truth Parameters
xc_gt = 50.0
yc_gt = -120.0
fov_gt = 45.0

# Run
inp_image, preds = generate_prediction_sample(xc_gt, yc_gt, fov_gt)

print("Ground Truth:")
print(f"  xc:  {xc_gt}")
print(f"  yc:  {yc_gt}")
print(f"  fov: {fov_gt}")

print("\nPredictions:")
for k, v in preds.items():
    print(f"  {k}: {v:.4f}")
    
print("\nErrors:")
print(f"  xc_err:  {abs(preds['xc'] - xc_gt):.4f}")
print(f"  yc_err:  {abs(preds['yc'] - yc_gt):.4f}")
print(f"  fov_err: {abs(preds['fov'] - fov_gt):.4f}")

# Visualize
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Input Phase (Cos)")
plt.imshow(inp_image[:, :, 0], cmap='twilight')
plt.colorbar()

plt.subplot(1, 2, 2)
plt.title("Input Phase (Sin)")
plt.imshow(inp_image[:, :, 1], cmap='twilight')
plt.colorbar()
plt.show()