# Graph Learning

## Lab 7: Graph Neural Networks

In this lab, you will learn to classify nodes using a graph neural network (GNN).

## Import

In [None]:
import numpy as np
from scipy import sparse
import matplotlib.pyplot as plt

In [None]:
from sknetwork.classification import get_accuracy_score
from sknetwork.data import load_netset
from sknetwork.embedding import Spectral
from sknetwork.gnn import GNNClassifier
from sknetwork.utils import directed2undirected
from sklearn import metrics

from sklearn.manifold import TSNE

## Data

We will work on the following datasets (see the [NetSet](https://netset.telecom-paris.fr/) collection for details):
* Cora (directed graph + bipartite graph)
* WikiVitals (directed graph + bipartite graph)

Both datasets are graphs with node features (given by the bipartite graph) and ground-truth labels.

In [None]:
cora = load_netset('cora')
wikivitals = load_netset('wikivitals')

In [None]:
def visualize_embedding(embedding, labels, size=(6,6)):
    """Visualize embedding in 2 dimensions using TSNE. """
    print("Computing TSNE...")
    tsne = TSNE(random_state=8).fit_transform(embedding)
    fig, ax = plt.subplots(1, 1, figsize=size)
    plt.scatter(tsne[:, 0], tsne[:, 1], c=labels, s=50, cmap='hsv')
    plt.xticks([])
    plt.yticks([])
    plt.show()

## 1. Cora

We start with the Cora dataset. We check the embedding of the nodes before and after learning, and the impact of the GNN architecture on accuracy.

In [None]:
dataset = cora

In [None]:
adjacency = dataset.adjacency
features = dataset.biadjacency
labels_true = dataset.labels

In [None]:
# we use undirected graphs
adjacency = directed2undirected(adjacency)

## To do

Consider a GNN with a single hidden layer of dimension 16.

* Run a single forward pass on the data, without learning.
* Display the embedding provided by the hidden layer.

In [None]:
hidden_dim = 16

In [None]:
n_labels = len(set(labels_true))

In [None]:
gnn = GNNClassifier(dims=[hidden_dim, n_labels], verbose=True)

In [None]:
gnn

In [None]:
output = gnn.forward(adjacency, features)

In [None]:
embedding = gnn.layers[0].embedding

In [None]:
visualize_embedding(embedding, labels_true)

## To do

We now train the GNN.

* Train the GNN with 50% / 50% train / test split.
* Give the accuracy of the classification on the train and test sets. 
* Give the total number of parameters.
* Display the embedding provided by the hidden layer.

In [None]:
# train / test split
train_ratio = 0.5
labels = labels_true.copy()
train_mask = np.random.random(size=len(labels)) < train_ratio
test_mask = ~train_mask
labels[test_mask] = -1

In [None]:
gnn.fit(adjacency, features, labels)

In [None]:
labels_pred = gnn.predict()

## To do

* Retrain the GNN with an empty graph.
* Compare the accuracy of the classification with that of the previous model. 
* Comment the results. <br>What is the learning model?

In [None]:
empty = sparse.csr_matrix(adjacency.shape)

## To do

We now consider a hidden layer of dimension 32.

* Retrain the GNN (with the graph).
* Give the accuracy of the classification and the number of parameters.
* Comment the results.

## To do

Finally , we take 2 hidden layers, each of dimension 16.

* Retrain the GNN.
* Give the accuracy of the classification and the number of parameters.
* Comment the results.

## 2. Wikivitals

We now focus on Wikivitals. We take the spectral embedding of the article-word bipartite graph as features.

In [None]:
dataset_WV = wikivitals

In [None]:
adjacency_WV = dataset_WV.adjacency
biadjacency_WV = dataset_WV.biadjacency
names_WV = dataset_WV.names
labels_true_WV = dataset_WV.labels
names_labels_WV = dataset_WV.names_labels

In [None]:
# we consider the graph as undirected
adjacency_WV = directed2undirected(adjacency_WV)

In [None]:
# we use the spectral embedding of the bipartite graph as features
spectral = Spectral(20)
features_WV = spectral.fit_transform(biadjacency_WV)

## To do

We consider a GNN with a single hidden layer of dimension 16.
* Train the GNN with 50% / 50% train / test split.
* Give the accuracy of the classification.
* Display the confusion matrix of the test set.
* Give for each label the 5 articles of the test set classified with the highest confidence.

In [None]:
n_labels_WV = len(set(labels_true_WV))

In [None]:
print(set(labels_true_WV))
print(np.unique(labels_true_WV))

In [None]:
gnn_WV = GNNClassifier(dims=[hidden_dim, n_labels_WV], verbose=True)

In [None]:
# train / test split
train_ratio = 0.5
labels_WV = labels_true_WV.copy()
train_mask = np.random.random(size=len(labels_WV)) < train_ratio
test_mask = ~train_mask
labels_WV[test_mask] = -1

In [None]:
gnn_WV.fit(adjacency_WV, features_WV, labels_WV)

In [None]:
labels_pred_WV = gnn_WV.predict()

train_acc_WV = get_accuracy_score(labels_true_WV[train_mask], labels_pred_WV[train_mask])
test_acc_WV = get_accuracy_score(labels_true_WV[test_mask], labels_pred_WV[test_mask])

print(f'WikiVitals train accuracy: {train_acc_WV:.4f}')
print(f'WikiVitals test accuracy: {test_acc_WV:.4f}')

In [None]:
valid_labels = np.arange(len(names_labels_WV)) 

confusion_matrix = metrics.confusion_matrix(labels_true_WV[test_mask], labels_pred_WV[test_mask], labels=valid_labels)
cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=names_labels_WV)

cm_display.plot()
plt.xticks(rotation=90, ha='center')
plt.show()

In [None]:

probs = gnn_WV.predict_proba()

for i, name in enumerate(names_labels_WV):

    predicted = (labels_pred_WV[test_mask] == i)
    if predicted.sum() > 0:
 
        confidences = probs[test_mask][predicted, i]

        article_indices = np.where(test_mask)[0][predicted]

        top5_idx = np.argsort(confidences)[-5:][::-1]
        print(f"\nLabel {name} - Top 5 articles:")
        for i, idx in enumerate(top5_idx):
            if i < len(article_indices):
                article_name = names_WV[article_indices[idx]]
                confidence = confidences[idx]
                print(f"  {article_name}: {confidence:.3f}")

## To do

Compare the results with those obtained with:
* Heat diffusion on the graph.
* Logistic regression on the features.