# Deformetrica Parameter Optimization

This notebook allows to inform the choice of the Deformetrica parameters by evaluating different parameter settings for the mapping. The optimal parameters should be chosen by the user based on the:

- Reconstruction error 
- Computation time 
- Visual inspection of the mapping in the optimization cohort

In [None]:
import os
import numpy as np
import pandas as pd
import deformetrica as dfca
import matplotlib.pyplot as plt
import seaborn as sns
import glob

from mesh_utils import load_vtk_polydata_mesh, calculate_distance_mesh

### 1. Setup paths and cohort

In [None]:
data_dir = "optimization_cohort/"
template_file = "template_remeshed.vtk"
results_dir = "optimization_runs"
os.makedirs(results_dir, exist_ok=True)

selected_meshes = sorted([
    os.path.join(data_dir, f)
    for f in os.listdir(data_dir)
    if f.endswith(".vtk")
])

### 2. Build dataset and template specifications

In [None]:
# Dataset
dataset_specifications = {
    'dataset_filenames': [[{'heart': f} for f in selected_meshes]],
    'visit_ages': [list(range(len(selected_meshes)))],
    'subject_ids': [list(f"sub_{i}" for i in range(len(selected_meshes)))]
}

# Template
template_specifications = {
    'biv': {
        'deformable_object_type': 'SurfaceMesh',
        'noise_std': 0.1,
        'filename': template_file,
        'attachment_type': 'Varifold'
    }
}

### 3. Define parameter grid

In [None]:
cp_spacings = np.arange(4, 20, 2)
kernel_widths = np.arange(4, 20, 2)

CP_grid, KW_grid = np.meshgrid(cp_spacings, kernel_widths, indexing='ij')

param_grid = [
    (CP_grid[i, j], KW_grid[i, j])
    for i in range(CP_grid.shape[0])
    for j in range(CP_grid.shape[1])
]

### 4. Run the models from the grid search

In [None]:
for cp_spacing, kernel_width in param_grid:
    run_name = f"cp{cp_spacing}_kw{kernel_width}"
    output_dir = os.path.join(results_dir, run_name)
    os.makedirs(output_dir, exist_ok=True)

    iteration_logs = []

    def estimator_callback(status):
        iteration_logs.append(status)
        return True

    estimator_options = {
        'optimization_method_type': 'GradientAscent',
        'max_iterations': 100,
        'convergence_tolerance': 1e-5,
        'initial_step_size': 0.01,
        'callback': estimator_callback
    }

    model_options = {
        'deformation_kernel_type': 'torch',
        'deformation_kernel_width': kernel_width,
        'smoothing_kernel_width': 15.0,  # fixed value
        'use_sobolev_gradient': True,
        'gpu_mode': dfca.GpuMode.NONE,
        'dtype': 'float32',
        'dense_mode': True
    }

    deformetrica = dfca.Deformetrica(output_dir=output_dir, verbosity='WARNING')
    deformetrica.estimate_geodesic_regression(
        template_specifications, dataset_specifications,
        estimator_options=estimator_options,
        model_options=model_options
    )

### 5. Evaluate reconstruction error for each parameter combination

In [None]:
data_folder = "optimization_runs"
domain = "biv"

cp_spacing = []
kernel_widths = []
registration_errors = []

list_dir = sorted(os.listdir(data_folder))

In [None]:
for dir_name in list_dir:
    try:
        cp = float(dir_name.split('_')[0].replace('cp', ''))
        k = float(dir_name.split('_')[1].replace('kw', ''))
    except:
        print(f"Skipping unrecognized folder format: {dir_name}")
        continue

    cp_spacing.append(cp)
    kernel_widths.append(k)

    input_dir = os.path.join(data_folder, dir_name)
    original_mesh_files = glob.glob(os.path.join(input_dir, "data", "*.vtk"))
    recon_pattern = f"output/DeterministicAtlas__Reconstruction__{domain}__subject_*.vtk"
    reconstructed_mesh_files = glob.glob(os.path.join(input_dir, recon_pattern))

    subject_ids_original = [os.path.basename(f).split('.')[0] for f in original_mesh_files]
    subject_ids_reconstructed = [os.path.basename(f).split("subject_")[-1].split('.')[0] for f in reconstructed_mesh_files]

    if sorted(subject_ids_original) != sorted(subject_ids_reconstructed):
        print(f"Mismatch in subjects for CP: {cp}, K: {k}")
        registration_errors.append(np.nan)
        continue

    distances = []
    for orig_file in original_mesh_files:
        subj_id = os.path.basename(orig_file).split('.')[0]
        recon_file = os.path.join(input_dir, "output", f"DeterministicAtlas__Reconstruction__{domain}__subject_{subj_id}.vtk")
        if not os.path.exists(recon_file):
            print(f"Missing reconstruction for: {subj_id}")
            continue
        original = load_vtk_polydata_mesh(orig_file)
        reconstructed = load_vtk_polydata_mesh(recon_file)
        distances.append(calculate_distance_mesh(original, reconstructed))

    mean_dist = np.mean(distances)
    registration_errors.append(mean_dist)
    print(f"CP: {cp}, KW: {k}, Error: {mean_dist:.4f}")

### 5. Visualize results as a heat map

In [None]:
df = pd.DataFrame({
    'directory': list_dir,
    'kernel_width': kernel_widths,
    'cp_spacing': cp_spacing,
    'registration_error': registration_errors
})

fig, ax = plt.subplots()
sc = ax.scatter(df['kernel_width'], df['cp_spacing'], c=df['registration_error'], cmap='viridis', marker='o')
mismatch_idx = df[df['registration_error'].isna()].index
ax.scatter(df.loc[mismatch_idx, 'kernel_width'], df.loc[mismatch_idx, 'cp_spacing'], c='red', marker='x', label='Mismatch')

plt.xlabel('Kernel Width')
plt.ylabel('CP Spacing')
plt.title('Registration Error')
plt.colorbar(sc, label='Error')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

The user is advised to qualitatively inspect the output of the mapping with the lowest reconstruction error to base their choice on the optimal parameters.

In [None]:
# Print the combinations with the lowest recontruction error

min_rows = df.nsmallest(10, 'registration_error')
print(min_rows)