In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import embedders

In [58]:
import embedders.gaussian_mixture


pm = embedders.manifolds.ProductManifold(signature=[(1, 2), (-1, 2)])
X, y = embedders.gaussian_mixture.gaussian_mixture(pm, num_clusters=32, num_classes=8)



In [None]:
# Classify with basic NN
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split

import embedders.tree_new


class Net(nn.Module):
    def __init__(self, pm):
        super(Net, self).__init__()
        self.pm = pm
        self.fc1 = nn.Linear(6, 6)
        self.fc2 = nn.Linear(6, 6)
        self.fc3 = nn.Linear(6, 8)

    def forward(self, x):
        # x = torch.relu(self.fc1(x))
        # x = self.fc2(x)
        # return x
        x = self.pm.logmap(x)
        x = torch.relu(self.fc1(x))
        # x = self.pm.expmap(x)
        # x = self.pm.logmap(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net(pm)
print(net)

opt = torch.optim.Adam(net.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

for i in range(1_000):
    opt.zero_grad()
    y_pred = net(torch.tensor(X_train).float())
    loss = loss_fn(y_pred, torch.tensor(y_train).long())
    loss.backward()
    opt.step()
    if i % 100 == 0:
        print(loss.item())

y_pred = net(torch.tensor(X_test).float())
print("NN acc:", (y_pred.argmax(1) == torch.tensor(y_test)).float().mean().item())

pdt = embedders.tree_new.ProductSpaceDT(pm)
pdt.fit(X_train, y_train)
print("DT acc:", (pdt.predict(X_test) == y_test).float().mean().item())

Net(
  (fc1): Linear(in_features=6, out_features=6, bias=True)
  (fc2): Linear(in_features=6, out_features=6, bias=True)
  (fc3): Linear(in_features=6, out_features=8, bias=True)
)
2.0826311111450195
1.646939992904663
1.5910574197769165
1.563310146331787


  y_pred = net(torch.tensor(X_train).float())
  loss = loss_fn(y_pred, torch.tensor(y_train).long())


1.5462636947631836
1.5382691621780396
1.5338881015777588
1.5291764736175537
1.5273035764694214
1.5245232582092285
NN acc: 0.3199999928474426
DT acc: 0.38999998569488525


  y_pred = net(torch.tensor(X_test).float())
  print("NN acc:", (y_pred.argmax(1) == torch.tensor(y_test)).float().mean().item())


In [1]:
%load_ext autoreload 
%autoreload 2

import embedders

pm = embedders.manifolds.ProductManifold(signature=[(1, 2), (-1, 2)])
pm.dim

4

In [2]:
pm.ambient_dim

6

In [32]:
from embedders.neural import TangentMLPClassifier
from sklearn.model_selection import train_test_split

pm = embedders.manifolds.ProductManifold(signature=[(1, 2), (-1, 2)])
X, y = embedders.gaussian_mixture.gaussian_mixture(pm, num_clusters=32, num_classes=8)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

tmc = TangentMLPClassifier(pm, input_dim=6, hidden_dims=[6, 6], lr=0.01)
tmc.fit(X_train, y_train)
print("TMC acc:", (tmc.predict(X_test) == y_test).float().mean().item())

Epoch 0, Loss: nan
Epoch 100, Loss: nan
Epoch 200, Loss: nan
Epoch 300, Loss: nan
Epoch 400, Loss: nan
Epoch 500, Loss: nan
Epoch 600, Loss: nan
Epoch 700, Loss: nan
Epoch 800, Loss: nan
Epoch 900, Loss: nan
TMC acc: 0.14499999582767487


In [43]:
# super basic tangent plane mlp
import torch
import torch.nn as nn

X_train_tangent = pm.logmap(X_train)

net = nn.Sequential(
    nn.Linear(6, 6),
    nn.ReLU(),
    nn.Linear(6, 6),
    nn.ReLU(),
    nn.Linear(6, 8),
)

opt = torch.optim.Adam(net.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for i in range(10_000):
    opt.zero_grad()
    y_pred = net(torch.tensor(X_train_tangent).float())
    loss = loss_fn(y_pred, y_train)
    loss.backward()
    opt.step()
    if i % 1_000 == 0:
        print(loss.item())

X_test_tangent = pm.logmap(X_test)
y_pred = net(torch.tensor(X_test_tangent).float())
print("NN acc:", (y_pred.argmax(1) == y_test).float().mean().item())

  y_pred = net(torch.tensor(X_train_tangent).float())


2.171628475189209
1.4711434841156006
1.4593839645385742
1.4583150148391724
1.4568194150924683
1.4565598964691162
1.4560192823410034
1.4564324617385864
1.4585633277893066
1.456484317779541
NN acc: 0.3449999988079071


  y_pred = net(torch.tensor(X_test_tangent).float())


In [71]:
# Tangent plane MLP with no bias
import torch
import torch.nn as nn

net = nn.Sequential(
    nn.Linear(6, 8, bias=False),
)

opt = torch.optim.Adam(net.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for i in range(10_000):
    opt.zero_grad()
    y_pred = net(torch.tensor(X_train).float())
    loss = loss_fn(y_pred, y_train)
    loss.backward()
    opt.step()
    if i % 1_000 == 0:
        print(loss.item())

y_pred = net(torch.tensor(X_test).float())
print("NN acc:", (y_pred.argmax(1) == y_test).float().mean().item())

  y_pred = net(torch.tensor(X_train).float())


2.2725815773010254
1.7822155952453613
1.7822155952453613
1.7822155952453613
1.7822155952453613
1.7822157144546509
1.7822155952453613
1.7822155952453613
1.7822157144546509
1.7822158336639404
NN acc: 0.3700000047683716


  y_pred = net(torch.tensor(X_test).float())


In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
import numpy as np

from tqdm.notebook import tqdm


# Simple GNN model
class SimpleGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_weight=None):
        x = torch.relu(self.conv1(x, edge_index, edge_weight))
        x = torch.relu(self.conv2(x, edge_index, edge_weight))
        return self.conv3(x, edge_index, edge_weight)


# Create edges for a subset of nodes
def get_subset_edges(dist_matrix, node_indices):
    # Get submatrix of distances
    sub_dist = dist_matrix[node_indices][:, node_indices]

    # Create edges based on threshold
    threshold = sub_dist.mean()
    edges = (sub_dist < threshold).nonzero().t()

    return edges


def get_dense_edges(dist_matrix, node_indices):
    # Get submatrix of distances
    sub_dist = dist_matrix[node_indices][:, node_indices]

    # Create dense edges (all-to-all connections)
    n = len(node_indices)
    rows = torch.arange(n).repeat_interleave(n)
    cols = torch.arange(n).repeat(n)
    edge_index = torch.stack([rows, cols])

    # Get corresponding distances as edge weights
    edge_weights = sub_dist.flatten()

    # Convert distances to weights (you can modify this function)
    edge_weights = torch.exp(-edge_weights)  # Gaussian kernel
    # Alternative weightings:
    # edge_weights = 1 / (edge_weights + 1e-6)  # Inverse distance
    # edge_weights = torch.softmax(-edge_weights, dim=0)  # Softmax of negative distances

    return edge_index, edge_weights


# Setup
dist_matrix = pm.pdist(X).detach()
X_tangent = pm.logmap(X).detach()
train_idx, test_idx = train_test_split(np.arange(len(X)), test_size=0.2)

# Model, optimizer, loss
model = SimpleGNN(in_channels=6, hidden_channels=6, out_channels=8)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# Get edges for training set
# train_edges = get_subset_edges(dist_matrix, train_idx)
train_edges, train_weights = get_dense_edges(dist_matrix, train_idx)

# Move to Mac
model = model.to("mps")
X_tangent = X_tangent.to("mps")
y = y.to("mps")
train_edges = train_edges.to("mps")
train_weights = train_weights.to("mps")


# Training loop
my_tqdm = tqdm(range(10_000))
for i in my_tqdm:
    model.train()
    optimizer.zero_grad()

    # Only use training data
    X_train = X_tangent[train_idx]
    y_train = y[train_idx]

    # y_pred = model(X_train, train_edges)
    y_pred = model(X_train, train_edges, train_weights)
    loss = loss_fn(y_pred, y_train)
    loss.backward()
    optimizer.step()

    if i % 10 == 0:
        my_tqdm.set_postfix({"loss": f"{loss.item():.4f}"})

# Evaluate
model.eval()
with torch.no_grad():
    # Get edges for test set
    # test_edges = get_subset_edges(dist_matrix, test_idx)
    test_edges, test_weights = get_dense_edges(dist_matrix, test_idx)

    test_edges = test_edges.to("mps")
    test_weights = test_weights.to("mps")

    # Make predictions on test set
    X_test = X_tangent[test_idx]
    y_test = y[test_idx]

    # y_pred = model(X_test, test_edges)
    y_pred = model(X_test_tangent, test_edges, test_weights)
    acc = (y_pred.argmax(1) == y_test).float().mean().item()
print(f"Test accuracy: {acc:.4f}")

  0%|          | 0/10000 [00:00<?, ?it/s]

Test accuracy: 0.1700


In [None]:
# Centroid-based model
import geoopt

N_CLASSES = 8
centroids = torch.vstack([pm.sample() for _ in range(N_CLASSES)])

# make centroids a manifold parameter
centroids = geoopt.ManifoldParameter(centroids, manifold=pm)

# Define model: take distance to centroids as logits
class CentroidMLR(nn.Module):
    def __init__(self, pm, centroids):
        super().__init__()
        self.pm = pm
        self.centroids = centroids
        self.weights = nn.Parameter(torch.randn(N_CLASSES, 1))

    def forward(self, x):
        # p(y | h) = softmax(W h)
    


In [7]:
%load_ext autoreload
%autoreload 2

# Verify tangent MLP works
import embedders
import pandas as pd

pm = embedders.manifolds.ProductManifold(signature=[(1, 2), (-1, 2)])
X, y = embedders.gaussian_mixture.gaussian_mixture(pm, num_clusters=32, num_classes=8)
scores = embedders.benchmarks.benchmark(
    X, y, pm, models=["product_dt", "tangent_mlp", "ambient_mlp", "tangent_gnn", "ambient_gnn"]
)
pd.DataFrame(scores)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload




Unnamed: 0,product_dt,tangent_mlp,ambient_mlp,tangent_gnn,ambient_gnn
accuracy,0.3,0.275,0.265,0.22,0.22
f1-micro,0.3,0.275,0.265,0.22,0.22
time,0.024543,0.233128,0.255595,25.343984,23.289283
