### import the necessary modules

In [13]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier

from gtda.homology import VietorisRipsPersistence
from gtda.diagrams import PersistenceEntropy, PersistenceImage
import gtda
print(gtda.__version__)

from torch_geometric.datasets import TUDataset

import networkx as nx

0.6.2


### step 1: load the dataset

In [14]:
dataset = TUDataset(root='data/TUDataset', name='MUTAG')

# extracting the graphs from the dataset
graphs = [data for data in dataset]

### step 2: node filtration values

- we assign a metric to track each graph, here, we deploy degree
- setting up a threshold t and then progressively increasing it. As we do this, we only read/learn from the graphs with degree <= t, and therefore, allow more graphs as t is made to increase.
- learning and keeping a track of these graphs progressively is done in the next step - persistent homology

In [15]:
graphs = [nx.Graph() for _ in dataset]

for i, g in enumerate(graphs):
    g.add_edges_from(dataset[i].edge_index.t().tolist())

adj_matrices = [nx.to_numpy_array(g) for g in graphs]

# degree_filtration = dict(adj_matrices.degree())

### step 3: persistent homology pipeline

In [16]:
VR = VietorisRipsPersistence(homology_dimensions=[0, 1])
diagrams = VR.fit_transform(adj_matrices)

  check_point_clouds(X, accept_sparse=True,
  X = check_point_clouds(X, accept_sparse=True,


### step 4: vectorize persistence diagrams

In [17]:
persistence_image = PersistenceImage(n_bins=100, n_jobs=None)
X = persistence_image.fit_transform(diagrams)

### step 5: train the model

In [18]:
y = np.array([d.y.item() for d in dataset])   # shape (188,)
X_train, X_test, y_train, y_test = train_test_split(X.reshape(len(X), -1), y, test_size=0.2)

clf = RandomForestClassifier(n_estimators = 100)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

### step 6: analyze accuracy

In [20]:
print("accuracy: ", accuracy_score(y_test, y_pred)*100, "%")

accuracy:  89.47368421052632 %
