## Graph Neural Networks with Pytorch Geometric

Author: [Savannah Thais](https://github.com/savvy379)

First portion (Cora dataset) adapted from the [Pytorch Geometric Tutorial](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html)

Introduction
This notebook teaches the reader how to build Graph Neural Networks (GNNs) with Pytorch Geometric (PyG). The first portion walks through a simple GNN architecture applied to the Cora Dataset; it is a modified version of the PyG [Tutorial](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html) on node classifying GNNs. The second portion introduces a GNN-based charged particle tracking pipeline based on ["Charged particle tracking via edge-classifying interaction networks"](https://arxiv.org/abs/2103.16701).

A graph $G$ is a mathematical object consisting of a set of nodes (vertices) $N$ and edges $E$, $G=(N,E)$. Graphs can easily represent a wide range of structured data including atoms in molecules, users in a social network, cities and roads in a transportation system, players in team sports, objects interacting in a dynamical physical system, detector events , and more. The nodes and edges of the graph can have associated features as defined by the developer; these can include geometric information (ie particle hit locations in a detector) and non-geometric information (ie particle momenta). Graphs can be directed or undirected. A major advantage of GNNs is that they can handle input data of varying sizes: each graph processed by the network can have a different number of nodes and edges, making them well suited for a range of HEP applications as shown below for (a) particle tracking, (b) calorimeter clustering, (c) event classification, and (d) jet tagging (image from [Graph Neural Networks in Particle Physics](https://iopscience.iop.org/article/10.1088/2632-2153)).

![](https://drive.google.com/uc?id=1dGIKhsfeI7-Ik3pXL8t5Dsb6UyDohMiM)

In general, GNNs work by leveraging local information across the graph structure to intelligently re-embed the edges and nodes. A commonly used class of GNNs (and the focus of this tutorial) is the Graph Convolutional Network (GCN). GCNs use the same convolution operations as in a normal CNN, but apply the convolutions to node neighborhoods rather than a fixed data tensor (like an image matrix). These are considered a type of "Message Passing Network" when are 'message' is constructed by combining and transforming information from neighboring nodes and that message is 'passed' to a target node and used to update its features. In this way the entire graph can be transformed such that each node is updated to include additional useful information. The convolved graph is typically then processed by an additional linear layer that uses node and or edge features to do classification or regression on individual graph elements, graph substructures, or the graph as a whole. (Note: the same message passing structure can be applied to edges instead of nodes to update edge features as well). 

![](https://drive.google.com/uc?id=1f5MJO9Kw1tWjJMBJBsZ-M4tMFz_dydr2)

In mathematical terms, a single graph convolution layer can be described as: 

![](https://drive.google.com/uc?id=1hj82TSiKTwKTX9IFfpYRG7qygtSAAw6A)

Here, $h_v^0$ is the initial embedding of node $v$ (ie the original node features). To update the embedding of node $v$, $h_v^k$, the 'message' is constructed by applying a function $f$ to a the average over the current embedding of all the neighboring nodes (nodes connected to $v$ by a graph edge): $N(v)$) $\sum_{u \in N(v)}\frac{h_u^{k-1}}{deg(v)}$ and, optionally, the current embedding of the target node: $h_v^{k-1}$ then passed through a non-linear activation function and used to update the target node. In practice the functions $f$ is approximated by a matrix (convolution) $W_k$ with the added non-linearity of the activation function.   

In [1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

# Helper function for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

1.12.0+cu113
[K     |████████████████████████████████| 7.9 MB 3.0 MB/s 
[K     |████████████████████████████████| 3.5 MB 2.0 MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone
