In [1]:
# import 
import matplotlib.pyplot as plt
import numpy as np
import torch
%matplotlib widget


In [2]:
from utils.visualisation import plot_contour, plot_contour_3d, plot_sphere, map_to_sphere_np
from utils.energy import compute_geodesic

In [3]:
def metric_riemann(x):
    theta, phi = x
    metric = [[1, 0], [0, 1]]  # Euclidean
    ### CHANGE THE EUCLIDEAN METRIC WITH THE RIEMANNIAN METRIC
    return torch.Tensor(metric)

def norm_riemann(vv, x, metric_fn=metric_riemann):
    metric = torch.stack([metric_fn(pos) for pos in x])  # (n, 2, 2)
    norm = torch.einsum("bi,bij,bj->b", vv, metric, vv)
    return torch.sqrt(norm)


In [None]:
# Visualisation in the 2D plane (theta, phi) Theta: [-pi, pi], phi: [-np.pi/2, np.pi/2]
endpoints = np.array(
    [[[-3 * np.pi / 8, -3 * np.pi / 8], [3 * np.pi / 8, 3 * np.pi / 8]]]
)

print("Computing geodesics riemann...")
for start, end in endpoints:
    curver = compute_geodesic(start, end, norm_riemann, n_points=100)
    curver = np.stack([c.detach().numpy() for c in curver])

# straight line for comparison
liner = np.linspace(start, end, 100)

print("Plotting in 2d...")
fig = plt.figure(figsize=plt.figaspect(0.5))
ax = fig.add_subplot(1, 2, 1)
xlim, ylim = [-np.pi / 2, np.pi / 2], [-np.pi / 2, np.pi / 2]
# indicatrices
ax = plot_contour(ax, norm_riemann, xlim, ylim, grid_pts=[8, 8], color="purple")
# start and end points
ax.scatter(*start, color="orange")
ax.scatter(*end, color="orange")
# geodesic and straight line
ax.plot(curver[:, 0], curver[:, 1], color="purple", zorder=10)
ax.plot(liner[:, 0], liner[:, 1], color="black", zorder=10)
ax.title.set_text("Geodesic with the pullback metric")
ax.set_xlim(xlim)
ax.set_ylim(ylim)
plt.show()

Computing geodesics riemann...
Iteration 0, energy: 0.11215460300445557
Iteration 200, energy: 0.11215460300445557
Converged at iteration 201
Plotting in 2d...


In [None]:
# map points and geodesics to the sphere in 3D


def plot_geodesic_3d(ax, curve, color="r"):
    curve = np.stack([map_to_sphere_np(c) for c in curve])
    ax.plot(curve[:, 0], curve[:, 1], curve[:, 2], color=color)
    return ax


print("Plotting in 3d...")
fig = plt.figure(figsize=plt.figaspect(0.5))
ax = fig.add_subplot(1, 2, 1, projection="3d")
plot_sphere(ax)
# indicatrices
ax = plot_contour_3d(ax, norm_riemann, xlim, ylim, grid_pts=[8, 8], color="purple", alpha=0.5)
# start and end points
ax.scatter(*map_to_sphere_np(start), color="orange")
ax.scatter(*map_to_sphere_np(end), color="orange")
# geodesic and straight line
plot_geodesic_3d(ax, curver, color="purple")
plot_geodesic_3d(ax, liner, color="black")
# remove axis and grid
ax.set_axis_off()
ax.set_box_aspect([1,1,1])
