# GraphNN Classification


In [1]:
from photonai.base import Hyperpipe, PipelineElement
from photonai_graph.GraphUtilities import get_random_connectivity_data, get_random_labels
from sklearn.model_selection import KFold

Make random matrices to simulate connectivity matrices

In [2]:
X = get_random_connectivity_data(number_of_nodes=50, number_of_individuals=100)
y = get_random_labels(l_type="classification", number_of_labels=100)

Design your Pipeline
We add a simple GraphConstructor and a GCN Classifier

In [3]:
my_pipe = Hyperpipe('basic_gembedding_pipe',
                    inner_cv=KFold(n_splits=5),
                    outer_cv=KFold(n_splits=5),
                    optimizer='sk_opt',
                    optimizer_params={'n_configurations': 25},
                    metrics=['accuracy', 'balanced_accuracy', 'recall', 'precision'],
                    best_config_metric='mean_absolute_error')

my_pipe.add(PipelineElement('GraphConstructorThreshold', threshold=0.95))

my_pipe.add(PipelineElement('GCNClassifier'))

Using backend: pytorch


Finally we simply fit the hyperpipe to our data.

In [4]:
my_pipe.fit(X, y)

PHOTONAI ANALYSIS: basic_gembedding_pipe

*****************************************************************************************************
Outer Cross validation Fold 1
*****************************************************************************************************
Did not find any hyperparameter to convert into skopt space.


100%|██████████| 200/200 [00:06<00:00, 29.92it/s]
100%|██████████| 200/200 [00:06<00:00, 30.38it/s]
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.60it/s]
100%|██████████| 200/200 [00:06<00:00, 30.45it/s]
100%|██████████| 200/200 [00:06<00:00, 30.53it/s]

-----------------------------------------------------------------------------------------------------
BEST_CONFIG 
-----------------------------------------------------------------------------------------------------
{}
-----------------------------------------------------------------------------------------------------
VALIDATION PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5375      |      0.4250      |
|  balanced_accuracy  |       0.5189      |      0.4698      |
|        recall       |       0.7690      |      0.7111      |
|      precision      |       0.6263      |      0.3659      |
| mean_absolute_error |       0.4625      |      0.5750      |
+---------------------+-------------------+------


100%|██████████| 200/200 [00:09<00:00, 20.96it/s]


-----------------------------------------------------------------------------------------------------
TEST PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5375      |      0.4000      |
|  balanced_accuracy  |       0.5338      |      0.5238      |
|        recall       |       0.6829      |      0.8333      |
|      precision      |       0.5385      |      0.3125      |
| mean_absolute_error |       0.4625      |      0.6000      |
+---------------------+-------------------+------------------+

*****************************************************************************************************
Outer Cross validation Fold 2
*******************************************************************************

100%|██████████| 200/200 [00:06<00:00, 30.17it/s]
100%|██████████| 200/200 [00:06<00:00, 30.67it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 29.80it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.60it/s]
100%|██████████| 200/200 [00:06<00:00, 29.91it/s]

-----------------------------------------------------------------------------------------------------
BEST_CONFIG 
-----------------------------------------------------------------------------------------------------
{}
-----------------------------------------------------------------------------------------------------
VALIDATION PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5437      |      0.4875      |
|  balanced_accuracy  |       0.5094      |      0.4685      |
|        recall       |       0.1163      |      0.0400      |
|      precision      |       0.2125      |      0.0667      |
| mean_absolute_error |       0.4562      |      0.5125      |
+---------------------+-------------------+------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:09<00:00, 21.02it/s]
  _warn_prf(average, modifier, msg_start, len(result))


-----------------------------------------------------------------------------------------------------
TEST PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5375      |      0.5000      |
|  balanced_accuracy  |       0.5000      |      0.5000      |
|        recall       |       0.0000      |      0.0000      |
|      precision      |       0.0000      |      0.0000      |
| mean_absolute_error |       0.4625      |      0.5000      |
+---------------------+-------------------+------------------+

*****************************************************************************************************
Outer Cross validation Fold 3
*******************************************************************************

  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.17it/s]
100%|██████████| 200/200 [00:06<00:00, 30.37it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.30it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.48it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.19it/s]

-----------------------------------------------------------------------------------------------------
BEST_CONFIG 
-----------------------------------------------------------------------------------------------------
{}
-----------------------------------------------------------------------------------------------------
VALIDATION PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5500      |      0.4875      |
|  balanced_accuracy  |       0.5125      |      0.5073      |
|        recall       |       0.1500      |      0.1600      |
|      precision      |       0.1091      |      0.0667      |
| mean_absolute_error |       0.4500      |      0.5125      |
+---------------------+-------------------+------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:09<00:00, 21.12it/s]
  _warn_prf(average, modifier, msg_start, len(result))


-----------------------------------------------------------------------------------------------------
TEST PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5375      |      0.5000      |
|  balanced_accuracy  |       0.5000      |      0.5000      |
|        recall       |       0.0000      |      0.0000      |
|      precision      |       0.0000      |      0.0000      |
| mean_absolute_error |       0.4625      |      0.5000      |
+---------------------+-------------------+------------------+

*****************************************************************************************************
Outer Cross validation Fold 4
*******************************************************************************

  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 29.83it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.35it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.37it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.45it/s]
100%|██████████| 200/200 [00:06<00:00, 30.42it/s]

-----------------------------------------------------------------------------------------------------
BEST_CONFIG 
-----------------------------------------------------------------------------------------------------
{}
-----------------------------------------------------------------------------------------------------
VALIDATION PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5750      |      0.4750      |
|  balanced_accuracy  |       0.5278      |      0.4564      |
|        recall       |       0.1161      |      0.0400      |
|      precision      |       0.1286      |      0.0250      |
| mean_absolute_error |       0.4250      |      0.5250      |
+---------------------+-------------------+------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:09<00:00, 20.80it/s]
  _warn_prf(average, modifier, msg_start, len(result))


-----------------------------------------------------------------------------------------------------
TEST PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5500      |      0.4500      |
|  balanced_accuracy  |       0.5000      |      0.5000      |
|        recall       |       0.0000      |      0.0000      |
|      precision      |       0.0000      |      0.0000      |
| mean_absolute_error |       0.4500      |      0.5500      |
+---------------------+-------------------+------------------+

*****************************************************************************************************
Outer Cross validation Fold 5
*******************************************************************************

  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.44it/s]
100%|██████████| 200/200 [00:06<00:00, 29.48it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.44it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.51it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:06<00:00, 30.36it/s]
  _warn_prf(average, modifier, msg_start, len(result))


-----------------------------------------------------------------------------------------------------
BEST_CONFIG 
-----------------------------------------------------------------------------------------------------
{}
-----------------------------------------------------------------------------------------------------
VALIDATION PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5469      |      0.4625      |
|  balanced_accuracy  |       0.5094      |      0.5000      |
|        recall       |       0.1938      |      0.2000      |
|      precision      |       0.1051      |      0.0625      |
| mean_absolute_error |       0.4531      |      0.5375      |
+---------------------+-------------------+------

  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:09<00:00, 20.87it/s]
  _warn_prf(average, modifier, msg_start, len(result))


-----------------------------------------------------------------------------------------------------
TEST PERFORMANCE
-----------------------------------------------------------------------------------------------------
+---------------------+-------------------+------------------+
|        METRIC       | PERFORMANCE TRAIN | PERFORMANCE TEST |
+---------------------+-------------------+------------------+
|       accuracy      |       0.5375      |      0.5000      |
|  balanced_accuracy  |       0.5000      |      0.5000      |
|        recall       |       0.0000      |      0.0000      |
|      precision      |       0.0000      |      0.0000      |
| mean_absolute_error |       0.4625      |      0.5000      |
+---------------------+-------------------+------------------+
*****************************************************************************************************
Finished all outer fold computations.


  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 200/200 [00:11<00:00, 17.25it/s]

*****************************************************************************************************

Project Folder: /Users/jan/PycharmProjects/photonai_graph/examples/gcn_examples/basic_gembedding_pipe_results_2022-03-29_22-10-54,
Computation Time: 2022-03-29 22:10:54.879967 - 2022-03-29 22:14:31.327652
Duration: 0:03:36.447685
Optimized for: mean_absolute_error
Hyperparameter Optimizer: sk_opt

+-------------------+--+
| PERFORMANCE DUMMY |  |
+-------------------+--+
+-------------------+--+

+---------------------+---------------+--------------+-----------+----------+
|     Metric Name     | Training Mean | Training Std | Test Mean | Test Std |
+---------------------+---------------+--------------+-----------+----------+
|       accuracy      |      0.54     |    0.005     |    0.47   |   0.04   |
|  balanced_accuracy  |    0.506754   |   0.013508   |  0.504762 | 0.009524 |
|        recall       |    0.136585   |   0.273171   |  0.166667 | 0.333333 |
|      precision      |    0.




Hyperpipe(name='basic_gembedding_pipe')