NB! This is a very old notebook.

The latent space of the AE is $R^d$. We define a Riemannian metric in a local chart of the latent space as the pull-back of the Euclidean metric in the output space $R^D$ by the decoder function $\Psi$ of the AE:
\begin{equation*}
    g = \nabla \Psi ^* \nabla \Psi \ .  
\end{equation*}

Here computational time for computing scalar curvature (https://en.wikipedia.org/wiki/Scalar_curvature) is measured for different latent space dimension $d$.

One can switch between 2 curvature computation modes: 
1) PyTorch back-propagation tool: torch.func.torch.func.jacrev. 
See https://pytorch.org/docs/stable/generated/torch.func.torch.func.jacrev.html
2) PyTorch forward-propagation tool: torch.func.jacfwd. 
See https://pytorch.org/functorch/stable/generated/functorch.jacfwd.html


In [None]:
# Minimal imports
import timeit
import math
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

# adding path to the set generating package
import sys
sys.path.append('../') # have to go 1 level up

In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.linear4 = nn.Linear(512, output_dim)
        self.activation = torch.sin
        #self.activation = torch.nn.ReLU()
    def forward(self, x):
        y = self.linear1(x)
        y = self.activation(y)
        y = self.linear2(y)
        y = self.activation(y)
        y = self.linear3(y)
        y = self.activation(y)
        out = self.linear4(y)
        #out = self.activation(out)
        #out = torch.sigmoid(y)
        return out

In [None]:
"""
D = 784
d = 6
decoder = Decoder(d,D)
x = torch.rand(d)
ricci_regularization.Sc_jacrev(x, function = decoder)
"""

In [None]:
times_to_repeat = 20
#jacfwd_timer = []
jacrev_timer = []
D = 32*32

hidden_dim_array = np.array([2,3,4,5,6])

for d in hidden_dim_array:
    x = torch.rand(d)
    decoder = Decoder(d,D)
    #jacfwd_timer.append(timeit.timeit(stmt="ricci_regularization.Sc_jacfwd(x, function = decoder)",number=times_to_repeat,globals=globals())/times_to_repeat)
    jacrev_timer.append(timeit.timeit(stmt="ricci_regularization.Sc_jacrev(x, function = decoder)",number=times_to_repeat,globals=globals())/times_to_repeat)
#print("jacfwd time:",jacfwd_timer)
print("torch.func.jacrev time:",jacrev_timer)

In [None]:
import json

In [None]:
# torch.func.jacrev timing for d = 2,  3,  4,  5,  6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54,
#       58, 62, 66, 70, 74, 78, 82, 86, 90, 94, 100
#with open("jacfwd_timing", "w") as fp:
#    json.dump(new_jacfw_timer.tolist(), fp)

In [None]:
# this is jacrev_timing for d = 2, 3, 4, 5, 6
#with open("jacrev_timing", "r") as fp:
#    b = json.load(fp)
#b

In [None]:
times_to_repeat = 10
jacfwd_timer = []
#D = 784

#hidden_dim_array = (np.arange(100)+1)[1::4]
hidden_dim_array = np.array([ 2,  3,  4,  5,  6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54,
       58, 62, 66, 70, 74, 78, 82, 86, 90, 94, 100])

for d in hidden_dim_array:
    x = torch.rand(d)
    decoder = Decoder(d,D)
    jacfwd_timer.append(timeit.timeit(stmt="ricci_regularization.Sc_jacfwd(x, function = decoder)",number=times_to_repeat,globals=globals())/times_to_repeat)
print("jacfwd time:",jacfwd_timer)

In [None]:
plt.rcParams.update({'font.size': 24}) # makes all fonts on the plot be 24
plt.figure(figsize=(9,9),dpi=300)
plt.semilogy(jacfwd_timer,label="jacfwd", marker='o')
plt.semilogy(jacrev_timer,label="torch.func.jacrev", marker='o')
#plt.xticks(((hidden_dim_array-hidden_dim_array[0])/4)[::4],labels=hidden_dim_array[::4],rotation = 0)
plt.xticks(np.array([0,4,10,15,20,27]),labels=([2,6,30,50,70,100]))
#plt.xticks(np.array([0,3,6,9,12,24]),labels=([2,6,10,30,50,100]))
plt.title(f"Scalar curvature evaluation time for D={D}.")
plt.xlabel("Latent space dimension d")
plt.ylabel("Log of time in seconds")
plt.legend(loc="lower right")
#plt.savefig(f'jacrev_jacfwd_time_D={D}.pdf',bbox_inches='tight',format='pdf')
plt.show()