NB! this is an old notebook!

This notebook explores Scalar curvature computation with jacfwd - a PyTorch forward-propagation tool (See torch.func.jacfwd at https://pytorch.org/functorch/stable/generated/functorch.jacfwd.html).

The AE consists of the encoder $\Phi$ and the decoder $\Psi$.
The latent space of the AE is $R^d$. We define a Riemannian metric in a local chart of the latent space as the pull-back of the Euclidean metric in the output space $R^D$ by the decoder function $\Psi$ of the AE:
\begin{equation*}
    g = \nabla \Psi ^* \nabla \Psi \ .
\end{equation*}

In detail this notebook consists of:

0) Loading weights of the decoder $\Psi$ (its architecture is set in this notebook) of a pre-trained convolutional AE.
1) Auxillary tensors involving higher order derivatives are computed with jacfwd: metric $g$ and its derivatives, Riemann tensor $R^{i}_{jkl}$, Ricci tensor $R_{ij}$ and scalar curvature.
2) The ground truth is checked on several examples (Einstein metrics): sphere and Lobachevskiy (Hyperbolic) plane. 
3) Computational time for scalar curvature computation is demonstrated.

# I. Imports and some functions for plotting (Skip reading this)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
device = torch.device("cpu")
import torch
import torch.func as TF
from functorch import jacrev,jacfwd
import matplotlib.pyplot as plt
import timeit
import functools

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x
    
decoder = Decoder(encoded_space_dim = 2,fc2_input_dim=128)

# Send to device
decoder.to(device) 

# Load the parameters of the trained decoder without curvature in Loss func
PATH_dec = '../nn_weights/decoder_conv_autoenc.pt'
decoder.load_state_dict(torch.load(PATH_dec))

# Switch to eval mode
decoder.eval()

In [None]:
def make_grid(numsteps, xshift = 0.0, yshift = 0.0):
    
    xs = torch.linspace(-1.5, 1.5, steps = numsteps) + xshift
    ys = torch.linspace(-1.5, 1.5, steps = numsteps) + yshift
    #uniform_grid = torch.cartesian_prod(xs,ys)

    # true grid starts from left bottom corner. x is the first to increase
    tgrid = torch.cartesian_prod(ys, xs)
    tgrid = tgrid.roll(1,1)
    return tgrid

In [None]:
def draw_frob_norm_tensor_on_grid(plot_name,tensor_on_grid, numsteps = 100,xshift = 0.0, yshift = 0.0):
    Frob_norm_on_grid = tensor_on_grid.norm(dim=(1,2)).view(numsteps,numsteps)
    #Frob_norm_on_grid = metric_on_grid.norm(dim=(1,2)).view(numsteps,numsteps)
    Frob_norm_on_grid = Frob_norm_on_grid[1:-1,1:-1].detach()

    fig, ax = plt.subplots()
    im = ax.imshow(Frob_norm_on_grid,origin="lower")

    cbar = ax.figure.colorbar(im)
    
    ax.set_xticks((Frob_norm_on_grid.shape[0]-1)*(np.linspace(0,1,num=11)),labels=(np.linspace(-1.5,1.5,num=11)+xshift).round(1))
    ax.set_yticks((Frob_norm_on_grid.shape[1]-1)*(np.linspace(0,1,num=11)),labels=(np.linspace(-1.5,1.5,num=11)+yshift).round(1))
    plt.xlabel( "x coordinate")
    plt.ylabel( "y coordinate")
    plt.axis('scaled')

    ax.set_title(plot_name)
    fig.tight_layout()
    plt.show()
    return plt

# II. Tensors computed with higher order derivatives using jacfwd

In [None]:
def metric_jacfwd(u, function = decoder, latent_space_dim=2):
    u = u.reshape(-1,latent_space_dim)
    jac = jacfwd(function)(u)
    jac = jac.reshape(-1,latent_space_dim)
    metric = torch.matmul(jac.T,jac)
    return metric

metric_jacfwd_vmap = TF.vmap(metric_jacfwd)

In [None]:
# The variable wrt which 
# the derivative is computed is the last index
def metric_der_jacfwd (u, function = decoder):
    metric = functools.partial(metric_jacfwd, function=function)
    dg = jacfwd(metric)(u).squeeze()
    # squeezing is needed to get rid of 1-dimentions 
    # occuring when using jacfwd
    return dg

$\Gamma^\rho_{\mu\nu} = \frac{1}{2} g^{\rho\sigma}\left(\partial_\mu g_{\sigma\nu} + \partial_\nu g_{\mu\sigma} - \partial_\sigma g_{\mu\nu}\right)$


In [None]:
def Ch_jacfwd (u, function = decoder):
    g = metric_jacfwd(u,function)
    g_inv = torch.inverse(g)
    dg = metric_der_jacfwd(u,function)
    Ch = 0.5*(torch.einsum('im,mkl->ikl',g_inv,dg)+
              torch.einsum('im,mlk->ikl',g_inv,dg)-
              torch.einsum('im,klm->ikl',g_inv,dg)
              )
    return Ch
Ch_jacfwd_vmap = TF.vmap(Ch_jacfwd)

In [None]:
def Ch_der_jacfwd (u, function = decoder):
    Ch = functools.partial(Ch_jacfwd, function=function)
    dCh = jacfwd(Ch)(u).squeeze()
    return dCh
Ch_der_jacfwd_vmap = TF.vmap(Ch_der_jacfwd)

$R^{\rho}_{\sigma\mu\nu} = \partial_{\mu}\Gamma^{\rho}_{\nu\sigma} - \partial_{\nu}\Gamma^{\rho}_{\mu\sigma} + \Gamma^{\rho}_{\mu\lambda}\Gamma^{\lambda}_{\nu\sigma} - \Gamma^{\rho}_{\nu\lambda}\Gamma^{\lambda}_{\mu\sigma}$


In [None]:
# Riemann curvature tensor (3,1)
def Riem_jacfwd(u, function = decoder):
    Ch = Ch_jacfwd(u, function)
    Ch_der = Ch_der_jacfwd(u, function)

    Riem = torch.einsum("iljk->ijkl",Ch_der) - torch.einsum("ikjl->ijkl",Ch_der)
    Riem += torch.einsum("ikp,plj->ijkl", Ch, Ch) - torch.einsum("ilp,pkj->ijkl", Ch, Ch)
    return Riem

$R_{\mu\nu} = R^{\rho}_{\mu\rho\nu}
$

In [None]:
def Ric_jacfwd(u, function = decoder):
    Riemann = Riem_jacfwd(u, function)
    Ric = torch.einsum("cacb->ab",Riemann)
    return Ric
Ric_jacfwd_vmap = TF.vmap(Ric_jacfwd)

In [None]:
# demo
Ric_jacfwd_vmap(torch.rand(3,2))

# III. Ground truth check

### Sphere

In [None]:
# metric generating functions
# u = (\theta, \phi)
# ds^2 = (d\theta)^2 + sin^2(\theta)*(d\phi)^2
def my_fun_sphere(u):
    u = u.flatten()
    
    x = torch.sin(u[0])*torch.cos(u[1])
    y = torch.sin(u[0])*torch.sin(u[1])
    z = torch.cos(u[0])

    x = x.unsqueeze(0)
    y = y.unsqueeze(0)
    z = z.unsqueeze(0)
    output = torch.cat((x, y, z),dim=-1)
    output = torch.cat((output.unsqueeze(0),torch.zeros(781).unsqueeze(0)),dim=1)
    output = output.flatten()
    return output

In [None]:
# metric tensor and ricci tensor ground truth
# ds^2 = (d\theta)^2 + sin^2(\theta)*(d\phi)^2
def metric_sphere(u):
    u = u.flatten()
    #print("\n", u.dtype)
    theta = u[0]
    phi = u[1]
    #metric_tensor = torch.diag(torch.tensor([1.0,0.0])) + torch.diag(torch.Tensor([0.0,1.0]))*(torch.sin(theta))**2
    E_11 = torch.diag(torch.tensor([1.0,0.0],dtype=float))
    E_22 = torch.diag(torch.tensor([0.0,1.0],dtype=float))
    
    metric_tensor = E_11 + E_22*(torch.sin(theta))**2
    #print("\n", metric_tensor.dtype)
    return metric_tensor 
metric_sphere_vmap = TF.vmap(metric_sphere)

In [None]:
# Motivating demo
torch.manual_seed(10)
test_batch = torch.rand(2,2,dtype=float)
metric_jacfwd_example = metric_jacfwd_vmap(test_batch,
                                      function=my_fun_sphere)
Ricci_jacfwd_example = Ric_jacfwd_vmap(test_batch,
                                      function=my_fun_sphere)
metric_exact_example = metric_sphere_vmap(test_batch)
print("metric via jacfwd:\n", metric_jacfwd_example)
print("Ricci via jacfwd:\n", Ricci_jacfwd_example)
print("metric via exact formula:\n", metric_exact_example)
print("Absolute error for metric:\n", metric_jacfwd_example - metric_exact_example)
print("Absolute error for Ricci:\n", Ricci_jacfwd_example - metric_exact_example)

In [None]:
# Historgam of errors. This is done in order to verify metric 
# tensor computation for a 2 dimentional sphere of radius 1
torch.manual_seed(10)

#test_batch = torch.rand(1000,2)
test_batch = torch.rand(1000,2,dtype=float)+torch.tensor([0.1,0.0])
#test_batch.to(torch.float64)

#test_batch = torch.rand(1000,2) @ torch.diag( torch.tensor( [torch.pi, 2*torch.pi] ) )
# Capping
#test_batch[:,0]   = 0.1 + 0.8*test_batch[:,0]
test_metric_jacfwd_array = metric_jacfwd_vmap(test_batch,
                                       function=my_fun_sphere)
test_metric_exact_array = metric_sphere_vmap(test_batch)

# here we check the error in exact metric computation and jacfwd
absolute_error = (test_metric_jacfwd_array - test_metric_exact_array).norm(dim=(1,2))
#relative_error = 100*(test_metric_jacfwd_array - test_metric_exact_array/test_metric_exact_array).norm(dim=(1,2))

array_of_determinants = torch.det(test_metric_exact_array)

print("array of metric determinants:\n mean", array_of_determinants.mean(),
      "\n max",array_of_determinants.max(),
      "\n min", array_of_determinants.min())

print( "Shapes:")
print( absolute_error.shape )
print( "Absolute error:")
print( absolute_error.mean(), absolute_error.max() )

plt.hist(absolute_error,bins=10,density=False,stacked=True)
plt.title("Histogram of frobenius norm of absolute error for metric tensor for a sphere")

In [None]:
# Historgam of errors. This is done in order to verify Ricci 
# tensor computation. Ric = k*g,with k = n-1 for an 
# n-dimentional sphere S^n. Thus if n = 2, Ric = g
torch.manual_seed(10)

test_batch = torch.rand(1000,2, dtype=torch.float64)+torch.tensor([0.1,0.0])

#test_batch = torch.rand(1000,2) @ torch.diag( torch.tensor( [torch.pi, 2*torch.pi] ) )
# Capping
#test_batch[:,0]   = 0.1 + 0.8*test_batch[:,0]
test_Ricci_jacfwd_array = Ric_jacfwd_vmap(test_batch,
                                       function=my_fun_sphere)
test_metric_exact_array = metric_sphere_vmap(test_batch)

# here we check the error in exact metric computation and jacfwd
absolute_error = (test_Ricci_jacfwd_array - test_metric_exact_array).norm(dim=(1,2))

array_of_determinants = torch.det(test_metric_exact_array)

print("array of metric determinants:\n mean", array_of_determinants.mean(),
      "\n max",array_of_determinants.max(),
      "\n min", array_of_determinants.min())

print( "Shapes:")
print( absolute_error.shape )
print( "Absolute error:")
print( absolute_error.mean(), absolute_error.max() )

plt.hist(absolute_error,bins=10,density=False,stacked=True)
plt.title("Histogram of frobenius norm of absolute error for Ricci tensor for a sphere")

### Lobachevsky plane

In [None]:
# Partial embedding (valid for y>c) of Lobachevsky plane to R^3 
# (formally here it is R^784)
# ds^2 = 1/y^2(dx^2 + dy^2)
# http://www.antoinebourget.org/maths/2018/08/08/embedding-hyperbolic-plane.html
def my_fun_lobachevsky(u, c = torch.tensor(0.01,dtype=torch.float64)):
    #cnew=torch.tensor(0.01,dtype=torch.float64)
    u = u.flatten()
    x = u[0]
    y = u[1]
    #print(" x and y shapes:", x, cnew)
    #print("{:.20f}".format(c))
    t = torch.acosh(y/c)
    x0 = t - torch.tanh(t)
    x1 = (1/torch.sinh(t))*torch.cos(x/c)
    x2 = (1/torch.sinh(t))*torch.sin(x/c)
    output = torch.cat((x0.unsqueeze(0),x1.unsqueeze(0),x2.unsqueeze(0)),dim=-1)
    #output = torch.cat((output.unsqueeze(0),torch.zeros(781).unsqueeze(0)),dim=1)
    output = output.flatten()
    return output

In [None]:
# metric tensor and ricci tensor ground truth
# ds^2 = \frac{dx^2 + dy^2}{y^2}
def metric_lobachevsky(u):
    u = u.flatten()

    #x = u[0]
    y = u[1]

    E= torch.eye(2,dtype=torch.float64)
    
    metric_tensor = E*(1/y**2)
    return metric_tensor
metric_lobachevsky_vmap = TF.vmap(metric_lobachevsky)

In [None]:
# Motivating demo
torch.set_printoptions(precision=16)


torch.manual_seed(10)
test_batch = torch.rand(2,2,dtype=float)+torch.tensor([0.0,0.5],dtype=torch.float64)

metric_jacfwd_example = metric_jacfwd_vmap(test_batch,
                                      function=my_fun_lobachevsky)
Ricci_jacfwd_example = Ric_jacfwd_vmap(test_batch,
                                      function=my_fun_lobachevsky)
metric_exact_example = metric_lobachevsky_vmap(test_batch)


print("metric via jacfwd:\n", metric_jacfwd_example)
print("Ricci via jacfwd:\n", Ricci_jacfwd_example)
print("metric via exact formula:\n", metric_exact_example)
print("Absolute error for metric:\n", metric_jacfwd_example - metric_exact_example)
print("Absolute error for Ricci:\n", Ricci_jacfwd_example + metric_exact_example)

In [None]:
# Historgam of errors. This is done in order to verify metric 
# tensor computation for hyperbolic plane
torch.manual_seed(10)

#test_batch = torch.rand(1000,2)
test_batch = torch.rand(1000,2,dtype=float)+torch.tensor([0.0,1.0],dtype=torch.float64)

#test_batch = torch.rand(1000,2) @ torch.diag( torch.tensor( [torch.pi, 2*torch.pi] ) )
# Capping
#test_batch[:,0]   = 0.1 + 0.8*test_batch[:,0]
test_metric_jacfwd_array = metric_jacfwd_vmap(test_batch,
                                       function=my_fun_lobachevsky)
test_metric_exact_array = metric_lobachevsky_vmap(test_batch)

# here we check the error in exact metric computation and jacfwd
absolute_error = (test_metric_jacfwd_array - test_metric_exact_array).norm(dim=(1,2))


array_of_determinants = torch.det(test_metric_exact_array)
array_of_frobenius_norm = torch.norm(test_metric_exact_array, dim=(1,2))

print("array of metric determinants:\n mean", array_of_determinants.mean(),
      "\n max",array_of_determinants.max(),
      "\n min", array_of_determinants.min())

print("array of metric frobenius norms:\n mean", array_of_frobenius_norm.mean(),
      "\n max",array_of_frobenius_norm.max(),
      "\n min", array_of_frobenius_norm.min())

print( "Shapes:")
print( absolute_error.shape )
print( "Absolute error:")
print( absolute_error.mean(), absolute_error.max() )

plt.hist(absolute_error,bins=10,density=False,stacked=True)
plt.title("Histogram of frobenius norm of absolute error for metric tensor for the hyperbolic plane")

In [None]:
# Historgam of errors. This is done in order to verify Ricci 
# tensor computation. Ric = -g for a 2-dimentional hyperbolic plane
torch.manual_seed(10)

test_batch = torch.rand(1000,2, dtype=torch.float64)+torch.tensor([0.0,1.0],dtype=torch.float64)

test_Ricci_jacfwd_array = Ric_jacfwd_vmap(test_batch,
                                       function=my_fun_lobachevsky)
test_metric_exact_array = metric_lobachevsky_vmap(test_batch)

# here we check the error in exact metric computation and jacfwd
absolute_error = (test_Ricci_jacfwd_array + test_metric_exact_array).norm(dim=(1,2))

array_of_determinants = torch.det(test_metric_exact_array)
array_of_frobenius_norm = torch.norm(test_metric_exact_array, dim=(1,2))

print("array of metric determinants:\n mean", array_of_determinants.mean(),
      "\n max",array_of_determinants.max(),
      "\n min", array_of_determinants.min())

print("array of metric frobenius norms:\n mean", array_of_frobenius_norm.mean(),
      "\n max",array_of_frobenius_norm.max(),
      "\n min", array_of_frobenius_norm.min())

print( "Shapes:")
print( absolute_error.shape )
print( "Absolute error:")
print( absolute_error.mean(), absolute_error.max() )

plt.hist(absolute_error,bins=10,density=False,stacked=True)
plt.title("Histogram of frobenius norm of absolute error for Ricci tensor for the hyperbolic plane")

# IV. Comparing metric and Ricci tensors for Einstein metrics

### Sphere

In [None]:
# Motivating demo
torch.manual_seed(10)
test_batch = torch.rand(3,2)
print("metric:\n", metric_jacfwd_vmap(test_batch,
                                      function=my_fun_sphere))
print("Ricci tensor:\n", Ric_jacfwd_vmap(test_batch,
                                      function=my_fun_sphere))

In [None]:
# Historgam of errors. This is done in order to verify Ricci 
# tensor computation. Ric = k*g,with k = n-1 for an 
# n-dimentional sphere S^n. Thus if n = 2, Ric = g
torch.manual_seed(10)

test_batch = torch.rand(1000,2, dtype=torch.float64)+0.1
#@ torch.diag( torch.tensor( [torch.pi, 2*torch.pi] ) )
# Capping
#test_batch[:,0]   = 0.1 + 0.8*test_batch[:,0]
test_metric_array = metric_jacfwd_vmap(test_batch,
                                       function=my_fun_sphere)
test_Ric_array = Ric_jacfwd_vmap(test_batch,
                                       function=my_fun_sphere)

# here we check if g = Ric
absolute_error = (test_metric_array - test_Ric_array).norm(dim=(1,2))
relative_error = 100*absolute_error/(test_metric_array.norm(dim=(1,2)))

print( "Shapes:")
print( absolute_error.shape )
print( "Absolute error:")
print( absolute_error.mean(), absolute_error.max() )
print( "Relative error:")
print( relative_error.mean(), relative_error.max() )

plt.hist(absolute_error,bins=10,density=False,stacked=True)

In [None]:
# compare frobenius norm heatmaps of the metric 
# and the Ricci tensor. For the sphere they should coincide
numsteps = 100
tgrid = make_grid(numsteps)
Ric_on_grid = Ric_jacfwd_vmap(tgrid, function=my_fun_sphere)
metric_on_grid = metric_jacfwd_vmap(tgrid, function=my_fun_sphere)

draw_frob_norm_tensor_on_grid(plot_name = 'Frobenius norm of the metric',
                              tensor_on_grid= metric_on_grid, 
                              numsteps=numsteps)
draw_frob_norm_tensor_on_grid(plot_name = 'Frobenius norm of the Ricci tensor',
                              tensor_on_grid= Ric_on_grid, 
                              numsteps=numsteps)

### Lobachevsky plane

In [None]:
# Motivating demo
torch.manual_seed(10)
test_batch = torch.rand(3,2)
print("metric:\n", metric_jacfwd_vmap(test_batch,
                                      function=my_fun_lobachevsky))
print("Ricci tensor:\n", Ric_jacfwd_vmap(test_batch,
                                      function=my_fun_lobachevsky))

In [None]:
# Historgam of errors. This is done in order to verify Ricci 
# tensor computation. Ric = k*g,with k = -1 
# for the Lobachevsky plane. Thus if Ric = -g
torch.manual_seed(10)

test_batch = 10*torch.rand(1000,2,dtype=torch.float64) + 0.5 
# we use shift because y>0 for this model

test_metric_array = metric_jacfwd_vmap(test_batch,
                                       function=my_fun_lobachevsky)
test_Ric_array = Ric_jacfwd_vmap(test_batch,
                                       function=my_fun_lobachevsky)

# here we check if g = - Ric
absolute_error = (test_metric_array + test_Ric_array).norm(dim=(1,2))
relative_error = 100*absolute_error/(test_metric_array.norm(dim=(1,2)))

print( "Shapes:")
print( absolute_error.shape )
print( "Absolute error:")
print( absolute_error.mean(), absolute_error.max() )
print( "Relative error:")
print( relative_error.mean(), relative_error.max() )



plt.hist(absolute_error,bins=10,density=False,stacked=True)

In [None]:
# compare frobenius norm heatmaps of the metric 
# and the Ricci tensor. For the Lobachevsky plane they should coincide
numsteps = 100
tgrid = make_grid(numsteps, xshift=0.0, yshift=1.7)

lobachevsky_metric_on_grid = metric_jacfwd_vmap(tgrid, function=my_fun_lobachevsky)
lobachevsky_Ric_on_grid = Ric_jacfwd_vmap(tgrid, function=my_fun_lobachevsky)

draw_frob_norm_tensor_on_grid(plot_name = 'Lobachevsky plane: Frobenius norm of the metric',
                              tensor_on_grid=lobachevsky_metric_on_grid,
                            numsteps= numsteps, xshift=0.0, yshift=1.7)
draw_frob_norm_tensor_on_grid(plot_name = 'Lobachevsky plane: Frobenius norm of the Ricci tensor',
                              tensor_on_grid=lobachevsky_Ric_on_grid,
                            numsteps= numsteps, xshift=0.0, yshift=1.7)


# V. The Ricci tensor for the metric given by the pullback of the decoder

In [None]:
# this takes around 17 secs
numsteps = 100
grid = make_grid(numsteps)
Decoder_Ric_on_grid = Ric_jacfwd_vmap(grid,function=decoder)
draw_frob_norm_tensor_on_grid(plot_name='Latent space: Frobenius norm of the Ricci tensor',
                              tensor_on_grid=Decoder_Ric_on_grid,
                              numsteps=numsteps)

# VI. Newton method convergence etc.

In [None]:
x = 5.0
n = 1000000
# Ground truth
y = np.exp(x)
print( y )
# Approx
z = (1+x/n)**n
print( z )
print( "Error: ", y-z)


In [None]:
# Ground truth 
y = np.sqrt(x)
print( "Ground truth ")
print( y )
# Approx via Newton-Raphson
z = x
for i in range(7):
    print( "" )
    #print( "Iteration ", i)
    z = 0.5*( z + x/z )
    print( z )
    print( "Error: ", y-z) 


# VII. Understanding how precision works

In [None]:
test_tensor = torch.rand(10)
h = 0.0001

cos_fd = lambda x: (torch.sin(x+h)-torch.sin(x))/h

cos_jacfwd = jacfwd(torch.sin)
cos_jacfwd_vmap = TF.vmap(cos_jacfwd)

result_jacfwd = cos_jacfwd_vmap(test_tensor)
result_fd = cos_fd(test_tensor)
result_exact = torch.cos(test_tensor)

absolute_error = result_fd-result_exact

relative_error = 100*absolute_error/result_exact

print( "Absolute error fd vs exact:")
print( absolute_error.mean(), absolute_error.max() )
print( "Relative error fd vs exact:")
print( relative_error.mean(), relative_error.max() )


absolute_error = result_jacfwd-result_exact

relative_error = 100*absolute_error/result_exact

print( "Absolute error jacfwd vs exact:")
print( absolute_error.mean(), absolute_error.max() )
print( "Relative error jacfwd vs exact:")
print( relative_error.mean(), relative_error.max() )

In [None]:
#ex = torch.tensor([1e-6])
# float32 by default
ex = torch.tensor([1e-6],dtype=float)

In [None]:
#ex*=0.1
print(ex)
print("{:.20f}".format(float(ex)))

# VIII. Timing Scalar curvature computation

In [None]:
import sys
sys.path.append('../') # have to go 1 level up
import ricci_regularization as RR

In [None]:
import timeit
times_to_repeat = 10
grid_linear_size = 100
time = timeit.timeit(stmt="RR.Sc_jacfwd_vmap(RR.make_grid(grid_linear_size),function=decoder)",
                                      number = times_to_repeat, globals=globals())

In [None]:
print("It takes", time/times_to_repeat, "seconds to compute")
print("Scalar curvature for a grid of linear size:", grid_linear_size)