In [1]:
%load_ext autoreload
%autoreload 2

In [16]:
import os

import networkx as nx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from geomstats.datasets.prepare_graph_data import HyperbolicEmbedding, Graph

from sklearn.model_selection import KFold

from src.hyperdt.tree import DecisionTreeClassifier, HyperbolicDecisionTreeClassifier

In [17]:
EPOCHS = 2

In [18]:
# Copied from 17_graphs_2.ipynb

def load_graph(graph_dir, graph_type="directed", edge_type="unweighted", add_isolates=False, top_cc=False):
    # Specify paths
    adjacency_path = f"{graph_dir}/adjacency.tsv"
    dense_adjacency_path = f"{graph_dir.replace('/raw/', '/interim/')}/adjacency_dense.tsv"
    labels_path = f"{graph_dir}/labels.tsv"
    label_names_path = f"{graph_dir}/names_labels.tsv"
    names_path = f"{graph_dir}/names.tsv"

    # Adjacency matrix: (out_node, in_node)
    adjacency = pd.read_table(adjacency_path, header=None, usecols=[0, 1])

    # Labels: (label, )
    labels = pd.read_table(labels_path, header=None, usecols=[0])[0]

    # Label name: (label_name, )
    if os.path.exists(label_names_path):
        label_names = pd.read_table(label_names_path, header=None, usecols=[0])[0]
    else:
        label_names = pd.Series(labels[0].unique()).reset_index()

    # Node name: (node_name, )
    if os.path.exists(names_path):
        names = pd.read_table(names_path, header=None, usecols=[0])[0]
    else:
        names = pd.Series(np.arange(len(labels))).reset_index()

    # Networkx object
    base_graph = nx.DiGraph if graph_type == "directed" else nx.Graph
    networkx_graph = nx.from_pandas_edgelist(adjacency, source=0, target=1, create_using=base_graph)
    if add_isolates:
        networkx_graph.add_nodes_from(names.index)

    # Add names and labels to nodes
    nx.set_node_attributes(networkx_graph, dict(zip(names.index, labels)), "label")
    networkx_graph = nx.relabel_nodes(networkx_graph, dict(zip(names.index, names)))

    # Get connected components
    if top_cc:
        networkx_graph = networkx_graph.subgraph(max(nx.connected_components(networkx_graph), key=len))

    # Pairwise distances
    distances = nx.floyd_warshall_numpy(networkx_graph)

    # Geomstats object
    dense_adjacency = nx.to_numpy_array(networkx_graph)
    np.savetxt(dense_adjacency_path, dense_adjacency, fmt="%d", delimiter="\t")
    geomstats_graph = Graph(graph_matrix_path=dense_adjacency_path, labels_path=labels_path)

    return {
        "labels": list(networkx_graph.nodes(data="label")),
        "label_names": list(label_names),
        "names": list(networkx_graph.nodes()),
        "networkx_graph": networkx_graph,
        "geomstats_graph": geomstats_graph,
        "distances": distances,
    }


polblogs = load_graph("data/raw/polblogs", graph_type="undirected", top_cc=True)

In [20]:
# Embeddings at different dimensions

for embed_dim in [2, 4, 8, 16]:
    hyp_emb = HyperbolicEmbedding(dim=embed_dim, n_negative=20, n_context=20, lr=1e-2, max_epochs=EPOCHS)
    polblogs_gs_embed = hyp_emb.embed(polblogs["geomstats_graph"])
    print(f"Embedding shape: {polblogs_gs_embed.shape}")


INFO: Number of edges: 1222
INFO: Mean vertices by edges: 27.357610474631752
INFO: iteration 0 loss_value 1.940278
