### 3D Metric Embedding Visualisation

This notebook creates a faithful 3D embedding of the learned 2D metric tensor on the sphere. The approach:

1. **Create a unified sphere mesh** from both stereographic patches (North and South hemispheres)
   - Structured disc meshes are generated for each patch and mapped to 3D via stereographic projection
   - A convex hull triangulation joins the two hemispheres into a single mesh

2. **Compute metric-weighted edge lengths** using the learned metric tensor $g_{ij}$
   - For same-patch edges: $\ell = \int_0^1 \sqrt{(\vec{v})^T g(\vec{x}(t)) \vec{v}} \, dt$ where $\vec{v} = \vec{p}_2 - \vec{p}_1$
   - For cross-patch edges: spherical geodesic distance is used as an approximation

3. **Build geodesic distance matrix** via shortest paths on the mesh graph

4. **Multi-Dimensional Scaling (MDS)** embeds the distance matrix into 3D Euclidean space
   - Procrustes alignment orients the embedding to match the reference sphere

5. **Compute scalar curvature** $R$ at each vertex using the learned metric

6. **Visualise** the resulting surface:
   - Plot 1: 3D mesh embedding (optionally side-by-side with reference sphere)
   - Plot 2: Embedding coloured by scalar curvature $R$

**Output:** Shows how the learned metric deforms sphere geometry. Figures are saved to `visualisations/` with filenames including `R_kind` and (for spherical harmonics) the `ms` values.

In [None]:
# Import libraries and relevant functions
import os
import re
import glob
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import tensorflow as tf
tf.keras.backend.set_floatx('float64')
from scipy.spatial import Delaunay, ConvexHull
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.csgraph import shortest_path
from sklearn.manifold import MDS
import warnings

# Import relevant functions for new model format
from data.dataset import list_saved_models
from data.prescribers import build_prescriber
from network.global_conformal_model import GlobalConformalModel
from geometry.ball import patch_xy_to_xyz, xyz_to_patch_xy

# Output the list of saved models (to select which to import with in the following cell)
saved_models = list_saved_models()


In [None]:
# Select which checkpoint folder(s) to use (by index or name)
model_index = 0  # Change this to select a different run folder
selected_folder = saved_models[model_index]
# Define checkpoints_dir if not already defined
checkpoints_dir = "checkpoints"  # Default directory for checkpoints
selected_folder_path = os.path.join(checkpoints_dir, selected_folder)

# Epoch numbers to visualize (list for multiple, single value, or "all")
# Examples:
#   epoch_numbers = [50, 75, 94]  # specific epochs
#   epoch_numbers = 94            # single epoch
#   epoch_numbers = "all"         # all available epochs
#   epoch_numbers = "final"       # only final_model.keras
epoch_numbers = "all"

# Find all epoch checkpoint .keras files in the selected folder
all_model_files = glob.glob(os.path.join(selected_folder_path, "model_epoch*.keras"))
def extract_epoch_num(filename):
    m = re.search(r"model_epoch(\d+)\.keras", filename)
    return int(m.group(1)) if m else None
model_files = [(extract_epoch_num(f), f) for f in all_model_files]
model_files = [(e, f) for e, f in model_files if e is not None]
model_files.sort()

# Check for final_model.keras
final_model_path = os.path.join(selected_folder_path, "final_model.keras")
has_final_model = os.path.exists(final_model_path)

# Filter by epoch_numbers
if epoch_numbers == "final":
    # Only load final model
    if has_final_model:
        selected_models = [("final", final_model_path)]
    else:
        print("Warning: final_model.keras not found, falling back to latest epoch model")
        selected_models = [model_files[-1]] if model_files else []
elif epoch_numbers == "all":
    selected_models = model_files
    # Optionally add final model if it exists and no epoch checkpoints found
    if has_final_model and not model_files:
        selected_models = [("final", final_model_path)]
elif isinstance(epoch_numbers, int):
    selected_models = [mf for mf in model_files if mf[0] == epoch_numbers]
elif isinstance(epoch_numbers, (list, tuple)):
    selected_models = [mf for mf in model_files if mf[0] in epoch_numbers]
else:
    selected_models = []

print(f"Selected models (epochs): {[e for e, _ in selected_models]}")

# Load all selected models
custom_objects = {'GlobalConformalModel': GlobalConformalModel}
loaded_models = []
for epoch, model_path in selected_models:
    loaded_models.append(tf.keras.models.load_model(model_path, custom_objects=custom_objects))

In [None]:
# Helper functions for mesh creation

def create_disc_mesh(n_radial=20, n_angular=40, radius=0.95):
    """Create a structured mesh on the disc with concentric rings."""
    vertices = [(0.0, 0.0)]
    for i in range(1, n_radial + 1):
        r = radius * (i / n_radial)
        n_pts = max(8, int(n_angular * r / radius))
        for j in range(n_pts):
            theta = 2 * np.pi * j / n_pts
            vertices.append((r * np.cos(theta), r * np.sin(theta)))
    vertices = np.array(vertices)
    return vertices, Delaunay(vertices).simplices

def disc_to_sphere(disc_coords, patch_idx=0):
    """Map 2D disc coords to 3D sphere coords via patch_xy_to_xyz."""
    disc_tf = tf.convert_to_tensor(disc_coords, dtype=tf.float64)
    return patch_xy_to_xyz(disc_tf, patch_idx=patch_idx).numpy()

def create_full_sphere_mesh(n_radial=15, n_angular=32, overlap_radius=0.85):
    """
    Create a unified mesh covering the full sphere from both patches.
    Returns vertices_3d, faces, disc coords for each patch, and index arrays.
    """
    vertices_disc_north, _ = create_disc_mesh(n_radial, n_angular, radius=overlap_radius)
    vertices_3d_north = disc_to_sphere(vertices_disc_north, patch_idx=0)
    
    vertices_disc_south, _ = create_disc_mesh(n_radial, n_angular, radius=overlap_radius)
    vertices_3d_south = disc_to_sphere(vertices_disc_south, patch_idx=1)
    
    n_north = len(vertices_3d_north)
    n_south = len(vertices_3d_south)
    
    vertices_3d = np.vstack([vertices_3d_north, vertices_3d_south])
    north_indices = np.arange(n_north)
    south_indices = np.arange(n_north, n_north + n_south)
    
    # Triangulate on sphere via convex hull
    hull = ConvexHull(vertices_3d)
    faces = hull.simplices
    
    return vertices_3d, faces, vertices_disc_north, vertices_disc_south, north_indices, south_indices

In [None]:
# Helper functions for distance computation

def get_disc_coords_for_vertex(vertices_disc_north, vertices_disc_south, 
                                north_indices, vertex_idx):
    """Get disc coordinates and patch index for a vertex."""
    if vertex_idx in north_indices:
        return vertices_disc_north[vertex_idx], 0
    else:
        return vertices_disc_south[vertex_idx - len(north_indices)], 1

def slerp(p0, p1, t):
    """Spherical linear interpolation between two 3D points on the unit sphere."""
    omega = np.arccos(np.clip(np.dot(p0, p1), -1.0, 1.0))
    if omega < 1e-10:
        return p0  # Points are very close
    return (np.sin((1-t)*omega) * p0 + np.sin(t*omega) * p1) / np.sin(omega)

def compute_edge_length(v1_idx, v2_idx, vertices_3d,
                        vertices_disc_north, vertices_disc_south,
                        north_indices, south_indices,
                        loaded_model, n_subdivisions=5):
    """Compute metric-weighted edge length using model prediction."""
    disc1, patch1 = get_disc_coords_for_vertex(vertices_disc_north, vertices_disc_south, north_indices, v1_idx)
    disc2, patch2 = get_disc_coords_for_vertex(vertices_disc_north, vertices_disc_south, north_indices, v2_idx)

    if patch1 == patch2:
        # Same patch: compute metric-weighted distance directly
        direction = np.asarray(disc2 - disc1, dtype=np.float64).ravel()
        # Handle degenerate case of identical vertices
        if np.linalg.norm(direction) < 1e-14:
            return 0.0
        
        t_vals = np.linspace(0, 1, n_subdivisions + 1)
        points = np.array([disc1 + t * direction for t in t_vals])  # (N, 2)
        
        # Model expects (batch, num_patches, 2) where batch = N points
        # Stack same coords for both patches: (N, 2, 2)
        points_both_patches = np.stack([points, points], axis=1)  # (N, 2, 2)
        points_tf = tf.convert_to_tensor(points_both_patches, dtype=tf.float64)
        
        batch_dict = {loaded_model.patch_coords_key: points_tf}
        output = loaded_model(batch_dict, training=False)
        # Output shape: (batch, num_patches, 2, 2) = (N, 2, 2, 2)
        g = output[loaded_model.conformal_metric_key][:, patch1, :, :].numpy()  # shape (N, 2, 2)
        
        # Compute metric-weighted norm at each point along the line
        integrand = []
        for gg in g:
            # gg is (2, 2)
            # Compute v^T @ g @ v = direction.T @ gg @ direction
            gv = gg @ direction  # (2,2) @ (2,) -> (2,)
            val = direction @ gv  # (2,) @ (2,) -> scalar
            val_scalar = float(val)
            integrand.append(np.sqrt(max(val_scalar, 1e-12)))
        
        # Manual trapezoidal rule
        integrand = np.array(integrand)
        dt = np.diff(t_vals)
        avg = 0.5 * (integrand[:-1] + integrand[1:])
        return float(np.sum(dt * avg))
    else:
        # Cross-patch: interpolate along great circle on sphere, evaluate metric in appropriate patch
        xyz1 = vertices_3d[v1_idx]
        xyz2 = vertices_3d[v2_idx]
        
        # Normalize to ensure they're on unit sphere
        xyz1 = xyz1 / np.linalg.norm(xyz1)
        xyz2 = xyz2 / np.linalg.norm(xyz2)
        
        # Use spherical linear interpolation (great circle path)
        t_vals = np.linspace(0, 1, n_subdivisions + 1)
        xyz_points = np.array([slerp(xyz1, xyz2, t) for t in t_vals])
        
        # For each segment, determine which patch to use and compute metric-weighted length
        total_length = 0.0
        for i in range(len(t_vals) - 1):
            t_mid = 0.5 * (t_vals[i] + t_vals[i+1])
            xyz_mid = slerp(xyz1, xyz2, t_mid)
            
            # Choose patch based on which pole is closer (z > 0 -> north, z < 0 -> south)
            if xyz_mid[2] >= 0:
                use_patch = 0
            else:
                use_patch = 1
            
            # Convert segment endpoints to patch coordinates
            disc_start = xyz_to_patch_xy(tf.constant([xyz_points[i]], dtype=tf.float64), 
                                         patch_idx=use_patch).numpy()[0]
            disc_end = xyz_to_patch_xy(tf.constant([xyz_points[i+1]], dtype=tf.float64), 
                                       patch_idx=use_patch).numpy()[0]
            
            # Direction in patch coordinates
            direction = np.asarray(disc_end - disc_start, dtype=np.float64).ravel()
            
            if np.linalg.norm(direction) < 1e-14:
                continue
            
            # Evaluate metric at midpoint
            disc_mid = xyz_to_patch_xy(tf.constant([xyz_mid], dtype=tf.float64), 
                                       patch_idx=use_patch).numpy()[0]  # (2,)
            
            # Model expects (batch, num_patches, 2) where batch = 1 point
            # Stack same coords for both patches: (1, 2, 2)
            disc_mid_both = np.stack([disc_mid, disc_mid], axis=0)[None, :, :]  # (1, 2, 2)
            disc_mid_tf = tf.convert_to_tensor(disc_mid_both, dtype=tf.float64)
            
            batch_dict = {loaded_model.patch_coords_key: disc_mid_tf}
            output = loaded_model(batch_dict, training=False)
            # Output shape: (batch, num_patches, 2, 2) = (1, 2, 2, 2)
            g = output[loaded_model.conformal_metric_key][0, use_patch, :, :].numpy()  # shape (2, 2)
            
            # Compute metric-weighted length: sqrt(direction^T @ g @ direction)
            gv = g @ direction
            val = direction @ gv
            val_scalar = float(val)
            segment_length = np.sqrt(max(val_scalar, 1e-12))
            
            total_length += segment_length
        
        return float(total_length)

def build_distance_matrix(vertices_3d, faces, vertices_disc_north, vertices_disc_south,
                          north_indices, south_indices, loaded_model, n_subdivisions=4):
    """Build geodesic distance matrix via shortest paths on mesh graph."""
    n_vertices = len(vertices_3d)
    edges = set()
    for face in faces:
        for i in range(3):
            v1, v2 = face[i], face[(i+1) % 3]
            edges.add((min(v1, v2), max(v1, v2)))
    print(f"Computing {len(edges)} edge lengths...")
    dist_matrix = lil_matrix((n_vertices, n_vertices))
    for idx, (i, j) in enumerate(edges):
        if idx % 200 == 0 and idx > 0:
            print(f"  ...{idx}/{len(edges)}")
        d = compute_edge_length(i, j, vertices_3d, vertices_disc_north, vertices_disc_south,
                                north_indices, south_indices, loaded_model, n_subdivisions)
        dist_matrix[i, j] = d
        dist_matrix[j, i] = d
    print("Computing shortest paths...")
    return shortest_path(csr_matrix(dist_matrix), method='D', directed=False)

In [None]:
# Helper functions for MDS embedding and orientation

def embed_with_mds(distance_matrix, n_components=3, max_iter=500):
    """Embed distance matrix into 3D using MDS."""
    distance_matrix = np.array(distance_matrix)
    max_finite = np.max(distance_matrix[np.isfinite(distance_matrix)])
    distance_matrix[~np.isfinite(distance_matrix)] = max_finite * 2
    
    mds = MDS(n_components=n_components, dissimilarity='precomputed', 
              random_state=42, max_iter=max_iter, normalized_stress='auto')
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        embedding = mds.fit_transform(distance_matrix)
    print(f"MDS stress: {mds.stress_:.4f}")
    return embedding

def standardize_embedding_orientation(embedding):
    """
    Standardize embedding orientation using first two points.
    - Scales so average distance from origin is 1 (approximates constant volume)
    - Rotates so first point has θ=0 (y=0, in xz-plane)
    - Rotates so second point has φ=0 (on positive z-axis)
    """
    # Get first two points
    p1 = embedding[0].copy()
    p2 = embedding[1].copy()
    
    # Scale so average distance from origin is 1 (approximates constant volume)
    norms = np.linalg.norm(embedding, axis=1)
    avg_norm = np.mean(norms)
    if avg_norm > 1e-10:
        scale_factor = 1.0 / avg_norm
        embedding_scaled = embedding * scale_factor
        p1 = embedding_scaled[0]
        p2 = embedding_scaled[1]
    else:
        embedding_scaled = embedding.copy()
    
    # Rotate so first point is on positive x-axis: [1, 0, 0]
    # Build rotation matrix that maps p1 to [1, 0, 0]
    p1_norm = p1 / np.linalg.norm(p1)
    target1 = np.array([1.0, 0.0, 0.0])
    
    # Rotation axis: cross product
    axis = np.cross(p1_norm, target1)
    axis_norm = np.linalg.norm(axis)
    
    if axis_norm > 1e-10:  # Not already aligned
        axis = axis / axis_norm
        # Rotation angle
        cos_angle = np.dot(p1_norm, target1)
        angle = np.arccos(np.clip(cos_angle, -1.0, 1.0))
        
        # Rodrigues rotation formula
        K = np.array([[0, -axis[2], axis[1]],
                      [axis[2], 0, -axis[0]],
                      [-axis[1], axis[0], 0]])
        R1 = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K)
    else:
        R1 = np.eye(3)
    
    embedding_rot1 = embedding_scaled @ R1.T
    p2_rot1 = embedding_rot1[1]
    
    # Now rotate around x-axis so second point is in xy-plane (z=0)
    # Second point is at some [x, y, z], we want to rotate around x-axis to make z=0
    # Rotation around x-axis by angle theta:
    # [1   0        0     ]
    # [0   cos(θ)  -sin(θ)]
    # [0   sin(θ)   cos(θ)]
    
    # We want z' = 0, where z' = y*sin(θ) + z*cos(θ)
    # So: tan(θ) = -z/y
    y2, z2 = p2_rot1[1], p2_rot1[2]
    
    if abs(y2) > 1e-10 or abs(z2) > 1e-10:
        theta = np.arctan2(-z2, y2)
        R2 = np.array([[1, 0, 0],
                       [0, np.cos(theta), -np.sin(theta)],
                       [0, np.sin(theta), np.cos(theta)]])
    else:
        R2 = np.eye(3)
    
    embedding_final = embedding_rot1 @ R2.T
    
    return embedding_final

In [None]:
# Options for visualization
save_figures = True           # Set to True to save figures to 'visualisations/' directory
include_curvature_plots = False   # Set to True to generate 'coloured by R' plots (adds compute time)

# Mesh parameters
n_radial = 12
n_angular = 28
overlap_radius = 0.88

# Step 1: Create unified sphere mesh
print("Creating unified sphere mesh...")
vertices_3d, faces, vertices_disc_north, vertices_disc_south, north_indices, south_indices = \
    create_full_sphere_mesh(n_radial=n_radial, n_angular=n_angular, overlap_radius=overlap_radius)
print(f"  {len(vertices_3d)} vertices, {len(faces)} faces")

# Visualize the mesh vertices on each patch (the actual points being embedded)
num_patches = 2 if len(vertices_disc_south) > 0 else 1
fig, axes = plt.subplots(1, num_patches, figsize=(5*num_patches, 5))
if num_patches == 1:
    axes = [axes]
axes[0].set_title('North Patch Mesh Vertices')
axes[0].scatter(vertices_disc_north[:, 0], vertices_disc_north[:, 1], alpha=0.5, s=10)
axes[0].set_xlim(-1, 1)
axes[0].set_ylim(-1, 1)
axes[0].set_aspect('equal')
if num_patches > 1:
    axes[1].set_title('South Patch Mesh Vertices')
    axes[1].scatter(vertices_disc_south[:, 0], vertices_disc_south[:, 1], alpha=0.5, s=10)
    axes[1].set_xlim(-1, 1)
    axes[1].set_ylim(-1, 1)
    axes[1].set_aspect('equal')
plt.tight_layout()
plt.show()

# Step 2+: For each loaded model, compute distances, embed, and visualize
for model_idx, loaded_model in enumerate(loaded_models):
    print(f"\n{'='*60}")
    print(f"Processing model: epoch {selected_models[model_idx][0]}")
    print(f"{'='*60}\n")

    # Extract config for this model
    cfg = loaded_model.get_config()['cfg']

    # Step 2: Compute geodesic distance matrix for this model
    geodesic_distances = build_distance_matrix(
        vertices_3d, faces, vertices_disc_north, vertices_disc_south,
        north_indices, south_indices, loaded_model, n_subdivisions=4
)

    # Step 3: MDS embedding and orientation standardization
    print("Running MDS embedding...")
    embedding_3d = embed_with_mds(geodesic_distances, n_components=3)
    embedding_3d = standardize_embedding_orientation(embedding_3d)

    # Step 4: Compute scalar curvature at mesh vertices (if needed)
    if include_curvature_plots:
        R_at_vertices = np.zeros(len(vertices_3d))
        # For each patch, predict R using the model
        # Model expects (batch, num_patches, 2) where batch = N points
        # Stack same coords for both patches: (N, 2, 2)
        vertices_north_both = np.stack([vertices_disc_north, vertices_disc_north], axis=1)
        vertices_tf_north = tf.convert_to_tensor(vertices_north_both, dtype=tf.float64)
        batch_dict_north = {loaded_model.patch_coords_key: vertices_tf_north}
        output_north = loaded_model(batch_dict_north, training=False)
        # Output shape: (N, 2, ...)
        u_north = output_north[loaded_model.conformal_factor_key][:, 0].numpy().squeeze()
        delta_u_north = output_north[loaded_model.laplace_beltrami_key][:, 0].numpy().squeeze()
        R_north = np.exp(-2.0 * u_north) * (2.0 - delta_u_north)
        R_at_vertices[north_indices] = R_north
        if num_patches > 1:
            vertices_south_both = np.stack([vertices_disc_south, vertices_disc_south], axis=1)
            vertices_tf_south = tf.convert_to_tensor(vertices_south_both, dtype=tf.float64)
            batch_dict_south = {loaded_model.patch_coords_key: vertices_tf_south}
            output_south = loaded_model(batch_dict_south, training=False)
            u_south = output_south[loaded_model.conformal_factor_key][:, 1].numpy().squeeze()
            delta_u_south = output_south[loaded_model.laplace_beltrami_key][:, 1].numpy().squeeze()
            R_south = np.exp(-2.0 * u_south) * (2.0 - delta_u_south)
            R_at_vertices[south_indices] = R_south

    # Step 5: Setup visualization parameters
    R_kind = cfg['data'].get('prescribed_R', 'unknown')
    title_suffix = f"R_kind={R_kind}, epoch={selected_models[model_idx][0]}"

    # Step 6: Plot learned metric embedding
    fig = plt.figure(figsize=(8, 7))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_trisurf(embedding_3d[:, 0], embedding_3d[:, 1], embedding_3d[:, 2],
                     triangles=faces, cmap='viridis', alpha=0.7, edgecolor='k', linewidth=0.15)
    #ax.set_title(f'Learned Metric Embedding\n({title_suffix})')
    ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
    ax.set_box_aspect([1,1,1])
    plt.tight_layout()
    if save_figures:
        save_dir = os.path.join("visualisations", selected_folder)
        os.makedirs(save_dir, exist_ok=True)
        fig.savefig(os.path.join(save_dir, f"embedding_epoch_{selected_models[model_idx][0]}_K_{R_kind}.pdf"), bbox_inches='tight')
    plt.show()


In [None]:
'''
# Create animation from all epochs
import imageio.v2 as imageio
from io import BytesIO

# Ensure model_epochs is defined as the list of epochs for loaded_models
model_epochs = [e for e, _ in selected_models]

def create_embedding_animation(output_filename='metric_evolution_embedding.gif', fps=2):
    """
    Create animated GIF showing metric embedding evolution across epochs.
    
    Args:
        output_filename: Output file path (supports .gif, .mp4)
        fps: Frames per second
    """
    frames = []
    
    # First pass: determine global axis limits across all embeddings
    print("Computing global axis limits...")
    all_embeddings = []
    
    for model_idx, (loaded_model, epoch_label) in enumerate(zip(loaded_models, model_epochs)):
        print(f"  Processing epoch {epoch_label} for limits...")
        geodesic_distances = build_distance_matrix(
            vertices_3d, faces, vertices_disc_north, vertices_disc_south,
            north_indices, south_indices, loaded_model, n_subdivisions=4
        )
        embedding_3d = embed_with_mds(geodesic_distances, n_components=3)
        embedding_3d = standardize_embedding_orientation(embedding_3d)
        all_embeddings.append(embedding_3d)
    
    # Compute global limits
    all_embeddings_array = np.vstack(all_embeddings)
    x_min, x_max = all_embeddings_array[:, 0].min(), all_embeddings_array[:, 0].max()
    y_min, y_max = all_embeddings_array[:, 1].min(), all_embeddings_array[:, 1].max()
    z_min, z_max = all_embeddings_array[:, 2].min(), all_embeddings_array[:, 2].max()
    
    # Add 5% padding
    x_range = x_max - x_min
    y_range = y_max - y_min
    z_range = z_max - z_min
    padding = 0.05
    
    x_min -= x_range * padding
    x_max += x_range * padding
    y_min -= y_range * padding
    y_max += y_range * padding
    z_min -= z_range * padding
    z_max += z_range * padding
    
    print(f"\nGlobal axis limits:")
    print(f"  X: [{x_min:.3f}, {x_max:.3f}]")
    print(f"  Y: [{y_min:.3f}, {y_max:.3f}]")
    print(f"  Z: [{z_min:.3f}, {z_max:.3f}]")
    
    # Second pass: create animation frames
    print("\nRendering animation frames...")
    for model_idx, (loaded_model, epoch_label) in enumerate(zip(loaded_models, model_epochs)):
        print(f"  Frame {model_idx+1}/{len(loaded_models)}: epoch {epoch_label}")
        
        embedding_3d = all_embeddings[model_idx]
        
        # Setup title
        try:
            model_cfg = loaded_model.get_config()['cfg']
            R_kind = model_cfg['data'].get('prescribed_R', 'unknown')
        except Exception:
            R_kind = 'unknown'
        
        if R_kind == "sph_harm":
            ms_values = getattr(loaded_model, 'hp', {}).get("ms", []) if hasattr(loaded_model, 'hp') else []
            title_suffix = f"R_kind={R_kind}, ms={ms_values}"
        else:
            title_suffix = f"R_kind={R_kind}"
        
        if epoch_label != "seed":
            title_suffix = f"Epoch {epoch_label}, {title_suffix}"
        
        # Create figure
        fig = plt.figure(figsize=(8, 7), dpi=100)
        ax = fig.add_subplot(111, projection='3d')
        
        ax.plot_trisurf(embedding_3d[:, 0], embedding_3d[:, 1], embedding_3d[:, 2],
                      triangles=faces, cmap='viridis', alpha=0.7, edgecolor='k', linewidth=0.15)
        ax.set_title(f'Learned Metric Embedding\n({title_suffix})')
        
        ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
        
        # Set consistent axis limits and view angle for all frames
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        ax.set_zlim(z_min, z_max)
        ax.set_box_aspect([x_range, y_range, z_range])
        ax.view_init(elev=20, azim=45)
        
        plt.tight_layout()
        
        # Convert to image
        buf = BytesIO()
        plt.savefig(buf, format='png', dpi=100)
        buf.seek(0)
        frames.append(imageio.imread(buf))
        plt.close(fig)
    
    # Save animation
    if save_figures:
        save_dir = os.path.join("visualisations", selected_folder)
        os.makedirs(save_dir, exist_ok=True)
        output_path = os.path.join(save_dir, output_filename)
    else:
        output_path = output_filename
    
    imageio.mimsave(output_path, frames, fps=fps, loop=0)
    print(f"\nAnimation saved to: {output_path}")
    print(f"Total frames: {len(frames)}, Duration: {len(frames)/fps:.1f}s")

# Create embedding animation
create_embedding_animation('metric_evolution_embedding.gif', fps=2)
'''

In [None]:
'''
# Plot Reference Round Sphere Only
# Run this cell after running the embedding cell above to plot the reference sphere separately.
save_reference_figures = False  # Set to True to save figures

# Reference sphere plots always save to visualisations/ directly (not a subfolder)
if save_reference_figures:
    save_dir = "visualisations"
    os.makedirs(save_dir, exist_ok=True)

# Plot 1: Reference sphere mesh
fig = plt.figure(figsize=(8, 7))
ax = fig.add_subplot(111, projection='3d')
ax.plot_trisurf(vertices_3d[:, 0], vertices_3d[:, 1], vertices_3d[:, 2],
                triangles=faces, cmap='Blues', alpha=0.7, edgecolor='k', linewidth=0.15)
ax.set_title('Reference Sphere (Round Metric)')
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
ax.set_box_aspect([1,1,1])
plt.tight_layout()
if save_reference_figures:
    fig.savefig(os.path.join(save_dir, "reference_sphere.pdf"), bbox_inches='tight')
plt.show()

# Plot 2: Reference sphere coloured by learned R (only if curvature plots enabled)
if include_curvature_plots:
    fig2 = plt.figure(figsize=(8, 7))
    ax = fig2.add_subplot(111, projection='3d')
    sc = ax.scatter(vertices_3d[:, 0], vertices_3d[:, 1], vertices_3d[:, 2],
                    c=R_at_vertices, cmap='coolwarm', s=30, edgecolor='k', linewidth=0.1)
    ax.plot_trisurf(vertices_3d[:, 0], vertices_3d[:, 1], vertices_3d[:, 2],
                    triangles=faces, alpha=0.3, color='gray', edgecolor='none')
    ax.set_title('Round Sphere '
    '(coloured by R)')
    ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
    ax.set_box_aspect([1,1,1])
    plt.colorbar(sc, ax=ax, shrink=0.5, label='Scalar Curvature R')
    plt.tight_layout()
    if save_reference_figures:
        fig2.savefig(os.path.join(save_dir, "reference_sphere_curvature.pdf"), bbox_inches='tight')
    plt.show()
    '''