In [None]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd 
import random 
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

data_dir = 'dataset'

train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
test_dataset  = torchvision.datasets.MNIST(data_dir, train=False, download=True)

train_transform = transforms.Compose([
transforms.ToTensor(),
])

test_transform = transforms.Compose([
transforms.ToTensor(),
])

train_dataset.transform = train_transform
test_dataset.transform = test_transform

m=len(train_dataset)

train_data, val_data = random_split(train_dataset, [int(m-m*0.2), int(m*0.2)])
batch_size=256

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,shuffle=True)

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            #nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            #nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        ### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )
        
    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_lin(x)
        return x
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [None]:
### Initialize the two networks
d = 2

#model = Autoencoder(encoded_space_dim=encoded_space_dim)
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)

# Check if the GPU is available
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Force CPU
device = torch.device("cpu")
print(f'Selected device: {device}')

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

In [None]:
# Load

#without curvature in Loss func
PATH_enc = '../nn_weights/encoder_conv_autoenc.pt'
PATH_dec = '../nn_weights/decoder_conv_autoenc.pt'

#with curvature in Loss func
#PATH_enc = 'encoder_convAE_curv_0.1.pt'
#PATH_dec = 'decoder_convAE_curv_0.1.pt'

#with curvature in Loss func
#PATH_enc = 'encoder_curw_w=0.001_2epoch.pt'
#PATH_dec = 'decoder_curw_w=0.001_2epoch.pt'

encoder.load_state_dict(torch.load(PATH_enc))
encoder.eval()
decoder.load_state_dict(torch.load(PATH_dec))
decoder.eval()

In [None]:
#generate samples from latnt code and calculate mean and std
def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

encoder.eval()
decoder.eval()

with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs 
    images, labels = next(iter(test_loader))
    images = images.to(device)
    latent = encoder(images)
    latent = latent.cpu()

    mean = latent.mean(dim=0)
    print(mean)
    std = (latent - mean).pow(2).mean(dim=0).sqrt()
    print(std)

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, d)*std + mean
    #print(latent)
    #print(latent.shape)

    # reconstruct images from the random latent vectors
    latent = latent.to(device)
    img_recon = decoder(latent)
    img_recon = img_recon.cpu()

    #fig, ax = plt.subplots(figsize=(20, 8.5))
    #show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    #plt.show()

Point plot

In [None]:
from tqdm import tqdm
encoded_samples = []
for sample in tqdm(test_dataset):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)
encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples

In [None]:
import plotly.express as px

px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', 
           color=encoded_samples.label.astype(str), opacity=0.7)

Manifold plot

In [None]:
#Let us take a uniform grid on the latent space. Note that here d=2. The bounds for the grid can be taken from 3 sigma rule. 
#We will take 2 sigmas however
numsteps = 10
xs = torch.linspace(mean[0]-2*std[0], mean[0]+2*std[0], steps = numsteps)
ys = torch.linspace(mean[1]-2*std[1], mean[1]+2*std[1], steps = numsteps)
uniform_grid = torch.cartesian_prod(xs,ys)

# True Manifold plot
truegrid = torch.cartesian_prod(ys,- xs)
latent = - truegrid.roll(1,1)
latent = latent.to(device)
img_recon = decoder(latent)
img_recon = img_recon.cpu()

fig, ax  = plt.subplots(figsize=(20, 8.5))
img_grid = torchvision.utils.make_grid(img_recon[:100],10,5)
show_image(img_grid.detach())
plt.show()
print(latent.shape)
print(img_recon.shape)

# Fast way to compute metric on a grid over the latent space (torch.roll)

In [None]:
#Let us take a uniform grid on the latent space. Note that here d=2. The bounds for the grid can be taken from 3 sigma rule. 
#We will take 2 sigmas however
numsteps = 100
zoom = 1

# Centralized and scaled evaluation 
xs = torch.linspace(mean[0]-2*std[0], mean[0]+2*std[0], steps = numsteps)/zoom
ys = torch.linspace(mean[1]-2*std[1], mean[1]+2*std[1], steps = numsteps)/zoom

#fixed location of latent space evaluation
#xs = torch.linspace(-1.5, 1.5, steps = numsteps)/zoom
#ys = torch.linspace(-1.5, 1.5, steps = numsteps)/zoom

#uniform_grid = torch.cartesian_prod(xs,ys)

In [None]:
#alt grid
#numsteps = 10

#xs = torch.linspace(1.2-0.3, 1.2+0.3, steps = numsteps)
#ys = torch.linspace(0.6-0.3, 0.6+0.3, steps = numsteps)

In [None]:
# true grid starts from left bottom corner. x is the first to increase
tgrid = torch.cartesian_prod(ys, xs)
tgrid = tgrid.roll(1,1)

In [None]:
#metric on a grid
def g(grid):
    numsteps = int(np.sqrt(grid.shape[0]))
    
    hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
    hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)
    
    latent = grid
    latent = latent.to(device)
    psi = decoder(latent)
    psi_next_x =  psi.roll(-1,0)
    psi_prev_x =  psi.roll(1,0)
    psi_next_y =  psi.roll(-numsteps,0)
    psi_prev_y =  psi.roll(numsteps,0)
    
    dpsidx = (psi_next_x - psi_prev_x)/(2*hx)
    dpsidy = (psi_next_y - psi_prev_y)/(2*hy)
    
    metric = torch.cat(((dpsidx*dpsidx).sum((1,2,3)),(dpsidx*dpsidy).sum((1,2,3)),(dpsidx*dpsidy).sum((1,2,3)),(dpsidy*dpsidy).sum((1,2,3))),0)
    metric = metric.view(4, numsteps*numsteps)
    metric = metric.transpose(0, 1)
    metric = metric.view(numsteps*numsteps, 2, 2)
    return metric

In [None]:
# hyperbolic metric and its derivatives on a grid
R = 10
def specific_metric (u): 
    # u is the vector of points
    #R = 5 #Radius
    phi = u[:,0]
    theta = u[:, 1]
    n = u.shape[0] #number of points
    g = torch.zeros((n,2,2))

    #Sphere
    g11 = torch.cos(theta)**2
    g12 = torch.zeros(n)
    g21 = torch.zeros(n)
    g22 = torch.ones(n)

    #hyperbolic metric on a half plane
    #g11 = 1/theta**2
    #g12 = torch.zeros(n)
    #g21 = torch.zeros(n)
    #g22 = 1/theta**2

    g = torch.cat((g11, g12, g21, g22)).view(4,n)
    g = g.T
    g = g.view(n, 2, 2)
    g = (R**2)*g
    #g = (R**2)*torch.tensor([[torch.cos(theta)**2, 0],[0, 1]])
    return g
def specific_metric_der (u): 
    #phi, theta = u
    #think of x = phi, y = theta
    # u is the vector of points
    #R = 5 #Radius
    phi = u[:,0]
    theta = u[:, 1]
    n = u.shape[0] #number of points
    g = torch.zeros((n,2,2,2))
    
    #x derivatives of g
 
    gx11 = torch.zeros(n)
    gx12 = torch.zeros(n)
    gx21 = torch.zeros(n)
    gx22 = torch.zeros(n)

    gx = torch.cat((gx11, gx12, gx21, gx22)).view(4,n)
    gx = gx.T
    gx = gx.view(n, 2, 2)
    
    #y derivatives of g

    #sphere
    gy11 = -R**2*torch.sin(2*theta)
    gy12 = torch.zeros(n)
    gy21 = torch.zeros(n)
    gy22 = torch.zeros(n)



    #hyperbolic metric
    #gy11 = -2/theta**3
    #gy12 = torch.zeros(n)
    #gy21 = torch.zeros(n)
    #gy22 = -2/theta**3
    
    gy = torch.cat((gy11, gy12, gy21, gy22)).view(4,n)
    gy = gy.T
    gy = gy.view(n, 2, 2)

    dg = torch.cat((gx,gy),1).view(n,2,2,2)
    #g = np.array([[[0, 0],
    #               [0, 0]],
    #              [[-R**2*np.sin(2*theta), 0],
    #               [0, 0]]])
    return dg

In [None]:
# compute the grid of metric
with torch.no_grad():
    metric = g(tgrid)
    #metric = specific_metric(tgrid)

## Heatmap of frobenius norm of metric

In [None]:
# Fast computation of Frobenious norm on the grid without borders
Newfrob = metric.norm(dim=(1,2)).view(numsteps,numsteps)
Newfrob = Newfrob[1:-1,1:-1].transpose(0,1)

In [None]:
#Heat map of the frobenius norm
h = plt.contourf(xs[1:-1], ys[1:-1], Newfrob)
plt.title('Heatmap of the Frobenius norm of the metric')
plt.xlabel( "x coordinate")
plt.ylabel( "y coordinate")
plt.axis('scaled')
plt.colorbar(label="Frobenius norm of the metric")
#plt.xlim(-1.5 + mean[0], 1.5 + mean[0])
#plt.ylim(-1.5 + mean[1], 1.5 + mean[1])
plt.show()

### Derivatives of the metric and Christoffel symbols

In [None]:
#simultaneous differentiation on a grid with torch.roll
def diff_by_x(tensor, numsteps, h):
    psi = tensor
    psi_next_x =  psi.roll(-1,0)
    psi_prev_x =  psi.roll(1,0)
    dpsidx = (psi_next_x - psi_prev_x)/(2*h)
    return dpsidx
def diff_by_y(tensor, numsteps, h):
    psi = tensor
    psi_next_y =  psi.roll(-numsteps,0)
    psi_prev_y =  psi.roll(numsteps,0)
    dpsidy = (psi_next_y - psi_prev_y)/(2*h)
    return dpsidy
    

In [None]:
#derivatives of the metric on a grid
def dg_grid (grid): #dg
    
    numsteps = int(np.sqrt(grid.shape[0]))
    
    hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
    hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)

    latent = grid
    latent = latent.to(device)
    psi = decoder(latent)
    
    dpsidx = diff_by_x(psi, numsteps, hx)
    dpsidy = diff_by_x(psi, numsteps, hy)
    dpsidx_second = diff_by_x(dpsidx, numsteps, hx)
    dpsidx_dy = diff_by_y(dpsidx, numsteps, hy)
    dpsidy_second = diff_by_y(dpsidy, numsteps, hy)
    
    #metric = torch.cat(((dpsidx*dpsidx).sum((1,2,3)),(dpsidx*dpsidy).sum((1,2,3)),
    #                  (dpsidx*dpsidy).sum((1,2,3)),(dpsidy*dpsidy).sum((1,2,3))),0)
    
    dgdx = torch.cat((2*(dpsidx*dpsidx_second).sum((1,2,3)),(dpsidx_second * dpsidy + dpsidx * dpsidx_dy).sum((1,2,3)),
                      (dpsidx_second * dpsidy + dpsidx * dpsidx_dy).sum((1,2,3)),2*(dpsidy * dpsidx_dy).sum((1,2,3))),0)
    dgdy = torch.cat((2*(dpsidx*dpsidx_dy).sum((1,2,3)),(dpsidy_second * dpsidx + dpsidy * dpsidx_dy).sum((1,2,3)),
                      (dpsidy_second * dpsidx + dpsidy * dpsidx_dy).sum((1,2,3)),2*(dpsidy*dpsidy_second).sum((1,2,3))),0)
    metric_der = torch.cat((dgdx, dgdy), 0)
    metric = metric_der
    metric = metric.view(8, numsteps*numsteps)
    metric = metric.transpose(0, 1)
    metric = metric.view(numsteps*numsteps, 2, 4)
    metric = metric.view(numsteps*numsteps, 2, 2, 2)
    return metric

In [None]:
# compute the grid of metric derivatives
with torch.no_grad():
    metric_der = dg_grid(tgrid)
    #metric_der = specific_metric_der(tgrid)


In [None]:
metric_der.shape

In [None]:
#This means that we can simultanuousely invert all the matrices over the grid
torch.equal(torch.inverse(metric[0]),torch.inverse(metric)[0])

In [None]:
#this is the inverse of the metric on a grid
metric_inv = torch.inverse(metric)

In [None]:
#Christoffel symbols on a grid
def Ch_grid(grid):
    #x = grid[:,0]
    #y = grid[:, 1]
    n = grid.shape[0]
    Ch = torch.zeros((n, 2,2,2))
    for i in range(2):
        for j in range(2):
            for l in range(2):
                for k in range(2):
                    #Ch^l_ij
                    Ch[:,l,i,j] += 0.5 * metric_inv[:,l,k] * (metric_der[:,i,k,j] + metric_der[:,j,i,k] - metric_der[:,k,i,j]) 
                    
                    #Ch[l,i,j] += 0.5 * g_inv(grid)[l,k] * (dg(grid)[i,k,j] + dg(grid)[j,i,k] - dg(grid)[k,i,j]) #Ch^l_ij
    return Ch

In [None]:
#checking Christoffel on a grid
Ch_grid(tgrid).shape

Derivatives of Christoffel symbols on a grid

In [None]:
#derivatives of Christoffel symbols on a grid
def Ch_der_grid(grid):
    n = grid.shape[0]

    numsteps = int(np.sqrt(grid.shape[0]))
    hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
    hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)
    
    Chdx = diff_by_x(Ch_grid(grid), numsteps, hx)
    Chdy = diff_by_y(Ch_grid(grid), numsteps, hy)
    Chder = torch.cat((Chdx, Chdy), -1)
    Chder = Chder.view(n,2,2,2,2)
    Chder = Chder.transpose(-1,-2)
    return Chder



In [None]:
Ch_der_grid(tgrid).shape

In [None]:
# Riemann curvature tensor (3,1)
def Riem(grid):
    n = grid.shape[0]

    Riem = torch.zeros(n, 2, 2, 2, 2)
    for i in range(2):
        for j in range(2):
            for k in range(2):
                for l in range(2):                    
                    Riem[:, i, j, k, l] = Ch_der_grid(grid)[:, i, l, j, k] - Ch_der_grid(grid)[:, i, k, j, l] 
                    for p in range(2):
                        Riem[:, i, j, k, l] += (Ch_grid(grid)[:, i, k, p]*Ch_grid(grid)[:, p, l, j] - Ch_grid(grid)[:, i, l, p]*Ch_grid(grid)[:, p, k, j])
    return Riem



In [None]:
Riem(tgrid).shape

In [None]:
#Scew symmetry check
torch.equal(Riem(tgrid)[:,0,0,0,1], - Riem(tgrid)[:,0,0,1,0])


In [None]:
# Ricci curvature tensor via Riemann
# R_ab = Riem^c_acb
def Ric(grid):
    n = grid.shape[0]
    Ric = torch.zeros(n, 2, 2)
    for a in range(2):
        for b in range(2):
            for c in range(2):
                Ric[:, a, b] += Riem(grid)[:, c, a, c, b]
    return Ric
    # takes 2.5 secs on 100 by 100 grid

In [None]:
Ric(tgrid).shape

In [None]:
# Scalar curvature tensor via Riemann and Ricci
# R_ab = Riem^c_acb
# R = g^ij * R_ij
def Sc(grid):
    n = grid.shape[0]
    Sc = torch.zeros(n)
    Ric = torch.zeros(n, 2, 2)
    for a in range(2):
        for b in range(2):
            for c in range(2):
                Ric[:, a, b] += Riem(grid)[:, c, a, c, b]
    Sc = metric_inv*Ric
    Sc = torch.sum(Sc,(1,2))
    return Sc

In [None]:
Scalar_curvature_grid = Sc(tgrid)

# Scalar curvature heatmap

In [None]:
# Fast computation of Frobenious norm on the grid without borders
Scalar_curv = Scalar_curvature_grid.view(numsteps,numsteps)
#Scalar_curv_check = Scalar_curv[30:-30,30:-30].transpose(0,1)
Scalar_curv = Scalar_curv[2:-2,2:-2].transpose(0,1)

In [None]:
#Heat map of the Scalar curvature
h = plt.contourf(xs[2:-2], ys[2:-2], Scalar_curv)
#h = plt.contourf(xs[30:-30], ys[30:-30], Scalar_curv_check)
plt.title('Heat map of the Scalar curvature ')
plt.xlabel( "x coordinate")
plt.ylabel( "y coordinate")
plt.axis('scaled')
#plt.xlim(-1.5,1.5)
#plt.ylim(-1.5,1.5)

plt.colorbar()
plt.show()

Simplified energy functional computation: $F_{new}(g) = \int_{M}  R^{2} d\mu$

In [None]:
metric_no_border = metric.reshape(numsteps, numsteps,2,2)[2:-2,2:-2]
det_metric_no_border = torch.det(metric_no_border)
det_sqrt = torch.sqrt(det_metric_no_border)
grid = tgrid
hx = float(abs((grid[numsteps**2 - 1] - grid[0])[0]))/(numsteps - 1)
hy = float(abs((grid[numsteps**2 - 1] - grid[0])[1]))/(numsteps - 1)

F_new = (det_sqrt*torch.square(Scalar_curv)*hx*hy).sum()

print(F_new)


# Geodesics

In [None]:
# This is used for making a piecewise constant metric from its evaluation on a grid
def find_nearest_index (grid, u):
    index = int(torch.min(abs(grid - u),0).indices.sum())
    #index = int((((u - tgrid[0])*numsteps/size).floor()*torch.tensor([1.,numsteps])).sum()) #thisd could be faster
    return index

In [None]:
#computing geodesics...
# y = [u , v]
# v := dot(u)
# dot(v)^l = Ch^l_ij * v^i * v^j
def geod(y, t):
    #u, v = y
    u = y[0:2:]
    v = y[2::]
    dudt = v
    #dvdt = torch.zeros(2)
    dvdt = np.zeros(2)
    u = torch.from_numpy(u)
    for l in range(2):
        for i in range(2):
            for j in range(2):
                dvdt[l] -= (Ch(u)[l,i,j]).numpy() * v[i] * v[j]
    dydt = np.concatenate((dudt, dvdt))
    #dydt = torch.cat((dudt, dvdt),0)
    return dydt

## Vectorized computation of geodesics (with a loop in find_indices)

In [None]:
# this could be done faster
def find_nearest_indices (grid, u):
    #this could be done more efficiently
    n = u.shape[0]
    indices = torch.zeros(n)
    for i in range(n):
        indices[i] = find_nearest_index(grid, u[i])
    indices = indices.to(torch.int64) # just some magic to make it work
    return indices

In [None]:
find_nearest_index(tgrid, torch.tensor([0.5,0.3]))

In [None]:
tgrid[5563]

In [None]:
#evaluation of the piecewise constant inverse of g
def g_inv_vect (grid, u): #inverse metric
    #index = find_nearest_index(tgrid, u)
    indices = find_nearest_indices(grid, u)
    #A = metric[index]
    A = torch.index_select(metric, 0, indices)
    g_inv = torch.inverse(A)
    return g_inv

In [None]:
#g_inv_vect(tgrid, check)

In [None]:
#evaluation of the piecewise constant derivatives of g
def dg_vect (grid, u): #dg
    #index = find_nearest_index(uniform_grid, u)
    indices = find_nearest_indices(grid, u)
    g = torch.index_select(metric_der, 0, indices)
    return g

In [None]:
#dg_vect(tgrid, check)

In [None]:
#Christoffel symbols at a vector of n points. u has shape (n, x, y)
def Ch_vect(grid, u):
    n = u.shape[0]
    Ch = torch.zeros((n,2,2,2))
    for i in range(2):
        for j in range(2):
            for l in range(2):
                for k in range(2):
                    Ch[:,l,i,j] += 0.5 * g_inv_vect(grid, u)[:,l,k] * (dg_vect(grid, u)[:,i,k,j] + dg_vect(grid, u)[:,j,i,k] - dg_vect(grid, u)[:,k,i,j]) #Ch^l_ij
    return Ch

In [None]:
#Ch_vect(tgrid, check)

In [None]:
#Ch(check[1])
# just to check there is no mistake in vectorized vertion Ch_vect

In [None]:
#  Ch_vect still exploits the loop in find_indices
#Ch_vect(tgrid,tgrid)

In [None]:
#computing geodesics...
# y has shape num of points, u, v
# v := dot(u)
# dot(v)^l = Ch^l_ij * v^i * v^j
def geod(y, t):
    #u, v = y
    n = y.shape[0]
    u = y[: , 0:2:]
    v = y[: , 2::]
    dudt = v
    dvdt = torch.zeros(n, 2)
    for l in range(2):
        for i in range(2):
            for j in range(2):
                dvdt[:, l] -= Ch_vect(tgrid, u)[:, l,i,j] * v[:, i] * v[:, j] #here we use Ch_vect instead od Ch
    dydt = torch.cat((dudt.T, dvdt.T)).T
    # dydt = np.concatenate((dudt, dvdt))
    return dydt

In [None]:
def rungekutta_new(f, y0, t, args=()):
    nt = len(t) # number of steps in time
    # len(y0[0]) is the number of initial conditions
    # len(y0[1]) is the dimention of the state space. In our case it is 4 
    y = torch.zeros((nt, y0.shape[0],y0.shape[1]))
    y[0,:,:] = y0
    for i in range(nt - 1):
        y[i+1,:,:] = y[i,:,:] + (t[i+1] - t[i])*f(y[i,:,:], t[i], *args)
        print(y[i,:,:])
    return y

In [None]:
# Let us start at random points u with the same speed v
# we want to draw m geodesics
m = 10
v = torch.tensor([0.00, 0.00,1.00])
v = v.repeat(m,1)
u = torch.rand(m,1)
#unorm = u.norm(dim=1)
#u = (u.T/unorm).T

RandStartComSpeed = torch.cat((u,v),1)
RandStartComSpeed

In [None]:
t = torch.linspace(0, 1, steps = 21)
sol3 = rungekutta_new(geod, RandStartComSpeed, t)

In [None]:
plt.plot(sol3[:15, :, 0], sol3[:15, :, 1]) #geodesics are shortened by step 15 because of border effects
plt.title( "Plots of geodesics with rnd ititial point and common initial speed")
plt.xlabel( "x coordinate")
plt.ylabel( "y coordinate")
plt.grid()

# Scalar curvature and geodesics on one plot

In [None]:
# Let us start at different initial points u with the same speed v
# we want to draw m geodesics
m = 15 #number of geodesics
#v = torch.tensor([0.00, 0.00,1.00])
v = torch.tensor([0.00, 0.00,1.00])
v = v.repeat(m,1)
#u = torch.rand(m,1)
#u = torch.linspace(0.01,1.51,steps=m).reshape(15,1)
u = torch.linspace(0.01,1.51,steps=m).reshape(15,1)
#unorm = u.norm(dim=1)
#u = (u.T/unorm).T

RandStartComSpeed2 = torch.cat((u,v),1)
RandStartComSpeed2

In [None]:
t = torch.linspace(0, 1, steps = 41)
sol4 = rungekutta_new(geod, RandStartComSpeed2, t)

In [None]:
#Scalar curvature and geodesics
h = plt.contourf(xs[2:-2], ys[2:-2], Scalar_curv)
plt.plot(sol4[:30, :, 0], sol4[:30, :, 1]) #geodesics are shortened by step 30 because of border effects
plt.title('Scalar curvature and geodesics')
plt.xlabel( "x coordinate")
plt.ylabel( "y coordinate")
plt.axis('scaled')
plt.xlim(0,1.75)
plt.ylim(0,1.25)
plt.colorbar(label="Scalar curvature")
plt.show()

In [None]:
Scalar_curvature_grid[5563]