# Optimizing scalar curvature computation via jacfwd

In [None]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import torch
import ricci_regularization
from torch.func import vmap, jacfwd
import functools

In [None]:
torch.manual_seed(0)
Path_pictures = "../../experiments"
dtype = torch.float32
d = 2
torus_ae = ricci_regularization.Architectures.TorusAE(
        x_dim=784,
        h_dim1=512,
        h_dim2=256,
        z_dim=d,
        dtype=dtype
    )

In [None]:
decoder = torus_ae.decoder_torus

# curvature computation breakdown

In [None]:
N = 1000
points = torch.rand(N, 2)

In [None]:
ricci_regularization.Sc_jacfwd_vmap(points, function = decoder)

In [None]:
ricci_regularization.Ch_der_jacfwd_vmap(points, function = decoder)

In [None]:
ricci_regularization.Ch_jacfwd_vmap(points, function = decoder)

In [None]:
ricci_regularization.metric_jacfwd_vmap(points, function = decoder)

# A faster way? getting rid of recursive hell

In [None]:
# inspiration computing Ch and Ch_der simultanuously using flag has_aux
def foo(x):
    #result = functools.partial(ricci_regularization.Ch_jacfwd, function=decoder) ( x )
    result = ricci_regularization.Ch_jacfwd(x, function=decoder)
    bla = torch.tensor([5.])
    return result, (result, bla)

In [None]:
jacobian_f, f_x = vmap(jacfwd(foo, has_aux=True))( points )

In [None]:
print(jacobian_f.shape)
print(f_x[0].shape)
print(f_x[1].shape)

In [None]:
# real thing
def metric_jacfwd(u, function, latent_space_dim=2):
    # here u is one point!
    u = u.reshape(-1,latent_space_dim)
    jac = torch.func.jacfwd(function)(u)
    jac = jac.reshape(-1,latent_space_dim)
    metric = torch.matmul(jac.T,jac)
    return metric

# this function is auxiliary in computing metric and its derivatives later
# as one needs to output both the result and its derivative simultanuousely 
def aux_func_metric(x, function):
    g = metric_jacfwd( x, function=function)
    return g, g

# this also not vectorized
def Ch_g_g_inv_jacfwd (u, function, eps = 0.0):
    # compute metric and its derivatives at a batch of points
    dg, g = jacfwd( functools.partial(aux_func_metric, function=function),
                         has_aux=True)( u )
    # compute inverse of metric with some regularization param eps    
    d = g.shape[0]
    device = g.device
    g_inv = torch.inverse(g + eps*torch.eye(d,device=device))
    # compute Christoffel symbols
    Ch = 0.5*(torch.einsum('im,mkl->ikl',g_inv,dg)+
              torch.einsum('im,mlk->ikl',g_inv,dg)-
              torch.einsum('im,klm->ikl',g_inv,dg)
              )
    return Ch, g, g_inv

def aux_func(x,function, eps=0.0):
    Ch, g, g_inv = Ch_g_g_inv_jacfwd( x, function=function, eps=eps)
    return Ch, (Ch, g, g_inv)
#dCh, (Ch, g_inv) = vmap(jacfwd(functools.partial( aux_func, function=decoder, eps=0. ),
#                            has_aux=True))( points )

# this also not vectorized
def Sc_jacfwd (u, function, eps = 0.0):
    # compute Christoffel symbols and derivatives and inverse of metric
    dCh, (Ch, g, g_inv) = jacfwd(functools.partial( aux_func, function=function, eps=eps),
                            has_aux=True)( u )
    
    Riemann = torch.einsum("iljk->ijkl",dCh) - torch.einsum("ikjl->ijkl",dCh)
    Riemann += torch.einsum("ikp,plj->ijkl", Ch, Ch) - torch.einsum("ilp,pkj->ijkl", Ch, Ch)
    
    Ricci = torch.einsum("cacb->ab",Riemann)
    Sc = torch.einsum('ab,ab',g_inv,Ricci)
    return Sc

# vectorization
Sc_jacfwd_vmap = torch.func.vmap(Sc_jacfwd)

In [None]:
ricci_regularization.curvature_loss_jacfwd(points, function=decoder)

In [None]:
N = 5000
points = torch.rand(N, 2)

In [None]:
Sc_jacfwd_new_vmap(points, function=decoder)

In [None]:
ricci_regularization.Sc_jacfwd_vmap(points, function = decoder)