In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
device = torch.device("cpu")
import torch
import torch.func as TF
import torch.autograd.functional as AF
import torch.autograd as TA
from functorch import jacrev,jacfwd
import matplotlib.pyplot as plt
import timeit

# I. Tutorial

Introducing the Decoder used in ConvAE

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x
    
decoder = Decoder(encoded_space_dim = 2,fc2_input_dim=128)
# Send to device
decoder.to(device) 

# Load the parameters of the trained decoder without curvature in Loss func
PATH_dec = '../nn_weights/decoder_conv_autoenc.pt'
decoder.load_state_dict(torch.load(PATH_dec))

# Switch to eval mode
decoder.eval()

## vmap

In [None]:
torch.dot                            # [D], [D] -> []
batched_dot = torch.func.vmap(torch.dot)  # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
a = batched_dot(x, y)

In [None]:
print(a)

In [None]:
b = torch.sum(x*y,dim = 1)
print(b)

In [None]:
torch.equal(a,b)

In [None]:
a-b

## Another example of vmap

In [None]:
f = lambda x : x**2

In [None]:
f(5)

In [None]:
f_vectorized = TF.vmap(f)

In [None]:
f_vectorized(torch.rand(3,3))

# II. vmap for computing the metric using finite differences

In [None]:
#Let us take a uniform grid on the latent space. Note that here d=2. The bounds for the grid can be taken from 3 sigma rule. 
#We will take 2 sigmas however
numsteps =10

def make_grid(numsteps):
    
    xs = torch.linspace(-1.5, 1.5, steps = numsteps)
    ys = torch.linspace(-1.5, 1.5, steps = numsteps)
    #uniform_grid = torch.cartesian_prod(xs,ys)

    # true grid starts from left bottom corner. x is the first to increase
    tgrid = torch.cartesian_prod(ys, xs)
    tgrid = tgrid.roll(1,1)
    return tgrid

grid = make_grid(numsteps)
numsteps = int(np.sqrt(grid.shape[0]))
    
hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)

### Computing metric at one point with finite differences

In [None]:
# numerical metric computation
def metric_num_smart(u, hx=hx, hy=hy): #this gives metric at point u with steps hx and hy
    u = u.reshape(-1,2)

    dx = torch.tensor([[0.0 + hx, 0.0]])
    dy = torch.tensor([[0.0, 0.0 + hy]])
    dpsi_over_dx = (decoder(u + dx) - decoder(u - dx))/(2*hx)
    dpsi_over_dy = (decoder(u + dy) - decoder(u - dy))/(2*hy)

    dpsi_over_dx = torch.flatten(dpsi_over_dx).view(784,1)
    dpsi_over_dy = torch.flatten(dpsi_over_dy).view(784,1)
    
    dpsi = torch.cat((dpsi_over_dx,dpsi_over_dy),dim=-1)
    my_metric = torch.matmul(dpsi.T,dpsi)
    return my_metric

In [None]:
# numerical metric computation
def metric_num(u, hx=hx, hy=hy): #this gives metric at point u with steps hx and hy
    u = u.reshape(-1,2)
    #dx = torch.tensor([[0.0 + hx, 0.0]])
    # u -> p and so on...
    dpsidx = (decoder(u + torch.tensor([[0.0 + hx, 0.0]])) - decoder(u + torch.tensor([[0.0 - hx, 0.0]])))/(2*hx)
    dpsidy = (decoder(u + torch.tensor([[0.0, 0.0 + hy]])) - decoder(u + torch.tensor([[0.0, 0.0 - hy]])))/(2*hy)
    #my_metric = torch.tensor([[(dpsidx*dpsidx).sum(),(dpsidx*dpsidy).sum()],
    #                         [(dpsidx*dpsidy).sum(),(dpsidy*dpsidy).sum()]])
    #torch.dot(dpsidx,dpsidx)
    #return my_metric
    #dpsidx - > dpsi_over_dx
    dpsidx = torch.flatten(dpsidx)
    dpsidy = torch.flatten(dpsidy)
    #print(dpsidx.shape)
    g11 = torch.dot(dpsidx,dpsidx)
    g12 = torch.dot(dpsidx,dpsidy)
    g22 = torch.dot(dpsidy,dpsidy)
    #dpsi = torch.cat((dpsidx,dpsidy),dim=-1)
    #print(dpsi.shape)
    #my_metric = torch.matmul(dpsi.T,dpsi)
    #my_metric = torch.tensor([[g11,g12],
    #                         [g12,g22]])
    #return my_metric
    #return decoder(u)
    return g11, g12, g12, g22

In [None]:
h = 1e-5 #step
#u = torch.tensor([[0.,0.],[0.,0.],[0.,0.]]) # point
u = torch.tensor([0.,0.])
with torch.no_grad():
    print(metric_num(u,h,h))
    print(metric_num_smart(u,h,h))

### Vectorized numerical computation with vmap

In [None]:
metric_num_vectorized = TF.vmap(metric_num)

In [None]:
N = 100*100 # number of points
with torch.no_grad():
    mertic_numerical_list = metric_num_vectorized(torch.rand(N,2))
    #print(mertic_numerical_list)

In [None]:
# rearrange the array of metrics
def turn_metric_to_tensor(mertic_numerical_list):
    N = len(mertic_numerical_list[0])
    metric_as_tensor = torch.cat(mertic_numerical_list)
    metric_as_tensor = metric_as_tensor.reshape(-1,N)
    #print(torch.equal(metric_as_tensor[1], metric_as_tensor[2])) # check symmetry
    metric_as_tensor = metric_as_tensor.reshape(2,2,-1)
    metric_as_tensor = metric_as_tensor.T
    #metric_as_tensor [:5] # see first 5 metric matrices
    return metric_as_tensor

In [None]:
metric_num_smart_vectorized = TF.vmap(metric_num_smart)

In [None]:
N = 100*100 # number of points
with torch.no_grad():
    mertic_numerical_smart_list = metric_num_smart_vectorized(torch.rand(N,2))
    #print(mertic_numerical_list)

In [None]:
mertic_numerical_list[0]

In [None]:
mertic_numerical_smart_list[0]

### Benchmarking  f.d. on a grid with torch.vmap VS using torch.roll

In [None]:
#metric on a grid

def g(grid):
    numsteps = int(np.sqrt(grid.shape[0]))
    
    hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
    hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)
    
    latent = grid
    latent = latent.to(device)
    psi = decoder(latent)
    psi_next_x =  psi.roll(-1,0)
    psi_prev_x =  psi.roll(1,0)
    psi_next_y =  psi.roll(-numsteps,0)
    psi_prev_y =  psi.roll(numsteps,0)
    
    dpsidx = (psi_next_x - psi_prev_x)/(2*hx)
    dpsidy = (psi_next_y - psi_prev_y)/(2*hy)
    
    metric = torch.cat(((dpsidx*dpsidx).sum((1,2,3)),(dpsidx*dpsidy).sum((1,2,3)),(dpsidx*dpsidy).sum((1,2,3)),(dpsidy*dpsidy).sum((1,2,3))),0)
    metric = metric.view(4, numsteps*numsteps)
    metric = metric.transpose(0, 1)
    return metric

In [None]:
# compute the grid of metric
tgrid = make_grid(numsteps)

with torch.no_grad():
    metric_torchroll = g(tgrid)
    metric_torchroll = metric_torchroll.view(numsteps*numsteps, 2, 2)

In [None]:
tgrid.shape

In [None]:
metric_torchroll.shape

In [None]:
with torch.no_grad():
    metric_torchvmap_list = metric_num_vectorized(tgrid)
metric_torchvmap = turn_metric_to_tensor(metric_torchvmap_list)

In [None]:
torch.equal(metric_torchvmap,metric_torchroll) #errors on the border for torch.roll 

In [None]:
torch.max(metric_torchroll-metric_torchvmap) # errors on the border for torch.roll

In [None]:
# no border
metric_torchvmap_no_border = metric_torchvmap.view(numsteps,numsteps,2,2)[1:-1,1:-1]
metric_torchroll_no_border = metric_torchroll.view(numsteps,numsteps,2,2)[1:-1,1:-1]
#Newfrob = metric.norm(dim=(1,2)).view(numsteps,numsteps)
#Newfrob = Newfrob[1:-1,1:-1].transpose(0,1)

In [None]:
print("L1 error:", float(torch.max(metric_torchroll_no_border - metric_torchvmap_no_border))) # no error in L1

In [None]:
size = len((metric_torchroll_no_border - metric_torchvmap_no_border).flatten())
print("MSE:",float(((metric_torchroll_no_border - metric_torchvmap_no_border)**2).sum()/size))

In [None]:
numsteps = 10
tgrid = make_grid(numsteps=numsteps)

#with_torchroll = Timer(stmt="g(tgrid)", globals=globals())
#with_vmap = Timer(stmt="metric_num_vectorized(tgrid)", globals=globals())

with_torchroll_timer = timeit.timeit(stmt="g(tgrid)",number=100,globals=globals())
with_vmap_timer = timeit.timeit(stmt="metric_num_smart_vectorized(tgrid)",number=100, globals=globals())

print("using torch.roll:", with_torchroll_timer)
print("using vmap:",with_vmap_timer)

In [None]:
numstep_array = np.linspace(10,300,30).astype(int)
numstep_array

In [None]:
computation_time_roll =[]
computation_time_vmap =[]

for i in numstep_array:
    numsteps = i
    tgrid = make_grid(numsteps=numsteps)

    with_torchroll_timer = timeit.timeit(stmt="g(tgrid)",number=1,globals=globals())
    with_vmap_timer = timeit.timeit(stmt="metric_num_vectorized(tgrid)",number=1, globals=globals())

    computation_time_roll.append(with_torchroll_timer)
    computation_time_vmap.append(with_vmap_timer)

In [None]:
computation_time_vmap

In [None]:
numstep_array[::3]

In [None]:
plt.plot(computation_time_roll,label="toch.roll")
plt.plot(computation_time_vmap,label="vmap")
plt.xticks(numstep_array[::3]/10, labels=numstep_array[::3])
plt.title("Comparison of torch.roll and vmap performance")
plt.xlabel("Linear size of the grid")
plt.ylabel("Time in seconds")
plt.legend()
plt.show()

# III. Selecting a method for automatic differentiation using autograd.grad

Conclusion is to use jacfwd + vmap

In [None]:
input = torch.tensor([3.,5.],requires_grad=True)

In [None]:
f(input)

In [None]:
torch.autograd.grad(f(input)[0],input)

In [None]:
def g(tensor):
    return torch.sum(tensor*tensor)

In [None]:
torch.autograd.grad(g(input),input)

### autograd.functional.jacobian

In [None]:
input = torch.tensor(0.)

AF.jacobian(torch.sin, input)

In [None]:
input = torch.ones(10)
input.shape


In [None]:
AF.jacobian(torch.sin, input).shape

In [None]:
input = input.reshape(-1,2)
decoder(input).shape
input.shape
AF.jacobian(decoder, input).shape

### Vectorizing autograd with vmap

In [None]:
torch.manual_seed(0)
input = torch.rand(10,2).requires_grad_(True)

In [None]:
decoder(input).shape

In [None]:
TA.grad(decoder(input)[0,0,0,0],input)

### Vectorizing autograd.jacobian with vmap is unsupported

In [None]:
def decoder_auto_jacobian(input):
    input = input.reshape(-1,2)
    decoder(input).shape
    return AF.jacobian(decoder, input)

In [None]:
decoder_auto_jacobian(torch.rand(1,2)).shape

In [None]:
decoder_auto_jacobian_vectorized = TF.vmap(decoder_auto_jacobian)

In [None]:
decoder_auto_jacobian_vectorized(torch.rand(10,2))

### jacrev and jacfwd

In [None]:
x = torch.randn(5)
jacobian = jacrev(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)
print(jacobian)
print(jacobian.shape)

### jacrev+vmap

In [None]:
def decoder_auto_jacrev(input):
    input = input.reshape(-1,2)
    decoder(input).shape
    return jacrev(decoder)(input)

In [None]:
decoder_auto_jacrev(torch.rand(5,2)).shape #it works

In [None]:
decoder_auto_jacrev_vectorized = TF.vmap(decoder_auto_jacrev)

In [None]:
decoder_auto_jacrev_vectorized(torch.rand(10,2)).shape # it works!!!

In [None]:
# What about bigger stuff??
N = 1000
decoder_auto_jacrev_vectorized(torch.rand(N,2)).shape # it works in 7.2 secs!!!

### jacfwd + vmap

In [None]:
def decoder_auto_jacfwd(input):
    input = input.reshape(-1,2)
    return jacfwd(decoder)(input)
decoder_auto_jacfwd_vectorized = TF.vmap(decoder_auto_jacfwd)
decoder_auto_jacfwd_vectorized(torch.rand(10,2)).shape # it works!!!

In [None]:
# What about bigger stuff??
N = 10000
decoder_auto_jacfwd_vectorized(torch.rand(N,2)).shape # it works in 0.1 secs!!!

In [None]:
# Conclusion: Reda was right the jacfwd+vmap seems to be the thing we need

# IV. Metric using jacfwd + vmap

In [None]:
#input = torch.rand(1,2)
input = torch.zeros(1,2)

In [None]:
def metric_jacfwd(input):
    input = input.reshape(-1,2)
    jac = jacfwd(decoder)(input)
    jac = jac.reshape(-1,2)
    metric = torch.matmul(jac.T,jac)
    return metric

## IV. 1. Comparing with finite differences

In [None]:
metric_jacfwd(input)

In [None]:
precision = 1e-2
result = metric_num(input, precision, precision)
torch.tensor( result ).view(2,2)

## IV.2. Metric with jacfwd + vmap

In [None]:
metric_jacfwd_vectorized = TF.vmap(metric_jacfwd)

In [None]:
numstep_array = np.linspace(10,210,num=5).astype(int)
numstep_array

In [None]:

#computation_time_roll =[]
computation_time_vmap =[]
computation_time_jacfwd =[]

for i in numstep_array:
    numsteps = i
    tgrid = make_grid(numsteps=numsteps)

    #with_torchroll_timer = timeit.timeit(stmt="g(tgrid)",number=1,globals=globals())
    with_vmap_timer = timeit.timeit(stmt="metric_num_vectorized(tgrid)",number=1, globals=globals())
    with_jacfwd_timer = timeit.timeit(stmt="metric_jacfwd_vectorized(tgrid)",number=1,globals=globals())

    computation_time_vmap.append(with_vmap_timer)
    computation_time_jacfwd.append(with_jacfwd_timer)

In [None]:
plt.plot(computation_time_jacfwd,label="vmap+jacfwd")
plt.plot(computation_time_vmap,label="Finite differences using vmap")
plt.xticks((numstep_array-10)/50, labels=numstep_array)
plt.title("Comparison of jacfwd and numerical vmap performance")
plt.xlabel("Linear size of the grid")
plt.ylabel("Time in seconds")
plt.legend()
plt.show()

### why jacrev is not an option

In [None]:
def metric_jacrev(input):
    input = input.reshape(-1,2)
    jac = jacrev(decoder)(input)
    jac = jac.reshape(-1,2)
    metric = torch.matmul(jac.T,jac)
    return metric
metric_jacrev_vectorized = TF.vmap(metric_jacrev)

In [None]:
# metric_jacrev_vectorized(torch.rand(1600,2)).shape # this is not computable

In [None]:
numstep_array = np.linspace(10,30,num=6).astype(int)
numstep_array

In [None]:
#computation_time_roll =[]
computation_time_jacrev =[]
computation_time_jacfwd =[]

for i in numstep_array:
    numsteps = i
    tgrid = make_grid(numsteps=numsteps)

    #with_torchroll_timer = timeit.timeit(stmt="g(tgrid)",number=1,globals=globals())
    with_jacrev_timer = timeit.timeit(stmt="metric_jacrev_vectorized(tgrid)",number=1, globals=globals())
    with_jacfwd_timer = timeit.timeit(stmt="metric_jacfwd_vectorized(tgrid)",number=1,globals=globals())

    computation_time_jacrev.append(with_jacrev_timer)
    computation_time_jacfwd.append(with_jacfwd_timer)

In [None]:
plt.plot(computation_time_jacfwd,label="vmap+jacfwd")
plt.plot(computation_time_jacrev,label="vmap+jacrev")
plt.xticks((numstep_array-10)/4, labels=numstep_array)
plt.title("Comparison of jacfwd and jacrev")
plt.xlabel("Linear size of the grid")
plt.ylabel("Time in seconds")
plt.legend()
plt.show()

# V. Frobenius norm of the metric

In [None]:
numsteps = 200
tgrid = make_grid(numsteps)

## V.1. With jacfwd+vmap

In [None]:
numsteps = 100
tgrid = make_grid(numsteps)

In [None]:
metric_on_grid_jacfwd = metric_jacfwd_vectorized(tgrid).detach()

In [None]:
xs = torch.linspace(-1.5, 1.5, steps = numsteps)
ys = torch.linspace(-1.5, 1.5, steps = numsteps)

# Fast computation of Frobenious norm on the grid without borders
Newfrob1 = metric_on_grid_jacfwd.norm(dim=(1,2)).view(numsteps,numsteps)
Newfrob1 = Newfrob1[1:-1,1:-1].transpose(0,1)
#Heat map of the frobenius norm
h = plt.contourf(xs[1:-1], ys[1:-1], Newfrob1)
plt.title('Heatmap of the Frobenius norm of the metric')
plt.xlabel( "x coordinate")
plt.ylabel( "y coordinate")
plt.axis('scaled')
plt.colorbar(label="Frobenius norm of the metric")
#plt.xlim(-1.5 + mean[0], 1.5 + mean[0])
#plt.ylim(-1.5 + mean[1], 1.5 + mean[1])
plt.show()

## V.2. With f.d. + torch.roll

In [None]:
numsteps = 200
tgrid = make_grid(numsteps)

In [None]:
with torch.no_grad():
    metric_torchroll = g(tgrid)
    metric_torchroll = metric_torchroll.view(numsteps*numsteps, 2, 2)

In [None]:
xs = torch.linspace(-1.5, 1.5, steps = numsteps)
ys = torch.linspace(-1.5, 1.5, steps = numsteps)

# Fast computation of Frobenious norm on the grid without borders
Newfrob2 = metric_torchroll.norm(dim=(1,2)).view(numsteps,numsteps)
#Newfrob2 = metric_torchroll.view(numsteps,numsteps,2,2)[1:-1,1:-1].norm(dim=(2,3)).transpose(0,1)

Newfrob2 = Newfrob2[1:-1,1:-1].transpose(0,1)
#Heat map of the frobenius norm
h = plt.contourf(xs[1:-1], ys[1:-1],1e+4*Newfrob2)
plt.title('Heatmap of the Frobenius norm of the metric')
plt.xlabel( "x coordinate")
plt.ylabel( "y coordinate")
plt.axis('scaled')
plt.colorbar(label="Frobenius norm of the metric")
#plt.xlim(-1.5 + mean[0], 1.5 + mean[0])
#plt.ylim(-1.5 + mean[1], 1.5 + mean[1])
plt.show()

### plotting the difference

In [None]:
#Heat map of the frobenius norm
h = plt.contourf(xs[1:-1], ys[1:-1],100*abs(Newfrob1-Newfrob2)/Newfrob1)
plt.title('Heatmap of relative error')
plt.xlabel( "x coordinate")
plt.ylabel( "y coordinate")
plt.axis('scaled')
plt.colorbar(label="Relative error for the Frobenius norm of the metric")
#plt.xlim(-1.5 + mean[0], 1.5 + mean[0])
#plt.ylim(-1.5 + mean[1], 1.5 + mean[1])
plt.show()

In [None]:
# the error is ~10% for 200x200 grid

In [None]:
metric_on_grid_jacfwd.shape

# VI. Higher order derivatives using autograd

In [None]:
x = torch.randn(5)
jacobian = jacrev(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)
print(jacobian)
print(jacobian.shape)

In [None]:
jacfun = jacrev(torch.square)

In [None]:
second_der = jacrev(jacfun)

In [None]:
second_der(torch.tensor([1.,3.]))

In [None]:
jacfun(torch.tensor(0.))