# I. Imports and some functions for plotting (Skip reading this)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
device = torch.device("cpu")
import torch
import torch.func as TF
from functorch import jacrev,jacfwd
import matplotlib.pyplot as plt
import timeit
import functools

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x
    
decoder = Decoder(encoded_space_dim = 2,fc2_input_dim=128)

# Send to device
decoder.to(device) 

# Load the parameters of the trained decoder without curvature in Loss func
PATH_dec = '../nn_weights/decoder_conv_autoenc.pt'
decoder.load_state_dict(torch.load(PATH_dec))

# Switch to eval mode
decoder.eval()

In [None]:
def make_grid(numsteps, xshift = 0.0, yshift = 0.0):
    
    xs = torch.linspace(-1.5, 1.5, steps = numsteps) + xshift
    ys = torch.linspace(-1.5, 1.5, steps = numsteps) + yshift
    #uniform_grid = torch.cartesian_prod(xs,ys)

    # true grid starts from left bottom corner. x is the first to increase
    tgrid = torch.cartesian_prod(ys, xs)
    tgrid = tgrid.roll(1,1)
    return tgrid

In [None]:
def draw_frob_norm_tensor_on_grid(plot_name,tensor_on_grid, numsteps = 100,xshift = 0.0, yshift = 0.0):
    Frob_norm_on_grid = tensor_on_grid.norm(dim=(1,2)).view(numsteps,numsteps)
    #Frob_norm_on_grid = metric_on_grid.norm(dim=(1,2)).view(numsteps,numsteps)
    Frob_norm_on_grid = Frob_norm_on_grid[1:-1,1:-1].detach()

    fig, ax = plt.subplots()
    im = ax.imshow(Frob_norm_on_grid,origin="lower")

    cbar = ax.figure.colorbar(im)
    
    ax.set_xticks((Frob_norm_on_grid.shape[0]-1)*(np.linspace(0,1,num=11)),labels=(np.linspace(-1.5,1.5,num=11)+xshift).round(1))
    ax.set_yticks((Frob_norm_on_grid.shape[1]-1)*(np.linspace(0,1,num=11)),labels=(np.linspace(-1.5,1.5,num=11)+yshift).round(1))
    plt.xlabel( "x coordinate")
    plt.ylabel( "y coordinate")
    plt.axis('scaled')

    ax.set_title(plot_name)
    fig.tight_layout()
    plt.show()
    return plt

# II. Tensors computed with higher order derivatives using jacfwd

In [None]:
def metric_jacfwd(u, function = decoder, latent_space_dim=2):
    u = u.reshape(-1,latent_space_dim)
    jac = jacfwd(function)(u)
    jac = jac.reshape(-1,latent_space_dim)
    metric = torch.matmul(jac.T,jac)
    return metric

metric_jacfwd_vmap = TF.vmap(metric_jacfwd)

In [None]:
# The variable wrt which 
# the derivative is computed is the last index
def metric_der_jacfwd (u, function = decoder):
    metric = functools.partial(metric_jacfwd, function=function)
    dg = jacfwd(metric)(u).squeeze()
    # squeezing is needed to get rid of 1-dimentions 
    # occuring when using jacfwd
    return dg

In [None]:
def Ch_jacfwd (u, function = decoder):
    g = metric_jacfwd(u,function)
    g_inv = torch.inverse(g)
    dg = metric_der_jacfwd(u,function)
    Ch = 0.5*(torch.einsum('im,mkl->ikl',g_inv,dg)+
              torch.einsum('im,mlk->ikl',g_inv,dg)-
              torch.einsum('im,klm->ikl',g_inv,dg)
              )
    return Ch
Ch_jacfwd_vmap = TF.vmap(Ch_jacfwd)

In [None]:
def Ch_der_jacfwd (u, function = decoder):
    Ch = functools.partial(Ch_jacfwd, function=function)
    dCh = jacfwd(Ch)(u).squeeze()
    return dCh
Ch_der_jacfwd_vmap = TF.vmap(Ch_der_jacfwd)

In [None]:
# Riemann curvature tensor (3,1)
def Riem_jacfwd(u, function = decoder):
    Ch = Ch_jacfwd(u, function)
    Ch_der = Ch_der_jacfwd(u, function)

    Riem = torch.einsum("iljk->ijkl",Ch_der) - torch.einsum("ikjl->ijkl",Ch_der)
    Riem += torch.einsum("ikp,plj->ijkl", Ch, Ch) - torch.einsum("ilp,pkj->ijkl", Ch, Ch)
    return Riem

In [None]:
def Ric_jacfwd(u, function = decoder):
    Riemann = Riem_jacfwd(u, function)
    Ric = torch.einsum("cacb->ab",Riemann)
    return Ric
Ric_jacfwd_vmap = TF.vmap(Ric_jacfwd)

In [None]:
# demo
Ric_jacfwd_vmap(torch.rand(3,2))

# III. Ground truth check

### Sphere

In [None]:
def my_fun_sphere(u):
    u = u.flatten()
    output = torch.cat((torch.sin(u[0])*torch.cos(u[1]).unsqueeze(0),torch.sin(u[0])*torch.sin(u[1]).unsqueeze(0),torch.cos(u[0]).unsqueeze(0)),dim=-1)
    output = torch.cat((output.unsqueeze(0),torch.zeros(781).unsqueeze(0)),dim=1)
    output = output.flatten()
    return output

In [None]:
# Motivating demo
torch.manual_seed(10)
test_batch = torch.rand(3,2)
print("metric:\n", metric_jacfwd_vmap(test_batch,
                                      function=my_fun_sphere))
print("Ricci tensor:\n", Ric_jacfwd_vmap(test_batch,
                                      function=my_fun_sphere))

In [None]:
# Historgam of errors. This is done in order to verify Ricci 
# tensor computation. Ric = k*g,with k = n-1 for an 
# n-dimentional sphere S^n. Thus if n = 2, Ric = g
torch.manual_seed(10)

test_batch = torch.rand(1000,2)
test_metric_array = metric_jacfwd_vmap(test_batch,
                                       function=my_fun_sphere)
test_Ric_array = Ric_jacfwd_vmap(test_batch,
                                       function=my_fun_sphere)

# here we check if g = Ric
absolute_error = (test_metric_array - test_Ric_array).norm(dim=(1,2))
relative_error = 100*absolute_error/(test_metric_array.norm(dim=(1,2)))



plt.hist(relative_error,bins=10,density=False,stacked=True)

In [None]:
# compare frobenius norm heatmaps of the metric 
# and the Ricci tensor. For the sphere they should coincide
numsteps = 100
tgrid = make_grid(numsteps)
Ric_on_grid = Ric_jacfwd_vmap(tgrid, function=my_fun_sphere)
metric_on_grid = metric_jacfwd_vmap(tgrid, function=my_fun_sphere)

draw_frob_norm_tensor_on_grid(plot_name = 'Frobenius norm of the metric',
                              tensor_on_grid= metric_on_grid, 
                              numsteps=numsteps)
draw_frob_norm_tensor_on_grid(plot_name = 'Frobenius norm of the Ricci tensor',
                              tensor_on_grid= Ric_on_grid, 
                              numsteps=numsteps)

### Lobachevsky plane

In [None]:
# Partial embedding (for y>c) of Lobachevsky plane to R^3 
# (formally here it is R^784)
# ds^2 = 1/y^2(dx^2 + dy^2)
# http://www.antoinebourget.org/maths/2018/08/08/embedding-hyperbolic-plane.html
def my_fun_lobachevsky(u, c=0.01):
    u = u.flatten()
    x = u[0]
    y = u[1]
    t = torch.acosh(y/c)
    x0 = t - torch.tanh(t)
    x1 = (1/torch.sinh(t))*torch.cos(x/c)
    x2 = (1/torch.sinh(t))*torch.sin(x/c)
    output = torch.cat((x0.unsqueeze(0),x1.unsqueeze(0),x2.unsqueeze(0)),dim=-1)
    output = torch.cat((output.unsqueeze(0),torch.zeros(781).unsqueeze(0)),dim=1)
    output = output.flatten()
    return output
    

In [None]:
# Motivating demo
torch.manual_seed(10)
test_batch = torch.rand(3,2)
print("metric:\n", metric_jacfwd_vmap(test_batch,
                                      function=my_fun_lobachevsky))
print("Ricci tensor:\n", Ric_jacfwd_vmap(test_batch,
                                      function=my_fun_lobachevsky))

In [None]:
# Historgam of errors. This is done in order to verify Ricci 
# tensor computation. Ric = k*g,with k = -1 
# for the Lobachevsky plane. Thus if Ric = -g
torch.manual_seed(10)

test_batch = torch.rand(1000,2) + 0.2 
# we use shift because y>0 for this model

test_metric_array = metric_jacfwd_vmap(test_batch,
                                       function=my_fun_lobachevsky)
test_Ric_array = Ric_jacfwd_vmap(test_batch,
                                       function=my_fun_lobachevsky)

# here we check if g = - Ric
absolute_error = (test_metric_array + test_Ric_array).norm(dim=(1,2))
relative_error = 100*absolute_error/(test_metric_array.norm(dim=(1,2)))



plt.hist(relative_error,bins=10,density=False,stacked=True)

In [None]:
# compare frobenius norm heatmaps of the metric 
# and the Ricci tensor. For the Lobachevsky plane they should coincide
numsteps = 100
tgrid = make_grid(numsteps, xshift=0.0, yshift=1.7)

lobachevsky_metric_on_grid = metric_jacfwd_vmap(tgrid, function=my_fun_lobachevsky)
lobachevsky_Ric_on_grid = Ric_jacfwd_vmap(tgrid, function=my_fun_lobachevsky)

draw_frob_norm_tensor_on_grid(plot_name = 'Lobachevsky plane: Frobenius norm of the metric',
                              tensor_on_grid=lobachevsky_metric_on_grid,
                            numsteps= numsteps, xshift=0.0, yshift=1.7)
draw_frob_norm_tensor_on_grid(plot_name = 'Lobachevsky plane: Frobenius norm of the Ricci tensor',
                              tensor_on_grid=lobachevsky_Ric_on_grid,
                            numsteps= numsteps, xshift=0.0, yshift=1.7)


# IV. The Ricci tensor for the metric given by the pullback of the decoder

In [None]:
# this takes around 17 secs
numsteps = 100
grid = make_grid(numsteps)
Decoder_Ric_on_grid = Ric_jacfwd_vmap(grid,function=decoder)
draw_frob_norm_tensor_on_grid(plot_name='Latent space: Frobenius norm of the Ricci tensor',
                              tensor_on_grid=Decoder_Ric_on_grid,
                              numsteps=numsteps)