In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import embedders
import networkx as nx
import torch

In [4]:
_, _, adj = embedders.dataloaders.load("football")

Top CC has 35 nodes; original graph has 35 nodes.


In [3]:
# Make link prediction dataset


def make_link_prediction_dataset(X_embed, pm, adj, add_dists=True):
    # Stack embeddings
    emb = []
    for i in range(len(X_embed)):
        for j in range(len(X_embed)):
            joint_embed = torch.cat([X_embed[i], X_embed[j]])
            emb.append(joint_embed)

    X = torch.stack(emb)

    # Add distances
    if add_dists:
        dists = pm.pdist(X_embed)
        X = torch.cat([X, dists.flatten().unsqueeze(1)], dim=1)

    y = torch.tensor(adj.flatten())

    # Make a new signature
    new_sig = pm.signature + pm.signature
    if add_dists:
        new_sig.append((0, 1))
    new_pm = embedders.manifolds.ProductManifold(signature=new_sig)

    return X, y, new_pm

In [4]:
# Any other dataset
import embedders
from embedders.manifolds import ProductManifold
from embedders.tree_new import ProductSpaceDT
from embedders.dataloaders import load
from embedders.coordinate_learning import train_coords
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from tqdm.notebook import tqdm

# from sklearn.model_selection import train_test_split
from scipy.stats import wilcoxon
import numpy as np
import pandas as pd
import torch

USE_SPECIAL_DIMS = False
USE_DISTS = True
SIGNATURE = [(1, 2), (0, 2), (-1, 2)]
TEST_SIZE = 0.2
TOTAL_ITERATIONS = 3000
MAX_DEPTH = None
N_TRIALS = 100

In [13]:
# Karate club

for dataset in ["karate_club", ""]
dists, labels, adj = embedders.dataloaders.load("karate_club")
dists = dists / dists.max()

results = []
my_tqdm = tqdm(total=N_TRIALS)
while len(results) < N_TRIALS:
    pm = ProductManifold(signature=SIGNATURE)
    try:
        X_embed, losses = train_coords(
            pm,
            dists,
            burn_in_iterations=int(0.1 * TOTAL_ITERATIONS),
            training_iterations=int(0.9 * TOTAL_ITERATIONS),
            scale_factor_learning_rate=0.02,
        )
        assert not torch.isnan(X_embed).any()

        X, y, pm_new = make_link_prediction_dataset(X_embed, pm, adj, add_dists=USE_DISTS)

        res = embedders.benchmarks.benchmark(
            X, y, pm_new, max_depth=MAX_DEPTH, task="classification", use_special_dims=USE_SPECIAL_DIMS
        )
        res["d_avg"] = embedders.metrics.d_avg(pm.pdist(X_embed), dists).item()
        results.append(res)
        my_tqdm.update(1)

    except Exception as e:
        print(e)
        # print(f"Failed iteration {len(results)}")


# Print results
results = pd.DataFrame(results)
for col in results.columns:
    if col not in ["model", "d_avg"]:
        r = results[col]
        print(f"{col}: {r.mean():.4f} +/- {r.std() / np.sqrt(N_TRIALS):.4f}", end=" ")

        for col2 in results.columns:
            if col2 not in ["model", col, "d_avg"]:
                stat, p = wilcoxon(results[col], results[col2])
                if p < 0.05 / 6 and results[col].mean() > results[col2].mean():
                    print(f"> {col2}", end=" ")

        print()
print(f"d_avg: {results['d_avg'].mean():.4f} +/- {results['d_avg'].std() / np.sqrt(N_TRIALS):.4f}")

# Save results
results.to_csv("../data/graph_benchmarks/karate_club_link.tsv", index=False, sep="\t")

sklearn_dt: 0.9494 +/- 0.0017 > sklearn_rf > product_rf > tangent_rf > knn > ps_perceptron 
sklearn_rf: 0.9375 +/- 0.0018 > product_rf > knn > ps_perceptron 
product_dt: 0.9498 +/- 0.0015 > sklearn_rf > product_rf > tangent_rf > knn > ps_perceptron 
product_rf: 0.9234 +/- 0.0022 > knn > ps_perceptron 
tangent_dt: 0.9499 +/- 0.0016 > sklearn_rf > product_rf > tangent_rf > knn > ps_perceptron 
tangent_rf: 0.9394 +/- 0.0016 > product_rf > knn > ps_perceptron 
knn: 0.8827 +/- 0.0023 > ps_perceptron 
ps_perceptron: 0.7989 +/- 0.0138 
d_avg: 0.1046 +/- 0.0001


In [15]:
# Karate club

dists, labels, adj = embedders.dataloaders.load("lesmis")
dists = dists / dists.max()

results = []
my_tqdm = tqdm(total=N_TRIALS)
while len(results) < N_TRIALS:
    pm = ProductManifold(signature=SIGNATURE)
    try:
        X_embed, losses = train_coords(
            pm,
            dists,
            burn_in_iterations=int(0.1 * TOTAL_ITERATIONS),
            training_iterations=int(0.9 * TOTAL_ITERATIONS),
            scale_factor_learning_rate=0.02,
        )
        assert not torch.isnan(X_embed).any()

        X, y, pm_new = make_link_prediction_dataset(X_embed, pm, adj, add_dists=USE_DISTS)

        res = embedders.benchmarks.benchmark(
            X, y, pm_new, max_depth=MAX_DEPTH, task="classification", use_special_dims=USE_SPECIAL_DIMS
        )
        res["d_avg"] = embedders.metrics.d_avg(pm.pdist(X_embed), dists).item()
        results.append(res)
        my_tqdm.update(1)

    except Exception as e:
        print(e)
        # print(f"Failed iteration {len(results)}")


# Print results
results = pd.DataFrame(results)
for col in results.columns:
    if col not in ["model", "d_avg"]:
        r = results[col]
        print(f"{col}: {r.mean():.4f} +/- {r.std() / np.sqrt(N_TRIALS):.4f}", end=" ")

        for col2 in results.columns:
            if col2 not in ["model", col, "d_avg"]:
                stat, p = wilcoxon(results[col], results[col2])
                if p < 0.05 / 6 and results[col].mean() > results[col2].mean():
                    print(f"> {col2}", end=" ")

        print()
print(f"d_avg: {results['d_avg'].mean():.4f} +/- {results['d_avg'].std() / np.sqrt(N_TRIALS):.4f}")

# Save results
results.to_csv("../data/graph_benchmarks/lesmis_link.tsv", index=False, sep="\t")

Top CC has 77 nodes; original graph has 77 nodes.


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

sklearn_dt: 0.9729 +/- 0.0006 > product_dt > product_rf > knn > ps_perceptron 
sklearn_rf: 0.9754 +/- 0.0005 > sklearn_dt > product_dt > product_rf > tangent_dt > knn > ps_perceptron 
product_dt: 0.9714 +/- 0.0007 > product_rf > knn > ps_perceptron 
product_rf: 0.9632 +/- 0.0008 > knn > ps_perceptron 
tangent_dt: 0.9726 +/- 0.0005 > product_rf > knn > ps_perceptron 
tangent_rf: 0.9748 +/- 0.0006 > sklearn_dt > product_dt > product_rf > tangent_dt > knn > ps_perceptron 
knn: 0.9558 +/- 0.0008 > ps_perceptron 
ps_perceptron: 0.9125 +/- 0.0008 
d_avg: 0.0953 +/- 0.0002


In [5]:
# Adjnoun club

dists, labels, adj = embedders.dataloaders.load("adjnoun")
dists = dists / dists.max()

results = []
my_tqdm = tqdm(total=N_TRIALS)
while len(results) < N_TRIALS:
    pm = ProductManifold(signature=SIGNATURE)
    try:
        X_embed, losses = train_coords(
            pm,
            dists,
            burn_in_iterations=int(0.1 * TOTAL_ITERATIONS),
            training_iterations=int(0.9 * TOTAL_ITERATIONS),
            scale_factor_learning_rate=0.02,
        )
        assert not torch.isnan(X_embed).any()

        X, y, pm_new = make_link_prediction_dataset(X_embed, pm, adj, add_dists=USE_DISTS)

        res = embedders.benchmarks.benchmark(
            X, y, pm_new, max_depth=MAX_DEPTH, task="classification", use_special_dims=USE_SPECIAL_DIMS
        )
        res["d_avg"] = embedders.metrics.d_avg(pm.pdist(X_embed), dists).item()
        results.append(res)
        my_tqdm.update(1)

    except Exception as e:
        print(e)
        # print(f"Failed iteration {len(results)}")


# Print results
results = pd.DataFrame(results)
for col in results.columns:
    if col not in ["model", "d_avg"]:
        r = results[col]
        print(f"{col}: {r.mean():.4f} +/- {r.std() / np.sqrt(N_TRIALS):.4f}", end=" ")

        for col2 in results.columns:
            if col2 not in ["model", col, "d_avg"]:
                stat, p = wilcoxon(results[col], results[col2])
                if p < 0.05 / 6 and results[col].mean() > results[col2].mean():
                    print(f"> {col2}", end=" ")

        print()
print(f"d_avg: {results['d_avg'].mean():.4f} +/- {results['d_avg'].std() / np.sqrt(N_TRIALS):.4f}")

# Save results
results.to_csv("../data/graph_benchmarks/adjnoun_link.tsv", index=False, sep="\t")

Top CC has 112 nodes; original graph has 112 nodes.


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

KeyboardInterrupt: 