This notebook provides the fast Frechet mean computation by combining the computation of geodesics and the mean itself into a single optimization problem with the help of 'Octopus' class

# Data loading

In [None]:
from tqdm.notebook import tqdm
import torch
import numpy as np
import ricci_regularization
import matplotlib.pyplot as plt
import matplotlib
import stochman
from stochman.manifold import EmbeddedManifold
from stochman.curves import CubicSpline

violent_saving = True

#experiment_json = f'../experiments/MNIST_torus_AEexp34.json' # no curv_pen

experiment_json = f'../experiments/MNIST01_torus_AEexp7.json'
mydict = ricci_regularization.get_dataloaders_tuned_nn(Path_experiment_json=experiment_json)

In [None]:
torus_ae = mydict["tuned_neural_network"]
test_loader = mydict["test_loader"]
json_cofig = mydict["json_config"]
Path_pictures = json_cofig["Path_pictures"]
exp_number = json_cofig["experiment_number"]
curv_w = json_cofig["losses"]["curv_w"]

In [None]:
D = 784
k = json_cofig["dataset"]["parameters"]["k"]
#zlist = []
torus_ae.cpu()
colorlist = []
enc_list = []
feature_space_encoding_list = []
input_dataset_list = []
recon_dataset_list = []
for (data, labels) in tqdm( test_loader, position=0 ):
#for (data, labels) in tqdm( train_loader, position=0 ):
    input_dataset_list.append(data)
    recon_dataset_list.append(torus_ae(data)[0])
    feature_space_encoding_list.append(torus_ae.encoder_torus(data.view(-1,D)))
    #zlist.append(vae(data)[1])
    enc_list.append(torus_ae.encoder2lifting(data.view(-1,D)))
    colorlist.append(labels) 

In [None]:
#x = torch.cat(zlist)
#enc = circle2anglevectorized(x).detach()
input_dataset = torch.cat(input_dataset_list)
recon_dataset = torch.cat(recon_dataset_list)
encoded_points = torch.cat(enc_list)
feature_space_encoding = torch.cat(feature_space_encoding_list)
encoded_points_no_grad = encoded_points.detach()
color_array = torch.cat(colorlist).detach()

In [None]:
plt.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1],c = color_array,cmap=ricci_regularization.discrete_cmap(k,"jet"))
plt.show()

In [None]:
from stochman.manifold import EmbeddedManifold
# geodesics are computed minimizing "energy" in the embedding of the manifold,
# So no need to compute the Pullback metric. and thus the algorithm is fast
class Autoencoder(EmbeddedManifold):
    def embed(self, c, jacobian = False):
        return torus_ae.decoder_torus(c)
#selected_labels = json_cofig["dataset"]["selected_labels"]
manifold = Autoencoder()

# Octopus

In [None]:
from stochman.curves import BasicCurve
from typing import Optional, Tuple
from torch import nn
from abc import ABC

class Octopus(BasicCurve):
    def __init__(
    self,
    begin: torch.Tensor,
    end: torch.Tensor,
    num_nodes: int = 5,
    requires_grad: bool = True,
    basis: Optional[torch.Tensor] = None,
    params: Optional[torch.Tensor] = None,
) -> None:
        super().__init__(begin, end, num_nodes, requires_grad, basis=basis, params=params)

    def _init_params(self, basis, params) -> None:
        if basis is None:
            basis = self._compute_basis(num_edges=self._num_nodes - 1)
        self.register_buffer("basis", basis)

        if params is None:
            params = torch.zeros(
                self.begin.shape[0], self.basis.shape[1], self.begin.shape[1], dtype=self.begin.dtype
            )
        else:
            params = params.unsqueeze(0) if params.ndim == 2 else params

        if self._requires_grad:
            self.register_parameter("params", nn.Parameter(params))
        else:
            self.register_buffer("params", params)

    # Compute cubic spline basis with end-points (0, 0) and (1, 0)
    def _compute_basis(self, num_edges) -> torch.Tensor:
        with torch.no_grad():
            # set up constraints
            t = torch.linspace(0, 1, num_edges + 1, dtype=self.begin.dtype)[1:-1]

            end_points = torch.zeros(2, 4 * num_edges, dtype=self.begin.dtype)
            end_points[0, 0] = 1.0
            #end_points[1, -4:] = 1.0

            zeroth = torch.zeros(num_edges - 1, 4 * num_edges, dtype=self.begin.dtype)
            for i in range(num_edges - 1):
                si = 4 * i  # start index
                fill = torch.tensor([1.0, t[i], t[i] ** 2, t[i] ** 3], dtype=self.begin.dtype)
                zeroth[i, si : (si + 4)] = fill
                zeroth[i, (si + 4) : (si + 8)] = -fill

            first = torch.zeros(num_edges - 1, 4 * num_edges, dtype=self.begin.dtype)
            for i in range(num_edges - 1):
                si = 4 * i  # start index
                fill = torch.tensor([0.0, 1.0, 2.0 * t[i], 3.0 * t[i] ** 2], dtype=self.begin.dtype)
                first[i, si : (si + 4)] = fill
                first[i, (si + 4) : (si + 8)] = -fill

            second = torch.zeros(num_edges - 1, 4 * num_edges, dtype=self.begin.dtype)
            for i in range(num_edges - 1):
                si = 4 * i  # start index
                fill = torch.tensor([0.0, 0.0, 6.0 * t[i], 2.0], dtype=self.begin.dtype)
                second[i, si : (si + 4)] = fill
                second[i, (si + 4) : (si + 8)] = -fill

            constraints = torch.cat((end_points, zeroth, first, second))
            self.constraints = constraints

            # Compute null space, which forms our basis
            _, S, V = torch.svd(constraints, some=False)
            basis = V[:, S.numel() :]  # (num_coeffs)x(intr_dim)

            return basis

    def _get_coeffs(self) -> torch.Tensor:
        coeffs = (
            self.basis.unsqueeze(0).expand(self.params.shape[0], -1, -1).bmm(self.params)
        )  # Bx(num_coeffs)xD
        B, num_coeffs, D = coeffs.shape
        degree = 4
        num_edges = num_coeffs // degree
        coeffs = coeffs.view(B, num_edges, degree, D)  # Bx(num_edges)x4xD
        return coeffs
    def _eval_polynomials(self, t: torch.Tensor, coeffs: torch.Tensor) -> torch.Tensor:
        # each row of coeffs should be of the form c0, c1, c2, ... representing polynomials
        # of the form c0 + c1*t + c2*t^2 + ...
        # coeffs: Bx(num_edges)x(degree)xD
        B, num_edges, degree, D = coeffs.shape
        idx = torch.floor(t * num_edges).clamp(min=0, max=num_edges - 1).long()  # Bx|t|
        power = (
            torch.arange(0.0, degree, dtype=t.dtype, device=self.device).view(1, 1, -1).expand(B, -1, -1)
        )  # Bx1x(degree)
        tpow = t.view(B, -1, 1).pow(power)  # Bx|t|x(degree)
        coeffs_idx = torch.cat([coeffs[k, idx[k]].unsqueeze(0) for k in range(B)])  # Bx|t|x(degree)xD
        retval = torch.sum(tpow.unsqueeze(-1).expand(-1, -1, -1, D) * coeffs_idx, dim=2)  # Bx|t|xD
        return retval

    def _eval_straight_line(self, t: torch.Tensor) -> torch.Tensor:
        B, T = t.shape
        tt = t.view(B, T, 1)  # Bx|t|x1
        retval = (1 - tt).bmm(self.begin.unsqueeze(1)) + tt.bmm(self.end.unsqueeze(1))  # Bx|t|xD
        return retval

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        coeffs = self._get_coeffs()  # Bx(num_edges)x4xD
        no_batch = t.ndim == 1
        if no_batch:
            t = t.expand(coeffs.shape[0], -1)  # Bx|t|
        retval = self._eval_polynomials(t, coeffs)  # Bx|t|xD
        retval += self._eval_straight_line(t)
        if no_batch and retval.shape[0] == 1:
            retval.squeeze_(0)  # |t|xD
        return retval

In [None]:
n = 3
p0 = torch.rand(n,2)
p1 = torch.rand(n,2)


In [None]:
octopus = Octopus(p0,p1)

In [None]:
octopus._parameters['params'].shape

In [None]:
octopus._parameters['params']

In [None]:
octopus.plot()
plt.scatter(p0[:,0],p0[:,1],marker="*",c="orange")
plt.scatter(p1[:,0],p1[:,1],marker="*",c="orange")

In [None]:
def frechet_mean(octopus, manifold,optimizer=torch.optim.Adam, max_iter=150, eval_grid=20, lr=1e-2):
    # Initialize optimizer and set up closure
    alpha = torch.linspace(0, 1, eval_grid, dtype=octopus.begin.dtype, device=octopus.device)
    opt = optimizer(octopus.parameters(), lr=lr,)   
    lambda_centers = 1e3 # huge weight

    def closure():
        opt.zero_grad()
 
        loss_dist2center = manifold.curve_energy(octopus(alpha)).mean()
        # this is euclidean distanses between points which are the starting points
        # we want it to be zero! (same starting point)
        loss_centers = (octopus(alpha)[:-1,-1] - octopus(alpha)[1:,-1]).norm().square()
        loss = lambda_centers*loss_centers + loss_dist2center
        loss.backward()
        return loss

    thresh = 1e-3
    for k in range(max_iter):
        opt.step(closure=closure)
        max_grad = max([p.grad.abs().max() for p in octopus.parameters()])
        if max_grad < thresh:
            break
        # if k % (max_iter // 10) == 0:
        #    curve.constant_speed(manifold)
    # curve.constant_speed(manifold)
    print(max_grad)
    return max_grad < thresh

In [None]:
frechet_mean(octopus,manifold)

In [None]:
octopus.plot()
plt.scatter(p0[:,0],p0[:,1],marker="*",c="orange")
#plt.scatter(p1[:,0],p1[:,1],marker="*",c="orange")

In [None]:
t = torch.linspace(0,1,20)
print("Frechet mean:\n", octopus(t)[:,-1][0].detach())
FM = octopus(t)[:,-1][0].detach()

In [None]:
octopus(t)[0]

In [None]:
c,_ = manifold.connecting_geodesic(p0[0],FM)
c(t)

Adding optimization params to the model

In [None]:
octopus.constraints.shape

In [None]:
optimizer = torch.optim.Adam(params=octopus.parameters())


In [None]:
def closure():
    optimizer.zero_grad()
    loss = manifold.curve_energy(octopus(t)).mean()
    loss.backward()
    return loss
optimizer.step(closure=closure)

In [None]:
"""
for param in octopus.parameters():
    print(param.grad)
"""

In [None]:
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

        # Register a custom parameter
        self.register_parameter("custom_param", nn.Parameter(torch.randn(1, 1)))

    def forward(self, x):
        # Use the custom parameter
        return self.fc(x) + self.custom_param

# Create an instance of the model
model = MyModel()

# Accessing registered parameters
for name, param in model.named_parameters():
    print(name, param.size())


#optimization loop example

input = torch.rand(10)
optimizer = torch.optim.Adam(params=model.parameters())
num_steps = 50
for i in range(num_steps):
    optimizer.zero_grad()
    loss = model(input).norm()
    loss.backward()
    optimizer.step()
    #print(loss)
new_param = nn.Parameter(torch.randn(7,7))
model.register_parameter("hahaha_param",nn.Parameter(torch.randn(7,7)))
# Accessing registered parameters
for name, param in model.named_parameters():
    print(name, param.grad)

# Riemannian K-means

# 1.Initial guess via Euclidean K-means via Geomstats

In [None]:
#this adds an environmental variable
#%env GEOMSTATS_BACKEND=pytorch

import geomstats.backend as gs
import geomstats.visualization as visualization
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.learning.kmeans import RiemannianKMeans

In [None]:
circumference1 = Hypersphere(dim=1)
circumference2 = Hypersphere(dim=1)

# Building torus as a product $\mathcal{T} = \mathcal{S}^1 \times \mathcal{S}^1$ 

In [None]:
from geomstats.geometry.product_manifold import ProductManifold
torus = ProductManifold((circumference1,circumference2))

Loading saved points and labels and plotting ground truth labels

In [None]:
encoded_angles = encoded_points_no_grad #torch.load("encoded_angles.pt")
gt_labels = color_array #torch.load("labels.pt")
#convert dt_labels into 0 and 1 array
gt_labels = (gt_labels - min(gt_labels))/max((gt_labels - min(gt_labels))).to(torch.int)
gt_labels = gt_labels.numpy()

Putting MNIST data on torus

In [None]:
circ_1_coordinates = torus.factors[0].intrinsic_to_extrinsic_coords(encoded_angles[:,0]).reshape(2,-1).T
circ_2_coordinates = torus.factors[1].intrinsic_to_extrinsic_coords(encoded_angles[:,1]).reshape(2,-1).T
#print("1st", circ_1_coordinates)
#print("2nd", circ_2_coordinates)

In [None]:
MNIST_data_on_torus_4d = np.concatenate((circ_1_coordinates,circ_2_coordinates),axis = 1).reshape(-1,2,2) # cos\phi, sin \phi, cos \psi, sin \psi

In [None]:
kmeans = RiemannianKMeans(torus, k, tol=1e-3)
kmeans.fit(MNIST_data_on_torus_4d)
kmeans_latent_space_euclidean_labels = kmeans.labels_
cluster_centers = kmeans.centroids_# kmeans.cluster_centers_

In [None]:
torus.factors[0].extrinsic_to_angle(cluster_centers[:][1])

In [None]:
cluster_centers_in_local_chart = Hypersphere(dim=1).extrinsic_to_intrinsic_coords(cluster_centers).squeeze()

In [None]:
fig,(ax1,ax2) = plt.subplots(2,1,figsize=(8, 12))
p1 = ax1.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=kmeans_latent_space_euclidean_labels, marker='o', edgecolor='none', cmap=ricci_regularization.discrete_cmap(k, 'jet'))
plt.colorbar(p1,ticks=range(k))
ax1.title.set_text("Latent space colored by K-means on Torus with Euclidean metric")
ax1.grid(True)
ax1.scatter(cluster_centers_in_local_chart[:,0],cluster_centers_in_local_chart[:,1],marker = '*',s=150,c ="orange")

correcltly_detected_labels = abs(kmeans_latent_space_euclidean_labels - gt_labels)
if correcltly_detected_labels.sum() < len(gt_labels)//2:
    correcltly_detected_labels = np.logical_not(correcltly_detected_labels)

p2 = ax2.scatter(encoded_points_no_grad[:,0],encoded_points_no_grad[:,1], c=correcltly_detected_labels, marker='o', edgecolor='none', cmap=plt.cm.get_cmap("viridis", k))
cbar = plt.colorbar(p2,ticks=[0.25,0.75])
cbar.ax.set_yticklabels(["incorrect","correct"]) 
#if violent_saving == True:
#    plt.savefig(f"{Path_pictures}/Kmeans_latent_space.pdf",format="pdf")

# K-means in input data space

In [None]:
num_points_in_clusters = 3

clusters = []
clusters_initial_labels = []
for i in range(k):
    current_cluster = encoded_points_no_grad[np.where(kmeans_latent_space_euclidean_labels == i)]
    current_cluster = current_cluster[:num_points_in_clusters]
    current_cluster_initial_labels = gt_labels[np.where(kmeans_latent_space_euclidean_labels == i)]
    current_cluster_initial_labels = torch.tensor(current_cluster_initial_labels[:num_points_in_clusters])
    clusters_initial_labels.append(current_cluster_initial_labels)
    clusters.append(current_cluster)

In [None]:
import matplotlib.colors as mcolors
colors = list(mcolors.TABLEAU_COLORS.keys())
for i in range(k):
    plt.scatter(clusters[i][:,0],clusters[i][:,1],c = colors[i%len(colors)],s=30)
    plt.scatter(cluster_centers_in_local_chart[i,0],cluster_centers_in_local_chart[i,1],marker = '*',s=250,c =colors[i%len(colors)],edgecolors="black")
plt.xlim(-torch.pi,torch.pi)
plt.ylim(-torch.pi,torch.pi)
plt.show()

In [None]:
# multioctopus parameters setting
N = num_points_in_clusters*k # # all points 

points = (torch.rand(N,2)-1/2)*2*torch.pi #random points

#points = torch.cat(clusters)
#cluster_centers_in_local_chart.repeat(k*num_points_in_clusters,1).T.shape
init_centers = torch.tensor(cluster_centers_in_local_chart.T.repeat(N,1).T)
#print(init_centers)
# b1,b1,...,b2,b2,...

In [None]:
multi_octopus = Octopus(points.repeat(k,1),init_centers)
weights = nn.Parameter((1/k)*torch.ones(N,k))
multi_octopus.register_parameter("leg_weights",weights)

In [None]:
multi_octopus.plot()
plt.scatter(cluster_centers_in_local_chart[:,0],cluster_centers_in_local_chart[:,1],marker = '*',s=250,c = "orange",edgecolors="black",zorder=10)
plt.xlim(-torch.pi,torch.pi)
plt.ylim(-torch.pi,torch.pi)
plt.show()

In [None]:
# loss centers
#multi_octopus(t)[:N,-1] - multi_octopus(t)[:N,-1].roll(shifts=1,dims=0)

In [None]:
t = torch.linspace(0,1,20)
#t

In [None]:
energies = manifold.curve_energy(multi_octopus(t),reduction=None)
print(energies)
energies = energies.reshape(2,N).T # in i-th column j-th row is the energy of curve
# connecting p_j and b_i 
print(weights)
print(energies)
(weights*energies).mean()

In [None]:
"""
with torch.no_grad():
    lambda_centers = 2*energies.mean()
"""
lambda_centers = 1e4 # huge weight
print(lambda_centers)
#optimizer choice

#opt = torch.optim.Adam(multi_octopus.parameters(), lr=0.2e-2)   
opt = torch.optim.SGD(multi_octopus.parameters(), lr=1e-2)

In [None]:
weights.grad

In [None]:
# diagnostic
opt.zero_grad()

energies = manifold.curve_energy(multi_octopus(t),reduction=None)
energies = energies.reshape(2,N).T # in i-th column j-th row is the energy of curve
# connecting p_j and b_i 
loss_dist2center = (weights*energies).sum()

#loss_dist2center = manifold.curve_energy(multi_octopus(t)).mean()

# these are euclidean distanses between points which are the starting points
# we want it to be zero! (same starting point)
#loss_centers = (multi_octopus(t)[:-2,-1] - multi_octopus(t)[2:,-1]).norm().square() 
loss_centers = 0.
for s in range(k):
    # just compute values in the end points!!
    loss_centers += (multi_octopus(t)[s*N:(s+1)*N,-1] - multi_octopus(t)[s*N:(s+1)*N,-1].roll(shifts=1,dims=0)).square().mean()
# these are endpoints of pathes to same baricenters

loss = lambda_centers*loss_centers + loss_dist2center

loss.backward()
print(f"loss:{loss.item():.3f}, loss_centers:{loss_centers.item():.3f}, loss_dist2center:{loss_dist2center.item():.3f}")

torch.nn.utils.clip_grad_norm_(multi_octopus.parameters(), 1e+1)
torch.nn.utils.clip_grad_norm_(weights, 1e-1) #clip weights gradients harder

opt.step()

# weights clamp and renormalization

with torch.no_grad():
    weights.clamp_(min=0.,max=1.)
    normalized_weights = torch.nn.functional.normalize(weights, p=1, dim=1)

    # if weights become (0,0) make them (1/2,1/2)
    normalized_weights += (1 - normalized_weights.norm(p=1,dim=1)).repeat(2,1).T
    normalized_weights = torch.nn.functional.normalize(normalized_weights, p=1, dim=1)

    weights.copy_(normalized_weights)
    
    # weights = torch.nn.functional.normalize(weights,p=1,dim=1) 
    
    # this should not be used as it creats a new tensor and kills 
    # previous grad tracking as it is used under torch.no_grad()
print("weights:\n",weights)
print("weights gradients:\n",weights.grad)
print("energies:\n",energies)
multi_octopus.plot()
plt.show()

#training_loop(1)

In [None]:
for param in multi_octopus.parameters():
    print("lala",param.grad)

In [None]:
import matplotlib.pyplot as plt

def training_loop(num_epochs: int,d=2):
    gradient_norms = []
    loss_list = []

    for epoch in range(num_epochs):
        opt.zero_grad()

        # Compute energies
        energies = manifold.curve_energy(multi_octopus(t), reduction=None)
        energies = energies.reshape(k, N).T  # Shape adjustment
        loss_dist2center = (weights * energies).norm()/(N*k) #was sqrt and grad was exploding!
        #loss_dist2center = energies.mean()

        # Compute loss for centers
        loss_centers = 0.0
        multi_octopus_leg_ends = multi_octopus(torch.ones(1))  # Compute once to avoid repetition
        for s in range(k):
            start_points = multi_octopus_leg_ends[s * N:(s + 1) * N]
            rolled_start_points = start_points.roll(shifts=epoch, dims=0)
            loss_centers += (start_points - rolled_start_points).square().mean()
        
        #weights sum up to 1
       
        #loss_weights_constraint = (weights.norm(p=1,dim=1)- torch.ones(N)).square().mean()
        #loss_weights_constraint += (nn.ReLU()(torch.zeros(N,2) - weights) + nn.ReLU()(weights - torch.ones(N,2))).mean()

        # Total loss
        loss = lambda_centers * loss_centers + loss_dist2center 
        #+ loss_weights_constraint 

        # Backpropagation
        loss.backward()

        # Collect gradient norms
        grad_norm = 0.0
        for param in multi_octopus.parameters():
            grad_norm += param.grad.norm().item()
        gradient_norms.append(grad_norm)
        loss_list.append(loss.item())

        # Gradient clipping
        # The algorithm results are very sensible to clipping parameters 
        #torch.nn.utils.clip_grad_norm_(multi_octopus.parameters(), 1e1)
        torch.nn.utils.clip_grad_norm_(weights, 1e0)

        # Optimization step
        opt.step()

        # Weights clamp and renormalization
        with torch.no_grad():
            weights.clamp_(min=0.0, max=1.0)
            normalized_weights = torch.nn.functional.normalize(weights, p=1, dim=1)
            
            # if weights become (0,0) make them (1/2,1/2)
            normalized_weights += (1 - normalized_weights.norm(p=1,dim=1)).repeat(d,1).T
            #normalized_weights = torch.nn.functional.normalize(normalized_weights, p=1, dim=1)

            weights.copy_(normalized_weights)  # Ensure weights retain gradients
        
        # Logging
        print("frechet means:", {multi_octopus_leg_ends[0]}, multi_octopus_leg_ends[N])
        print(f"Epoch {epoch+1}/{num_epochs} - loss: {loss.item():.3f}, loss_centers: {loss_centers.item():.3f},loss_dist2center: {loss_dist2center.item():.3f},grad_norm: {grad_norm:.3f}")

    # Plot gradient norms
    plt.figure(figsize=(10, 6))
    plt.plot(loss_list, label='loss')
    plt.plot(gradient_norms, label='Gradient Norms')
    plt.xlabel('Iteration')
    plt.ylabel('losses')
    plt.title('Gradient Norms During Training')
    plt.legend()
    plt.grid(True)
    plt.show()

    return loss


In [None]:
weights.grad

In [None]:
multi_octopus = Octopus(points.repeat(k,1),init_centers)
#multi_octopus = CubicSpline(points.repeat(k,1),init_centers) # fixed means
weights = nn.Parameter((1/k)*torch.ones(N,k))
multi_octopus.register_parameter("leg_weights",weights)

lambda_centers = 1e3 # huge weight
print(lambda_centers)
#optimizer choice

#opt = torch.optim.Adam(multi_octopus.parameters(), lr=0.2e-2)   


In [None]:
opt = torch.optim.SGD(multi_octopus.parameters(), lr=1e-2)
loss = training_loop(num_epochs=50)
multi_octopus.plot()

In [None]:
weights.grad

In [None]:
weights

In [None]:
multi_octopus.plot()
plt.show()

In [None]:
for name,param in multi_octopus.named_parameters():
    print(name,param)