In [1]:
from dgl.data import CoraGraphDataset

In [2]:
dataset = CoraGraphDataset()
nodes = dataset[0].nodes().numpy()
labels = dataset[0].ndata["label"].numpy()

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [3]:
from graphflex import GraphFlex
from graphflex.connectors.dgl import DGLConnector
from graphflex.functions.postprocessing.filter import NonUniqueFeatureFilter
from graphflex.functions.feature import MeanStdFeature

dgl_connect = DGLConnector(dataset)
gflex = GraphFlex(
          connector=dgl_connect,
          node_feature=MeanStdFeature(),
          post_processor=NonUniqueFeatureFilter()
        )

In [4]:
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression

pipe = Pipeline([('graphflex', gflex), ('logreg', LogisticRegression())])

In [5]:
from sklearn.model_selection import GridSearchCV
param_grid = {
    'logreg__C': [0.01, 0.1, 1, 10],
    'graphflex__max_depth': [1, 2],  # l1 also possible with saga solver
}

grid = GridSearchCV(pipe, param_grid, cv=3, scoring='accuracy', n_jobs=4, verbose=4)

In [6]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report


train_nodes, test_nodes, train_labels, test_labels = (
    train_test_split(nodes, labels, test_size=0.2, random_state=42, stratify=labels))

grid.fit(train_nodes, train_labels)

print("Best Params:", grid.best_params_)
print("Best Score:", grid.best_score_)


y_pred = grid.predict(test_nodes)
print(classification_report(test_labels, y_pred))

Fitting 3 folds for each of 8 candidates, totalling 24 fits
Best Params: {'graphflex__max_depth': 2, 'logreg__C': 10}
Best Score: 0.8559556786703602
              precision    recall  f1-score   support

           0       0.73      0.77      0.75        70
           1       0.86      0.84      0.85        43
           2       0.92      0.92      0.92        84
           3       0.82      0.91      0.86       164
           4       0.95      0.86      0.90        85
           5       0.96      0.80      0.87        60
           6       0.91      0.83      0.87        36

    accuracy                           0.86       542
   macro avg       0.88      0.85      0.86       542
weighted avg       0.87      0.86      0.86       542

