This notebook consists of:

0) Imports. Loading AE weights.
1)  Grid of geodesics in latent space using Shauder basis for approximation

The geodesics are computed considesing a geodesic boundary value problem (b.v.p.) with the help of $L^2$ approximation.

Let $( M, g )$ be a geodesically-complete Riemannian manifold. Geodesics connecting two points are found as minimizers of energy functional $ E $.The energy functional $E$ for a curve $\gamma : [0,1] \to M $ is given by:
\begin{align*}
E[\gamma] = \int_0^1 g_{\gamma(t)}(\dot{\gamma}(t), \dot{\gamma}(t)) \, dt 
\end{align*}

The problem of finding a geodesic between two points $p, q \in M$ can be formulated as:
\begin{equation}
%\label{eq:geodesic_via_energy}
\begin{aligned}
\gamma &= \arg\min_{\gamma} E[\gamma] \,, \\
\text{subject to:} \quad
\gamma(0) &= p \,, \\
\gamma(1) &= q \,. \\
\end{aligned}
\end{equation}

Technically, they are dimension-wise approximated by a linear combination of Shauder basis functions through the solution of an optimization problem on the coefficients.

The latent space of the AE is topologically a $ 2 $-dimensional torus $\mathcal{T}^2$, i.e., it can be considered as a periodic box $[-\pi, \pi]^2$. We define a Riemannian metric on the latent space as the pull-back of the Euclidean metric of the output space $\mathbb{R}^D$ by the decoder function $\Psi$ of the AE obtaining $g = \nabla \Psi^* \nabla \Psi \ .$


# 0. Imports. Loading AE weights.

In [None]:
import torch
import numpy as np
import ricci_regularization
import matplotlib.pyplot as plt
import yaml, os
import math

In [None]:
violent_saving = True
setting_number = 1
pretrained_AE_setting_name = f'MNIST_Setting_{setting_number}'
Path_AE_config = f'../experiments/{pretrained_AE_setting_name}_config.yaml'
with open(Path_AE_config, 'r') as yaml_file:
    yaml_config = yaml.load(yaml_file, Loader=yaml.FullLoader)

In [None]:
# Load data loaders based on YAML configuration
dict = ricci_regularization.DataLoaders.get_dataloaders(
    dataset_config=yaml_config["dataset"],
    data_loader_config=yaml_config["data_loader_settings"],
    dtype=torch.float32
)
train_loader = dict["train_loader"]
test_loader = dict["test_loader"]
test_dataset = dict.get("test_dataset")  # Assuming 'test_dataset' is a key returned by get_dataloaders

print("Data loaders created successfully.")

torus_ae, Path_ae_weights = ricci_regularization.DataLoaders.get_tuned_nn(config=yaml_config)

print("AE weights loaded successfully.")
print("AE weights loaded from", Path_ae_weights)

In [None]:
experiment_name = yaml_config["experiment"]["name"]

#Path_pictures = yaml_config["experiment"]["path"]
Path_pictures = "../experiments/" + yaml_config["experiment"]["name"]
if violent_saving == True:
    # Check and create directories based on configuration
    if not os.path.exists(Path_pictures):  # Check if the picture path does not exist
        os.mkdir(Path_pictures)  # Create the directory for plots if not yet created
        print(f"Created directory: {Path_pictures}")  # Print directory creation feedback
    else:
        print(f"Directiry already exists: {Path_pictures}")

dataset_name = yaml_config["dataset"]["name"]
D = yaml_config["architecture"]["input_dim"]
# D is the dimension of the dataset
if dataset_name == ["MNIST01"]:
    # k from the JSON configuration file is the number of classes
    #k = yaml_config["dataset"]["k"]
    k = len(yaml_config["dataset"]["selected_labels"])
    selected_labels = yaml_config["dataset"]["selected_labels"]
elif dataset_name == "MNIST":
    selected_labels = np.arange(10)
    k = 10
elif dataset_name == "Synthetic":
    k = yaml_config["dataset"]["k"] 
print("Experiment name:", experiment_name)
print("Plots will be saved at:", Path_pictures)

In [None]:
# Move the torus autoencoder to CPU
torus_ae.cpu()

# Latent space plot (encoded dataset) via ricci_regularization.point_plot
fig = ricci_regularization.point_plot(
    encoder=torus_ae.encoder_to_lifting,
    data_loader=test_loader,
    batch_idx=0,
    config=yaml_config,
    show_title=True,
    colormap='jet',
    normalize_to_unit_square=False,
    s=40,
    draw_grid=False,
    figsize=(9, 9),
)
fig.show()

# Optionally save
if violent_saving:
    try:
        fig.savefig(f"{Path_pictures}/latent_space_point_plot_{experiment_name}.pdf", bbox_inches="tight", format="pdf")
    except Exception as e:
        print(f"Could not save figure via Matplotlib: {e}")


In [None]:
# Demo: a geodesic connecting two poinnts (geodesic bvp via Schauder).
# Geodesic via energy minimization (geodesic_bvp)

# choose device
optimization_device = "cpu"
# choose two latent points (within [-pi, pi]^2)
p0 = torch.tensor([0.0, -2.0])
p1 = torch.tensor([0., 3.])
torus_ae.to(optimization_device)

n_max = 7 # depth of Schauder basis
step_count = 100 # number of interpolation points on each geodesic
num_epochs = 500
geodesic_solver = ricci_regularization.Schauder.NumericalGeodesics(n_max, step_count)
optimizer_info = {
    "name": "Adam",   # optimizer class name as string
    "args": {
        "lr": 0.01     # learning rate
        # "betas": (0.9, 0.999)  # optional for Adam
    }
}

lin_curve, geod_curve = geodesic_solver.computeGeodesicInterpolation(generator=torus_ae.decoder_torus, optimizer_info=optimizer_info,
                                             m1=p0,
                                             m2=p1, 
                                             epochs=num_epochs, display_info="geodesic optimization",
                                             device=optimization_device)


plt.figure(figsize=(7,7))
plt.plot(geod_curve[:, 0], geod_curve[:, 1], c="blue", linewidth=2, label="Geodesic")
plt.scatter([p0[0]], [p0[1]], c="green", marker="o", s=80, label="start p0")
plt.scatter([p1[0]], [p1[1]], c="red", marker="*", s=120, label="end p1")
plt.xlim(-math.pi, math.pi)
plt.ylim(-math.pi, math.pi)
plt.legend(loc="upper left")
if violent_saving:
    plt.savefig(f"{Path_pictures}/geodesic_bvp_single_{experiment_name}.pdf", bbox_inches="tight", format="pdf")
plt.show()


# 1. Grid of geodesics via Schauder

In [None]:
# --- Grid parameters ---
N = 10  # vertical lines
K = 10  # horizontal lines
side_size = 4
step_count = 100

optimization_device = "cuda"

# Move model to optimization device
torus_ae.to(optimization_device)

# --- 1. Compute border geodesics (bounding box) ---
border_geodesics = {}

# Corner points
bottom_left  = torch.tensor([-side_size/2, -side_size/2])
bottom_right = torch.tensor([ side_size/2, -side_size/2])
top_left     = torch.tensor([-side_size/2,  side_size/2])
top_right    = torch.tensor([ side_size/2,  side_size/2])

# Bottom border: left -> right
_, border_geodesics["bottom"] = geodesic_solver.computeGeodesicInterpolationBatch(
    generator=torus_ae.decoder_torus,
    optimizer_info=optimizer_info,
    m1_batch=bottom_left,
    m2_batch=bottom_right,
    epochs=num_epochs,
    display_info="bottom border",
)

# Top border: left -> right
_, border_geodesics["top"] = geodesic_solver.computeGeodesicInterpolationBatch(
    generator=torus_ae.decoder_torus,
    optimizer_info=optimizer_info,
    m1_batch=top_left,
    m2_batch=top_right,
    epochs=num_epochs,
    display_info="top border",
)

# Left border: bottom -> top
_, border_geodesics["left"] = geodesic_solver.computeGeodesicInterpolationBatch(
    generator=torus_ae.decoder_torus,
    optimizer_info=optimizer_info,
    m1_batch=bottom_left,
    m2_batch=top_left,
    epochs=num_epochs,
    display_info="left border",
)

# Right border: bottom -> top
_, border_geodesics["right"] = geodesic_solver.computeGeodesicInterpolationBatch(
    generator=torus_ae.decoder_torus,
    optimizer_info=optimizer_info,
    m1_batch=bottom_right,
    m2_batch=top_right,
    epochs=num_epochs,
    display_info="right border",
)

In [None]:
def sample_curve(curve, num_points):
    """
    Uniformly samples points along a given curve using linear interpolation.

    Parameters:
    - curve (torch.Tensor): Tensor of shape (N, d) representing a curve with N points in d dimensions.
    - num_points (int): Number of points to sample along the curve.

    Returns:
    - sampled (torch.Tensor): Tensor of shape (num_points, d) with the sampled points.
    
    How it works:
    1. Compute `idxs`, `num_points` evenly spaced indices along the curve:
       - `idxs` are fractional indices from 0 to len(curve)-1.
    2. Split each index into floor and ceil components:
       - `idxs_floor` = integer part (lower index)
       - `idxs_ceil` = integer part + 1 (upper index)
    3. Compute interpolation weights `alpha` = fractional part of idx.
    4. Linearly interpolate between `curve[idxs_floor]` and `curve[idxs_ceil]` using `alpha`:
       sampled = (1-alpha)*floor_point + alpha*ceil_point
    5. Return `sampled`, a set of `num_points` points along the curve.

    Notes:
    - This is a simple linear interpolation along the discrete curve points.
    - Works for curves represented as sequences of points in any dimension.
    """
    idxs = torch.linspace(0, len(curve)-1, num_points)
    idxs_floor = idxs.floor().long()
    idxs_ceil = idxs.ceil().long()
    alpha = idxs - idxs_floor
    sampled = (1-alpha).unsqueeze(1)*curve[idxs_floor] + alpha.unsqueeze(1)*curve[idxs_ceil]
    return sampled

bottom_points = sample_curve(border_geodesics["bottom"], N)
top_points = sample_curve(border_geodesics["top"], N)
left_points = sample_curve(border_geodesics["left"], K)
right_points = sample_curve(border_geodesics["right"], K)

In [None]:
# --- 3. Compute vertical geodesics (bottom -> top) ---
# bottom_points and top_points should have shape (N, 2)
_, vertical_geodesics = geodesic_solver.computeGeodesicInterpolationBatch(
    generator=torus_ae.decoder_torus,
    optimizer_info=optimizer_info,
    m1_batch=bottom_points,  # shape (N, 2)
    m2_batch=top_points,     # shape (N, 2)
    epochs=num_epochs,
    display_info="vertical geodesics",
)
# vertical_geodesics shape: (N, step_count, 2)

# --- 4. Compute horizontal geodesics (left -> right) ---
_, horizontal_geodesics = geodesic_solver.computeGeodesicInterpolationBatch(
    generator=torus_ae.decoder_torus,
    optimizer_info=optimizer_info,
    m1_batch=left_points,   # shape (K, 2)
    m2_batch=right_points,  # shape (K, 2)
    epochs=num_epochs,
    display_info="horizontal geodesics",
)
# horizontal_geodesics shape: (K, step_count, 2)

In [None]:
# move model to CPU for plotting
torus_ae.to("cpu")

# background latent scatter
fig = ricci_regularization.point_plot(
    encoder=torus_ae.encoder_to_lifting,
    data_loader=test_loader,
    batch_idx=0,
    config=yaml_config,
    show_title=False,
    colormap='jet',
    normalize_to_unit_square=False,
    s=6,
    draw_grid=False,
    figsize=(9, 9),
)
ax = fig.axes[0]

# overlay border geodesics (bounding box)
ax.plot(border_geodesics["bottom"][:, 0], border_geodesics["bottom"][:, 1], c="black", linewidth=2.0, alpha=0.8, zorder=10)
ax.plot(border_geodesics["top"][:, 0], border_geodesics["top"][:, 1], c="black", linewidth=2.0, alpha=0.8, zorder=10)
ax.plot(border_geodesics["left"][:, 0], border_geodesics["left"][:, 1], c="black", linewidth=2.0, alpha=0.8, zorder=10)
ax.plot(border_geodesics["right"][:, 0], border_geodesics["right"][:, 1], c="black", linewidth=2.0, alpha=0.8, zorder=10)

# overlay vertical geodesics
for curve in vertical_geodesics:
    ax.plot(curve[:, 0], curve[:, 1], c="black", linewidth=1.2, alpha=0.8, zorder=5)

# overlay horizontal geodesics
for curve in horizontal_geodesics:
    ax.plot(curve[:, 0], curve[:, 1], c="black", linewidth=1.2, alpha=0.8, zorder=5)

ax.set_xlim(-math.pi, math.pi)
ax.set_ylim(-math.pi, math.pi)
ax.set_aspect('equal', adjustable='box')
fig.tight_layout()

if violent_saving:
    fig.savefig(f"{Path_pictures}/grid_geodesics_Schauder_{experiment_name}.pdf", bbox_inches="tight", format="pdf")

plt.show()