In [None]:
from typing import Callable

import keras
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from network.global_conformal_model import GlobalConformalModel
from data.prescribers import build_prescriber
from data.samplers import StereoSampler
from geometry.integration import integrate_gauss_legendre, integrate_monte_carlo

In [None]:
# set path to model checkpoint here
model_checkpoint_path = "checkpoints/run_11/final_model.keras"

try:
    model = keras.models.load_model(model_checkpoint_path)
except TypeError:
    model = keras.models.load_model(
        model_checkpoint_path,
        custom_objects={"GlobalConformalModel": GlobalConformalModel}
    )

cfg = model.get_config()["cfg"]

# build prescriber
if isinstance(cfg["data"]["prescribed_R"], str):
    prescriber = build_prescriber(cfg["data"]["prescribed_R"])
    prescriber_name = cfg["data"]["prescribed_R"]
else:
    prescriber = build_prescriber(
        cfg["data"]["prescribed_R"]["kind"],
        **cfg["data"]["prescribed_R"]["kwargs"],
    )
    prescriber_name = cfg["data"]["prescribed_R"]["kind"]

In [None]:
def plot_samples_patch(num_patches: int, num_samples: int, radial_offset: float):
    sampler = StereoSampler(num_patches, num_samples, radial_offset)
    test_samples_patch, test_samples_xyz = sampler()

    fig, axes = plt.subplots(1, num_patches, figsize=(5 * num_patches, 5))
    if num_patches == 1:
        axes = [axes]

    for patch_idx in range(num_patches):
        axes[patch_idx].set_title(f"Patch {patch_idx + 1}")
        axes[patch_idx].scatter(test_samples_patch[:, patch_idx, 0], test_samples_patch[:, patch_idx, 1], alpha=0.1)
        axes[patch_idx].set_xlim(-1, 1)
        axes[patch_idx].set_ylim(-1, 1)

    plt.tight_layout()

    return test_samples_patch, test_samples_xyz

test_samples_patch, test_samples_xyz = plot_samples_patch(
    num_patches=cfg["data"]["num_patches"],
    num_samples=cfg["data"]["num_samples"],
    radial_offset=cfg["data"]["radial_offset"],
)

In [None]:
def plot_samples_xyz(samples_xyz: tf.Tensor):
    num_samples, num_patches, _ = samples_xyz.shape

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection="3d")

    for patch_idx in range(num_patches):
        ax.scatter(
            samples_xyz[:, patch_idx, 0],
            samples_xyz[:, patch_idx, 1],
            samples_xyz[:, patch_idx, 2],
            alpha=0.3,
            s=1,
            label=f"Patch {patch_idx + 1}",
        )

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_title(f'Test Samples on S² ({num_patches} patch{"es" if num_patches > 1 else ""})')
    ax.set_box_aspect([1, 1, 1])
    if num_patches > 1:
        ax.legend()
    
    plt.tight_layout()
    plt.show()

plot_samples_xyz(samples_xyz=test_samples_xyz)

In [None]:
data_dict = {model.patch_coords_key: test_samples_patch}
# run model forward pass
output_dict = model(data_dict, training=False)

# collect model outputs
g = output_dict[model.conformal_metric_key].numpy()
u = output_dict[model.conformal_factor_key].numpy()
delta_u = output_dict[model.laplace_beltrami_key].numpy()

# compute predicted R_g
R_g = np.exp(-2.0 * u) * (2.0 - 2.0 * delta_u)

# compute expected R_g
target_R = prescriber(test_samples_xyz).numpy()

In [None]:
def plot_features_patch(
    samples_patch: tf.Tensor,
    g: tf.Tensor,
    R_g: tf.Tensor,
    target_R: tf.Tensor,
    u: tf.Tensor,
    *,
    metric_components: tuple[int, int] = (0, 0),
    elevation_angle: float = 30.0,
    azimuthal_angle: float = 45.0,
):
    num_samples, num_patches, _ = samples_patch.shape
    metric_i, metric_j = metric_components

    fig, axes = plt.subplots(3, num_patches, figsize=(5 * num_patches, 15), subplot_kw={"projection": "3d"})
    if num_patches == 1:
        axes = np.expand_dims(axes, -1)

    # row 1: metric component g_ij
    for patch_idx in range(num_patches):
        ax = axes[0, patch_idx]
        scatter = ax.scatter(
            samples_patch[:, patch_idx, 0],
            samples_patch[:, patch_idx, 1],
            g[:, patch_idx, metric_i, metric_j],
            c=g[:, patch_idx, metric_i, metric_j],
            cmap="viridis",
            vmin=float(g.min()) - 1e-6,
            vmax=float(g.max()) + 1e-6,
        )
        ax.set_title(rf"$g_{{{metric_i}{metric_j}}}$ (Patch {patch_idx + 1})")
        ax.set_xlabel(r'$x_1$')
        ax.set_ylabel(r'$x_2$')
        fig.colorbar(scatter, ax=ax, shrink=0.6)
        ax.view_init(elev=elevation_angle, azim=azimuthal_angle)

    # row 2: scalar curvature R
    for patch_idx in range(num_patches):
        ax = axes[1, patch_idx]
        scatter = ax.scatter(
            samples_patch[:, patch_idx, 0],
            samples_patch[:, patch_idx, 1],
            R_g[:, patch_idx, 0],
            c=R_g[:, patch_idx, 0],
            cmap="viridis",
            label="Predicted R(g)",
            vmin=float(target_R.min()) - 1e-6,
            vmax=float(target_R.max()) + 1e-6,
        )
        ax.plot_trisurf(
            samples_patch[:, patch_idx, 0],
            samples_patch[:, patch_idx, 1],
            target_R[:, patch_idx, 0],
            color="red",
            alpha=0.3,
            label="Target R",
        )
        ax.set_title(f"Gaussian Curvature R(g) (Patch {patch_idx + 1})")
        ax.set_xlabel(r'$x_1$')
        ax.set_ylabel(r'$x_2$')
        fig.colorbar(scatter, ax=ax, shrink=0.6)
        ax.view_init(elev=elevation_angle, azim=azimuthal_angle)

    # row 3: conformal factor u
    for patch_idx in range(num_patches):
        ax = axes[2, patch_idx]
        scatter = ax.scatter(
            samples_patch[:, patch_idx, 0],
            samples_patch[:, patch_idx, 1],
            u[:, patch_idx, 0],
            c=u[:, patch_idx, 0],
            cmap="viridis",
            vmin=float(u.min()) - 1e-6,
            vmax=float(u.max()) + 1e-6,
        )
        ax.set_title(f"Conformal Factor u (Patch {patch_idx + 1})")
        ax.set_xlabel(r'$x_1$')
        ax.set_ylabel(r'$x_2$')
        fig.colorbar(scatter, ax=ax, shrink=0.6)
        ax.view_init(elev=elevation_angle, azim=azimuthal_angle)

    plt.tight_layout()
    plt.show()

plot_features_patch(samples_patch=test_samples_patch, g=g, R_g=R_g, target_R=target_R, u=u)

In [None]:
def plot_features_xyz(
    samples_xyz: tf.Tensor,
    g: tf.Tensor,
    R_g: tf.Tensor,
    u: tf.Tensor,
    target_R: tf.Tensor,
    prescriber_name: str,
    *,
    metric_components: tuple[int, int] = (0, 0),
    elevation_angle: float = 30.0,
    azimuthal_angle: float = 45.0,
):
    num_samples, num_patches, _ = samples_xyz.shape
    metric_i, metric_j = metric_components

    fig = plt.figure(figsize=(18, 12))
    plot_titles = [
        f"$g_{{{metric_i}{metric_j}}}$ (kind={prescriber_name})",
        f"Conformal Factor u (kind={prescriber_name})",
        f"Gaussian Curvature R(g) (kind={prescriber_name})",
        f"Target Curvature R (kind={prescriber_name})"
    ]

    for idx, title in enumerate(plot_titles):
        ax = fig.add_subplot(2, 2, idx + 1, projection="3d")
        for patch_idx in range(num_patches):
            X, Y, Z = samples_xyz[:, patch_idx, 0], samples_xyz[:, patch_idx, 1], samples_xyz[:, patch_idx, 2]
            if idx == 0:
                data_to_plot = g[:, patch_idx, metric_i, metric_j]
                vmin, vmax = float(g.min()), float(g.max())
            elif idx == 1:
                data_to_plot = u[:, patch_idx, 0]
                vmin, vmax = float(u.min()), float(u.max())
            elif idx == 2:
                data_to_plot = R_g[:, patch_idx, 0]
                vmin, vmax = float(target_R.min()) - 1e-6, float(target_R.max()) + 1e-6
            else:
                data_to_plot = target_R[:, patch_idx, 0]
                vmin, vmax = float(target_R.min()) - 1e-6, float(target_R.max()) + 1e-6
            sc = ax.scatter(X, Y, Z, c=data_to_plot, cmap="viridis", s=10, vmin=vmin, vmax=vmax, depthshade=False)
        ax.set_title(title)
        fig.colorbar(sc, ax=ax, shrink=0.6)
        ax.set_box_aspect((1, 1, 1))
        if idx in (2, 3):
            ax.view_init(elev=elevation_angle, azim=azimuthal_angle)

    plt.tight_layout()
    plt.show()

plot_features_xyz(samples_xyz=test_samples_xyz, g=g, R_g=R_g, u=u, target_R=target_R, prescriber_name=prescriber_name)

In [None]:
def check_gauss_bonnet(u: Callable, prescriber: Callable):
    def integrand(theta: tf.Tensor, phi: tf.Tensor):
        # u_model and prescriber require Cartesian coordinates
        x = tf.sin(theta) * tf.cos(phi)
        y = tf.sin(theta) * tf.sin(phi)
        z = tf.cos(theta)
        xyz = tf.concat((x, y, z), axis=-1)
        grid_shape = tf.shape(xyz)[:-1] 
        xyz = tf.reshape(xyz, [-1, 3])
        # evalute conformal factor u
        u_vals = u(xyz)
        u_vals = tf.reshape(u_vals, [-1])
        # evalute prescriber
        R_vals = prescriber(xyz)
        R_vals = tf.reshape(R_vals, [-1])
        # build integrand, take care to divide R by 2.0
        out = (R_vals / 2.0) * tf.exp(2.0 * u_vals)
        out = tf.reshape(out, tf.concat((grid_shape, [1]), axis=0))
        return out

    gl_estimate = integrate_gauss_legendre(f=integrand)
    mc_estimate = integrate_monte_carlo(f=integrand)

    gl_rel_error = abs(gl_estimate - (4.0 * np.pi)) / (4.0 * np.pi)
    mc_rel_error = abs(mc_estimate - (4.0 * np.pi)) / (4.0 * np.pi)

    print(f"[Gauss-Legendre] ∫_S^2 K dA_g ~= {gl_estimate} | Relative error: {gl_rel_error:.2%}")
    print(f"[Monte-Carlo] ∫_S^2 K dA_g ~= {mc_estimate} | Relative error: {mc_rel_error:.2%}")

check_gauss_bonnet(u=model.u_model, prescriber=prescriber)