NB! Stochman pachage is required.
Type: pip install stochman

The latent space of the AE is topologically a $ 2 $-dimensional torus $\mathcal{T}^2$, i.e., it can be considered as a periodic box $[-\pi, \pi]^2$. We define a Riemannian metric on the latent space as the pull-back of the Euclidean metric in the output space $\mathbb{R}^D$ by the decoder function $\Psi$ of the AE:
\begin{equation}
    g = \nabla \Psi^* \nabla \Psi \ ,
\end{equation}

Let $( M, g )$ be a Riemannian manifold.

In this notebook we consider two different ways of finding geodesics:

1) Geodesic shooting.

Given a points $ p\in M $ and a vector $v\in T_p M$, a geodesic starting from the point $p$ with speed $v$ is the curve $\gamma : [0,1] \to M $ such that:
\begin{equation}
%\label{eq:geodesic_eq}
\begin{aligned}
    \gamma(0) &= p \ , \\
    \gamma'(0) &= v \ , \\
    \nabla_{\dot{\gamma}} \dot{\gamma} &= 0 \ .
\end{aligned}
\end{equation}
where $ \nabla $ is the Levi-Civita connection associated with $ g $.

In local coordinates $ (x^1, x^2, \ldots, x^n) $, the geodesic equation is:
\begin{align}
    \frac{d^2 x^i}{dt^2} + \Gamma^i_{jk} \frac{dx^j}{dt} \frac{dx^k}{dt} = 0 \ ,
\end{align}
where $ \Gamma^i_{jk} $ are the Christoffel symbols.

2) Geodesic boundarry value problem (b.v.p.).

Given points $ p, q \in M $, find a curve $\gamma : [0,1] \to M $ such that:
\begin{equation}
%\label{eq:geodesic_eq}
\begin{aligned}
    \gamma(0) &= p \ , \\
    \gamma(1) &= q \ , \\
    \nabla_{\dot{\gamma}} \dot{\gamma} &= 0 \ .
\end{aligned}
\end{equation}
where $ \nabla $ is the Levi-Civita connection associated with $ g $.

The length functional $ L $ for a curve $ \gamma $ is given by:
\begin{align*}
    L[\gamma] = \int_0^1 \sqrt{g_{\gamma(t)}(\dot{\gamma}(t), \dot{\gamma}(t))} \, dt
\end{align*}

The energy functional $E$ for a curve $\gamma$ is given by:
\begin{align*}
E[\gamma] = \int_0^1 g_{\gamma(t)}(\dot{\gamma}(t), \dot{\gamma}(t)) \, dt \quad \text{(6)}
\end{align*}

Geodesics are the curves that minimize the length functional $ L $ and also minimize the energy functional $ E $. In the Stochman package, geodesics connecting two points are found as minimizers of energy functionals. Technically, they are approximated by cubic splines through the solution of an optimization problem on the spline coefficients.



# Geodesic shooting in the torus latent space on a local chart

In [None]:
from tqdm.notebook import tqdm
import torch
import ricci_regularization
import matplotlib.pyplot as plt
import matplotlib

experiment_json = f'../../experiments/MNIST01_torus_AEexp7.json'
mydict = ricci_regularization.get_dataloaders_tuned_nn(Path_experiment_json=experiment_json)

In [None]:
torus_ae = mydict["tuned_neural_network"]
test_loader = mydict["test_loader"]

In [None]:
D = 784
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
feature_space_encoding_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( test_loader, position=0 ):
#for (data, labels) in tqdm( train_loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    feature_space_encoding_list.append(torus_ae.encoder_torus(data.view(-1,D)))
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder_to_lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
feature_space_encoding = torch.cat(feature_space_encoding_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(2,"jet"))
plt.show()

In [None]:
def geod(x,dxdt):
    u = x
    v = dxdt
    dudt = v
    dvdt = torch.zeros(2)
    Ch_at_u = ricci_regularization.Ch_jacfwd(u,function=torus_ae.decoder_torus,device=torch.device("cpu"))
    for l in range(2):
        for i in range(2):
            for j in range(2):
                dvdt[l] -= Ch_at_u[l,i,j] * v[i] * v[j]
    return dudt, dvdt

def rungekutta1(f, initial_point, initial_speed, t, args=()):
    n = len(t)
    x = torch.zeros(n,len(initial_point))
    dxdt = torch.zeros(n,len(initial_speed))
    x[0] = initial_point
    dxdt[0] = initial_speed
    for i in range(n - 1):
        dudt, dvdt = f(x[i], dxdt[i], *args)
        x[i+1] = x[i] + (t[i+1] - t[i])*dudt
        dxdt[i+1] = dxdt[i] + (t[i+1] - t[i])*dvdt
        #print(type(t[i]))
        #print (np.array(f(y[i], t[i], *args)))
    return x, dxdt

# Example

In [None]:
starting_point = torch.tensor([-2.1,0.0])
tangent_vector = torch.tensor([1.0,-0.0])

In [None]:
geod(starting_point,tangent_vector)

In [None]:
num_approximation_points = 101 # how good the approximation is
max_parameter_value = 3 # how far to go
time_array = torch.linspace(0, max_parameter_value, num_approximation_points)
points, velocities = rungekutta1(geod,initial_point=starting_point,
                                 initial_speed=tangent_vector,t=time_array)
points_no_grad = points.detach()

end_point = points_no_grad[-1]
end_speed = velocities[-1].detach()
reverse_points, velocities = rungekutta1(geod,initial_point=end_point, initial_speed=-end_speed,t=time_array)
reverse_points_no_grad = reverse_points.detach()

## Geodesics

In [None]:
torch.set_printoptions(precision=2)
plt.title(f"Geodesic and its reverse: \nstart_point:{starting_point},speed:{tangent_vector}, \nend_point:{end_point},end_speed:{end_speed}")
plt.plot(points_no_grad[:,0], points_no_grad[:,1],c="green")
plt.plot(reverse_points_no_grad[:,0], reverse_points_no_grad[:,1],c="orange")

plt.show()

## Geodesics and data in the latent space

In [None]:
torch.set_printoptions(precision=2)
plt.title(f"Geodesic and its reverse: \nstart_point:{starting_point},speed:{tangent_vector}, \nend_point:{end_point},end_speed:{end_speed}")
plt.plot(points_no_grad[:,0], points_no_grad[:,1],c="green")
plt.plot(reverse_points_no_grad[:,0], reverse_points_no_grad[:,1],c="orange")
# data on the background
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(2,"jet"))
plt.show()

# Shooting several geodesics

In [None]:
def geod_vect(x,dxdt):
    u = x
    v = dxdt
    dudt = v
    n = v.shape[0]
    dvdt = torch.zeros(n,2)
    Ch_at_u = ricci_regularization.Ch_jacfwd_vmap(u,function=torus_ae.decoder_torus,device=torch.device("cpu"))
    for l in range(2):
        for i in range(2):
            for j in range(2):
                dvdt[:,l] -= Ch_at_u[:,l,i,j] * v[:,i] * v[:,j]
    return dudt, dvdt

def rungekutta_vect(f, initial_point_array, initial_speed_array, t, args=()):
    n = len(t)
    #num_geodesics = initial_point_array.shape[0]
    x = torch.zeros((n, *tuple(initial_point_array.shape)))
    dxdt = torch.zeros((n, *tuple(initial_speed_array.shape)))
    x[0] = initial_point_array
    dxdt[0] = initial_speed_array
    for i in range(n - 1):
        dudt, dvdt = f(x[i], dxdt[i], *args)
        x[i+1] = x[i] + (t[i+1] - t[i])*dudt
        dxdt[i+1] = dxdt[i] + (t[i+1] - t[i])*dvdt
        #print(type(t[i]))
        #print (np.array(f(y[i], t[i], *args)))
    return x, dxdt
# x is of shape [num_grid_points,num_geodesics,dimension=2]

In [None]:
from torch.nn.functional import normalize

num_approximation_points = 101 # how good the approximation is
max_parameter_value = 3 # how far to go
time_array = torch.linspace(0, max_parameter_value, num_approximation_points)

num_geodesics = 100

starting_points = torch.tensor([-2.,0.]).repeat(num_geodesics,1) # common starting point
maxtangent = 2 # max slope of geodesics 
starting_speeds = torch.cat([torch.tensor([1.,0. + k]) for k in torch.linspace(-maxtangent,maxtangent,num_geodesics) ]).reshape(num_geodesics,2)
starting_speeds = normalize(starting_speeds) #make norms of all speeds equal

geodesics2plot,_ = rungekutta_vect(f=geod_vect,initial_point_array=starting_points,
                                   initial_speed_array=starting_speeds,t=time_array)
geodesics2plot = geodesics2plot.detach()

In [None]:
scalar_curvature_on_geodesics = ricci_regularization.Sc_jacfwd_vmap(geodesics2plot.reshape(-1,2),function=torus_ae.decoder_torus,device=torch.device("cpu"))
scalar_curvature_on_geodesics = scalar_curvature_on_geodesics.reshape(num_approximation_points,num_geodesics).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(2,"jet"))
for i in range(num_geodesics):
    #plt.scatter(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c=time_array,cmap="jet")
    plt.scatter(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c=scalar_curvature_on_geodesics[:,i],cmap="viridis",norm=matplotlib.colors.SymLogNorm(linthresh=1e-2))
    plt.plot(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c="black")
plt.colorbar(label="scalar curvature along geodesics")
plt.show()

# Geodesic bvp with Stochman

In [None]:
import torch
from stochman.manifold import EmbeddedManifold

In [None]:
# geodesics are computed minimizing "energy" in the embedding of the manifold,
# So no need to compute the Pullback metric. and thus the algorithm is fast
class Autoencoder(EmbeddedManifold):
    def embed(self, c, jacobian = False):
        return torus_ae.decoder_torus(c)

In [None]:
model = Autoencoder()
torch.manual_seed(0)
#p0, p1 = torch.randn(1, 2), torch.randn(1, 2)
p0 = torch.tensor([-2.,0.])
p1 = torch.tensor([2.,0.])
#print(f"start:{p0}, \n end {p1}")
c, _ = model.connecting_geodesic(p0, p1) # here the parameter t in c(t)should be a torch.tensor

In [None]:
t = torch.linspace(0.,1.,100)

In [None]:
points_on_geodesic = c(t).detach()

In [None]:
plt.title("Geodesic bvp")
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(2,"jet"))
plt.plot(points_on_geodesic[:,0],points_on_geodesic[:,1],c="orange")
plt.show()

# Geodesic bvp for several geodesics

In [None]:
num_geodesics = 10

x_left = -2.0
x_right = 2.0
y_left = 0.
y_right = -2.

y_size = 2 # max shift of geodesics 
starting_points = torch.cat([torch.tensor([x_left,y_left + k]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)
end_points = torch.cat([torch.tensor([x_right,y_right + k]) for k in torch.linspace(0,y_size,num_geodesics) ]).reshape(num_geodesics,2)

In [None]:
c, success = model.connecting_geodesic(starting_points, end_points)
print("Success:", success.item())

In [None]:
num_approximation_points = 101
t = torch.linspace(0,1,num_approximation_points)

In [None]:
geodesics2plot_bvp = c(t).detach()

In [None]:
scalar_curvature_on_geodesics_bvp = ricci_regularization.Sc_jacfwd_vmap(geodesics2plot_bvp.reshape(-1,2),function=torus_ae.decoder_torus,device=torch.device("cpu"))
scalar_curvature_on_geodesics_bvp = scalar_curvature_on_geodesics_bvp.reshape(num_geodesics,num_approximation_points).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(2,"jet"))
for i in range(num_geodesics):
    #plt.scatter(geodesics2plot[:,i,0],geodesics2plot[:,i,1],c=time_array,cmap="jet")
    plt.scatter(geodesics2plot_bvp[i,:,0],geodesics2plot_bvp[i,:,1],c=scalar_curvature_on_geodesics_bvp[i,:],cmap="viridis",norm=matplotlib.colors.SymLogNorm(linthresh=1e-2))
    plt.plot(geodesics2plot_bvp[i,:,0],geodesics2plot_bvp[i,:,1],c="orange")
plt.colorbar(label="scalar curvature along geodesics")
plt.show()
