In [1]:
import pandas as pd
import torch
import networkx as nx
import torch_geometric.datasets as dataset

import seaborn as sns
import matplotlib.pyplot as plt

from grakel.kernels import WeisfeilerLehman, VertexHistogram
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

# Metrics
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

# Types
from torch_geometric.datasets import GNNBenchmarkDataset
#from networkx import Graph

In [2]:
# Get dataset
myData = dataset.GNNBenchmarkDataset(
    root="data/homework02",
    name="CSL",
)

In [3]:
# Create a graph list from GNNBenchmarkDataset
graphs = myData.process_CSL()

In [4]:
def createGraph(data: GNNBenchmarkDataset) -> list:  
    """Create a graph in the GraphKitLeaning format
    using a graph from GNNBenchmarkDataset
    
    Input 
    ---
    data: a GNNBenchmarkDataset 

    Output
    ---
    list of [set of edges, dict of node labels, dict of edge labels]
    """ 
    # dictionary that keeps sets of edges
    Graphs = set()
    # dictionary of labels for nodes
    node_labels = dict()
    # dictionary of labels for edges
    edge_labels = dict()
    
    # Create Graph
    u = data.edge_index[0]
    v = data.edge_index[1]
    for index in range(data.num_edges):
        Graphs.add( (int(u[index]), int(v[index])) )

    # Create node label
    for count, index in enumerate(range(data.num_nodes)):
        node_labels[index] = count
    
    # Create edge label
    for index in range(data.num_edges):
        edge_labels[ (int(u[index]), int(v[index])) ] = 0
        edge_labels[ (int(v[index]), int(u[index])) ] = 0

    return [Graphs, node_labels, edge_labels]

In [5]:
def createGraphs(data: list[GNNBenchmarkDataset]) -> list: 
    """Create a list of lists of [set of edges, dict of node labels, dict of edge labels]

    Input
    ----------
    data : list[GNNBenchmarkDataset]

    Output
    -------
    A list graphs divided as [set of edges, dict of node labels, dict of edge labels]
    """
    
    graphs = list(map(createGraph, data))
    return graphs

In [6]:
def getYs(data: list[GNNBenchmarkDataset])-> list:
    """Extract Ys from a list of GNNBenchmarkDataset

    Input
    -----
    data: A list of a GNNBenchmarkDataset 

    Output
    -----
    A list of Ys
    """
    
    y = [ graph.y for graph in data ]
    return y

In [7]:
# Transform graphs from dataset to graphs from GraphKitLearning library
G = createGraphs(graphs)
# Extract Ys from graphs of the dateset
y = getYs(graphs)
# Create a WL extractor
wl_kernel = WeisfeilerLehman(n_iter=5, normalize=False)
# Divide dataset into train_set and test_set
G_train, G_test, y_train, y_test = train_test_split(G, y, test_size=0.3, random_state=42)

In [8]:
# Train model to create a WL kernel
wl_kernel.fit(G_train)

# Create WL kernel matrix using train set
K_train = wl_kernel.transform(G_train)
# Create WL kernel matrix using test set
K_test = wl_kernel.transform(G_test)

In [9]:
# Print WL kernel for the first train graph
K_train[0]

array([246.,  41.,  41., 106.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,
        41.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,
        41.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,  91.,  41.,  41.,
        41.,  41.,  41.,  41.,  41.,  87.,  91.,  41.,  41.,  41.,  41.,
        41.,  41.,  41., 106.,  41.,  41.,  91.,  41.,  41.,  41.,  41.,
        41.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,
        41.,  41.,  41.,  41.,  41.,  41.,  41.,  81.,  41.,  41.,  41.,
        41.,  41.,  41.,  87.,  41.,  41.,  41.,  41.,  41.,  41.,  41.,
        41.,  41.,  41.,  41.,  41.,  41.,  41.,  98.,  41.,  41.,  41.,
        41.,  41.,  41.,  41.,  41.,  41.])

In [10]:
# Create e SVC model
svc = SVC(kernel='precomputed')
# Train SVC model using WL kernel and real classification
svc.fit(K_train, y_train)
# Predict classifications for test set
y_pred = list(svc.predict(K_test))

In [11]:
print("Real classification\n", y_test, sep="")
print("Predict classification\n", y_pred, sep="")
print("Accuracy: %2.2f %%" %(round(accuracy_score(y_test, y_pred)*100)))
print("Confusion Matrix\n", confusion_matrix(y_test, y_pred), sep="")

Real classification
[4, 1, 7, 5, 5, 2, 4, 9, 4, 5, 7, 0, 2, 0, 1, 3, 6, 4, 3, 8, 1, 8, 1, 8, 8, 9, 7, 9, 3, 2, 1, 1, 4, 0, 2, 9, 3, 1, 0, 2, 9, 5, 5, 1, 0]
Predict classification
[4, 1, 7, 5, 5, 2, 4, 9, 4, 5, 7, 0, 2, 0, 1, 3, 6, 4, 3, 8, 1, 8, 1, 8, 8, 9, 7, 9, 3, 2, 1, 1, 4, 0, 2, 9, 3, 1, 0, 2, 9, 5, 5, 1, 0]
Accuracy: 100.00 %
Confusion Matrix
[[5 0 0 0 0 0 0 0 0 0]
 [0 8 0 0 0 0 0 0 0 0]
 [0 0 5 0 0 0 0 0 0 0]
 [0 0 0 4 0 0 0 0 0 0]
 [0 0 0 0 5 0 0 0 0 0]
 [0 0 0 0 0 5 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 0 3 0 0]
 [0 0 0 0 0 0 0 0 4 0]
 [0 0 0 0 0 0 0 0 0 5]]
