<a href="https://colab.research.google.com/github/realfolkcode/GRAFF/blob/main/GRAFF_Tutorial_PyG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tutorial: *GNNs as Gradient Flows*

[Original paper](https://arxiv.org/pdf/2206.10991.pdf) (Di Giovanni, Francesco, et al. "Graph Neural Networks as Gradient Flows: understanding graph convolutions via energy")

[Michael Bronstein's blogpost](https://towardsdatascience.com/graph-neural-networks-as-gradient-flows-4dae41fb2e8a)

In this tutorial, we implement GRAFF (Gradient Flow Framework) using [PyG](https://www.pyg.org/), a pure PyTorch library for GNNs.

$\newcommand{\matr}[1]{\mathbf{#1}}$


---



## A brief theoretical overview

When designing a GNN, one could face the following potential problems:
- **Over-smoothing** of node features with the increase of the model depth 
- Poor performance on **heterophilic** data (i.e., neighboring nodes are vastly dissimilar)

GRAFF alleviates these issues by viewing GNNs as gradient flows and carefully parametrizing the weight matrix.



---



### ODE perspective

Let us denote the node feature matrix by $\matr{F}$, and normalized adjacency matrix by $\bar{\matr{A}} := \matr{D}^{-\frac12} \matr{A} \matr{D}^{-\frac12}$, where $\matr{D}$ contains the node degrees on its diagonal.

GNNs (with residual connections) can be seen as the discretizations of differential equations that govern the evolution of node features $\matr{F}$ given a graph $G$ (possibly with the addition of self-loops):

$$\dot{\matr{F}}(t) = \operatorname{GNN}_{\theta(t)}(G, \matr{F}(t)), \quad \matr{F}(0) = \matr{F},$$

where $\operatorname{GNN}_{\theta(t)}$ is a GNN layer at time $t$ with parameters $\theta(t)$. 

For example, a residual GCN corresponds to the following Euler discretization:

$$\matr{F}(t + 1) = \matr{F}(t) + \sigma\left( \bar{\matr{A}} \matr{F}(t) \matr{W}_t \right),$$

where $\sigma$ is a non-linearity, and $\matr{W}_t$ is the channel-mixing matrix of layer $t$.


---



### Energy perspective

Gradient flows are *the class of differential equations that minimize the energy functional*. If the energy is parametrized with the set of parameters $\theta$, then the gradient flow of $\mathcal{E}_{\theta}$ is defined as follows:

$$\dot{\matr{F}}(t) = -\nabla \mathcal{E}_{\theta}(\matr{F}(t))$$

Let $\matr{f}_i$ denote the transposed $i$-th row in $\matr{F}$ (such that the feature vector of node $i$ is now a column vector).

GRAFF considers a class of energies that can be decomposed into 
- edge-agnostic components $\mathcal{E}_{\matr{\Omega}}^{\textrm{ext}}$ ("external field")
- pairwise interactions $\mathcal{E}_{\matr{W}}^{\textrm{pair}}$
- the source terms $\mathcal{E}_{\tilde{\matr{W}}}^{\textrm{source}}$

$$\mathcal{E}_{\theta}(\matr{F}) = \underbrace{\frac12 \sum_i \left< \matr{f}_i, \matr{\Omega} \matr{f}_i \right>}_{\mathcal{E}_{\matr{\Omega}}^{\textrm{ext}}} - \underbrace{\frac12 \sum_{i,j} \bar{\matr{A}}_{ij} \left< \matr{f}_i, \matr{W} \matr{f}_j \right>}_{\mathcal{E}_{\matr{W}}^{\textrm{pair}}} + \underbrace{\sum_i \left< \matr{f}_i, \tilde{\matr{W}} \matr{f}_i(0) \right>}_{\mathcal{E}_{\tilde{\matr{W}}}^{\textrm{source}}},$$

where $\matr{\Omega}, \matr{W}, \tilde{\matr{W}}$ are learnable *square* matrices. 

Differentiating the energy yields an equivalent *gradient flow* formulation, where $\matr{\Omega}$ and $\matr{W}$ are **symmetric**:

$$\dot{\matr{F}}(t) = -\matr{F}(t) \matr{\Omega} + \bar{\matr{A}} \matr{F}(t) \matr{W} - \matr{F}(0) \tilde{\matr{W}}$$

The Euler discretization:

$$\matr{F}(t + \tau) = \matr{F}(t) + \tau \left( -\matr{F}(t) \matr{\Omega} + \bar{\matr{A}} \matr{F}(t) \matr{W} - \matr{F}(0) \tilde{\matr{W}} \right) $$

Note how the middle term $\bar{\matr{A}} \matr{F}(t) \matr{W}$ is reminiscent to the GCN dynamics. One important difference is that the learnable parameters $\matr{\Omega}, \matr{W}, \tilde{\matr{W}}$ are *shared* across all the layers.



---



### What makes GRAFF great

The energy functional can be rearranged by highlighting the positive and negative eigenvalues of the channel-mixing matrix $\matr{W}$ separately. In short, the authors show that the spectrum of $\matr{W}$ encodes the "mood" of edge-wise interactions in a graph:

- Positive eigenvalues make interactions *attractive* 🤗, i.e., neighboring nodes become *similar* (corresponds to magnifying the *low* frequencies, *homophilic* scenario) 
- Negative eigenvalues make Interactions *repulsive* 😒, i.e. neighboring nodes become *dissimilar* (corresponds to magnifying the *high* frequencies, *heterophilic* scenario)

Further, they show that without the residual connections, GRAFF is limited only to the homophilic setting, and is vulnerable to over-smoothing. Hence, the residual connections are crucial if we want good performance on both settings.

## Practice

### Install and Import

In [1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

2.0.0+cu118


In [None]:
!pip install torch_geometric
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}.html

In [3]:
import time

import torch
import torch.nn.functional as F
from torch.nn import Linear, Parameter
import torch.nn.utils.parametrize as parametrize
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, homophily

from torch_geometric.datasets import WebKB, Planetoid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### GCN Implementation

First, we start by implementing a GCN layer. We slightly modify the example provided in the [Creating Message Passing Layers](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html) tutorial.

Recall that GCN is defined as

$$\begin{align}
\matr{F}(t + 1) &= \sigma\left( \bar{\matr{A}} \matr{F}(t) \matr{W}_t \right) \\
&= \sigma \left( \sum_{j \in \mathcal{N}(i)} \frac{1}{\sqrt{D_{ii} \cdot D_{jj}}} \cdot \left( \matr{W}_t \cdot \matr{f}_j(t) \right) \right)
\end{align}$$

The first equation is given in a matrix form, whereas the second equation leverages the *message passing* formulation which encompasses a broader class of graph convolutions. **PyG** provides the `MessagePassing` class, from which we can inherit to implement our layer.

In the code, `x` denotes features $\matr{F}(t) \in \mathbb{R}^{n \times in}$. The result of convolution, $\matr{F}(t+1)$, has the dimensionality of $n \times out$. 

In [4]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.W = Linear(in_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        self.W.reset_parameters()

    def forward(self, x, edge_index):
        # Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Linearly transform node feature matrix.
        x = self.W(x)

        # Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        return out

    def message(self, x_j, norm):
        # Normalize node features.
        return norm.view(-1, 1) * x_j

Next, we define our network as a stacking of two GCN layers with the ReLU non-linearity.

In [5]:
class GCNNet(torch.nn.Module):
    def __init__(self, dataset, num_hidden):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, num_hidden)
        self.conv2 = GCNConv(num_hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        return F.log_softmax(x, dim=1)

### GRAFF Implementation

Let us summarize the main differences of GRAFF compared to GCN:

- All the learnable parameters are shared across the layers

- Besides the weight matrix $\matr{W}$ (pairwise interactions), there are also learnable matrices $\matr{\Omega}$ ("external field") and $\tilde{\matr{W}}$ ("source")

- $\matr{\Omega}, \matr{W} \in \mathbb{R}^{d \times d}$ and are symmetric

- It takes the initial feature matrix $\matr{F}(0)$ as an additional argument

We can keep the message passing logic of GCN intact. Since all the weights are shared, we do not initialize them inside a layer. Instead, we pass the already initialized linear layers as arguments `ext_lin` ($\matr{\Omega}$), `pair_lin` ($\matr{W}$), `source_lin` ($\tilde{\matr{W}}$). Actually, we can reuse the same layer again and again because it does not contain any parameters. In this tutorial, we pass the weights as arguments to first give a high-level look at GRAFF.

Another difference is that we make the addition of self-loops optional.

In [6]:
class GRAFFConv(MessagePassing):
    def __init__(self, ext_lin, pair_lin, source_lin, self_loops=True):
        super().__init__(aggr='add')
        self.ext_lin = ext_lin
        self.pair_lin = pair_lin
        self.source_lin = source_lin
        self.self_loops = self_loops

    def forward(self, x, edge_index, x0):
        # (Optionally) Add self-loops to the adjacency matrix.
        if self.self_loops:
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Linearly transform node feature matrix.
        out = self.pair_lin(x)

        # Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Start propagating messages.
        out = self.propagate(edge_index, x=out, norm=norm)

        # Add the external and source contributions
        out -= self.ext_lin(x) + self.source_lin(x0)

        return out

    def message(self, x_j, norm):
        # Normalize node features.
        return norm.view(-1, 1) * x_j

While there are many variants of gradient flow parametrizations, we are going to focus on the GRAFF with *diagonally dominant* $\matr{W}$:

$$\matr{F}(t + \tau) = \matr{F}(t) + \tau \sigma \left( -\matr{F}(t) \operatorname{diag}(\matr{\omega}) + \bar{\matr{A}} \matr{F}(t) \matr{W} - \beta \matr{F}(0) \right)$$

Here, we have
- $\matr{\Omega} := \operatorname{diag}(\matr{\omega})$, where $\matr{\omega} \in \mathbb{R}^{d}$
- $\matr{W} := \matr{W^0} + \operatorname{diag}(\matr{w})$, where $\matr{W}^0$ is symmetric with zero diagonal, and $\matr{w}$ defined by $\matr{w}_{\alpha} = q_{\alpha} \sum_{\beta} | \matr{W}^0_{\alpha \beta} | + r_{\alpha}$
- $\tilde{\matr{W}} := \beta \matr{I}$, where $\beta$ is a scalar, and $\matr{I}$ is the identity matrix of size $d$

Notice that this variant also features the non-linearity $\sigma$.  Strictly speaking, this is not a gradient flow anymore. Nevertheless, the energy is decreasing along the solution of the gradient flow equation if $\sigma$ satisfies $x \sigma(x) \geq 0$.


---



First, let's implement the "external field" layer. A naive way to implement the multiplication $\matr{F} \cdot \operatorname{diag}(\matr{\omega})$ would be to explicitly construct matrix $\operatorname{diag}(\matr{\omega})$ and multiply it with $\matr{F}$. However, it is easy to see that this is equivalent to the elementwise multiplication of each row in $\matr{F}$ with the row vector $\matr{\omega}^{\intercal}$. Therefore, we can efficiently implement it by leveraging broadcasting.

In [7]:
class External(torch.nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty((1, num_features)))
        self.reset_parameters()
    
    def reset_parameters(self):
        torch.nn.init.normal_(self.weight)

    def forward(self, x):
        return x * self.weight

Similarly, we implement the multiplication with $\matr{W}$ as a linear layer which we wrap inside the `Pairwise` class. 

Here, $\matr{W}$ has a special structure which must be preserved during training. To account for this, we use PyTorch's [parametrization](https://pytorch.org/tutorials/intermediate/parametrizations.html) functionality. The `forward` method of `PairwiseParametrization` is a function of weight that imposes symmetry and a diagonally dominant structure. 
- Symmetry can be imposed by taking the sum of the upper-triangular part of the matrix and its transpose. 
- Next, the main diagonal is constructed with the help of additional parameters $\matr{q}$ and $\matr{r}$ which we store in the last two columns. 

We then add this parametrization to the linear layer with `parametrize.register_parametrization`. Under the hood, `PairwiseParametrization` gets invoked during each forward pass.

In [8]:
class PairwiseParametrization(torch.nn.Module):
    def forward(self, W):
        # Construct a symmetric matrix with zero diagonal
        W0 = W[:, :-2].triu(1)
        W0 = W0 + W0.T

        # Retrieve the `q` and `r` vectors from the last two columns
        q = W[:, -2]
        r = W[:, -1]
        # Construct the main diagonal
        w_diag = torch.diag(q * torch.sum(torch.abs(W0), 1) + r) 

        return W0 + w_diag


class Pairwise(torch.nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        # Pay attention to the dimensions
        self.lin = torch.nn.Linear(num_hidden + 2, num_hidden, bias=False)
        # Add parametrization
        parametrize.register_parametrization(self.lin, "weight", PairwiseParametrization(), unsafe=True)
        self.reset_parameters()
    
    def reset_parameters(self):
        self.lin.reset_parameters()
    
    def forward(self, x):
        return self.lin(x)

Finally, we multiply the initial condition value $\matr{F}(0)$ with a scalar $\beta$. Although it is trivial, we again implement it as a module just to stick to the convention.

In [9]:
class Source(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(1))
        self.reset_parameters()
    
    def reset_parameters(self):
        torch.nn.init.normal_(self.weight)
    
    def forward(self, x):
        return x * self.weight

Now, we can write our GRAFF network 🦒. It would follow the architecture of our previous GCN network whenever it is possible.

Points to consider:

- Fixed step size $0 < \tau \leq 1$, and it does not have to be equal $1$ as in GCN (We pass the step size as an argument)

- We need to project the raw node features onto the subspace of dimensionality $d$ (*which is the dimension of square matrices in our convolution*) with the **encoder** (in this tutorial, we implement it as a linear layer)

- Similarly, the **decoder** projects the output of the last convolutional layer onto the subspace of dimensionality $k$ (*the number of classes*)

- Don't forget the residual connections!

- We use ReLU as the non-linearity, as it satisfies $x \sigma(x) \geq 0$


In [10]:
class GRAFFNet(torch.nn.Module):
    def __init__(self, dataset, num_hidden, self_loops=True, step_size=1.):
        super().__init__()
        self.step_size = step_size

        # Encoder
        self.enc = torch.nn.Linear(dataset.num_features, num_hidden, bias=False)

        # Initialize the linear layers
        self.ext_lin = External(num_hidden)
        self.pair_lin = Pairwise(num_hidden)
        self.source_lin = Source()

        # Initialize the GRAFF layer
        self.conv = GRAFFConv(self.ext_lin, self.pair_lin, self.source_lin, self_loops=self_loops)

        # Decoder
        self.dec = torch.nn.Linear(num_hidden, dataset.num_classes, bias=False)

        self.reset_parameters()
    
    def reset_parameters(self):
        self.enc.reset_parameters()
        self.ext_lin.reset_parameters()
        self.pair_lin.reset_parameters()
        self.source_lin.reset_parameters()
        self.dec.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Apply the encoder
        x = self.enc(x)
        # Copy the initial features
        x0 = x.clone()

        # This context manager caches the parametrization to reduce redundant calculations
        with parametrize.cached():
            x = x + self.step_size * F.relu(self.conv(x, edge_index, x0))
            x = x + self.step_size * F.relu(self.conv(x, edge_index, x0))

        # Apply the decoder
        x = self.dec(x)
        return F.log_softmax(x, dim=1)

### Experiments

We consider the task of node classification on a single graph (transductive setting). The datasets are `Texas` (*low homophily*) from `Cora` (*high homophily*). The train/val/test masks are taken from the [Geom-GCN paper](https://arxiv.org/pdf/2002.05287.pdf) (10 random splits)

First, let us examine the datasets.

In [11]:
def summarize_dataset(dataset):
    print(f'Dataset name:', dataset.name)
    
    runs = dataset[0]['train_mask'].shape[1]
    print(f'Number of splits in dataset: {runs}')

    print(f'Number of classes: {dataset.num_classes}')

    print(f'Number of nodes: {dataset[0].num_nodes}')
    print(f'Number of edges: {dataset[0].num_edges}')

    h = homophily(dataset[0].edge_index, dataset[0].y)
    print(f'Homophily: {h:.3f}')

In [None]:
dataset_texas = WebKB(root='/tmp/Texas', name='Texas')

In [None]:
dataset_cora = Planetoid(root='/tmp/Cora', name='Cora', split='Geom-GCN')

In [14]:
summarize_dataset(dataset_texas)

Dataset name: texas
Number of splits in dataset: 10
Number of classes: 5
Number of nodes: 183
Number of edges: 325
Homophily: 0.108


In [15]:
summarize_dataset(dataset_cora)

Dataset name: Cora
Number of splits in dataset: 10
Number of classes: 7
Number of nodes: 2708
Number of edges: 10556
Homophily: 0.810


An important qualitative property of a graph is homophily. In the [Geom-GCN paper](https://arxiv.org/pdf/2002.05287.pdf), it is defined as the average fraction of neighbors with the same label:

$$h = \frac{1}{|V|} \sum_{v \in V} \frac{\textrm{Number of }v\textrm{'s neighbors who have the same label as }v}{\textrm{Number of }v\textrm{'s neighbors}} $$

It can be computed with a PyG function `torch_geometric.utils.homophily`. As can be seen from the summarization above, `Texas` is heterophilic, whereas `Cora` is homophilic.



---



Now, let us write the training and evaluation pipeline. We run our pipeline on each of the 10 splits. Each run is terminated with early stopping monitored on the validation set (when neither loss nor accuracy is improved over the last 100 epochs). The quality of the models is assessed on the test set of each split, and the accuracy is evaluated when the validation loss is minimal. For training, we use the [Adam](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html) optimizer without a learning rate scheduler.

In [16]:
def train(model, optimizer, data, split):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask[:,split]], data.y[data.train_mask[:,split]])
    loss.backward()
    optimizer.step()


@torch.no_grad()
def evaluate(model, data, split):
    model.eval()
    out = model(data)
    outs = {}

    for key in ['train', 'val', 'test']:
        mask = data[f'{key}_mask'][:,split]
        loss = float(F.nll_loss(out[mask], data.y[mask]))
        pred = out[mask].argmax(1)
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        outs[f'{key}_loss'] = loss
        outs[f'{key}_acc'] = acc

    return outs


def run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping):
    val_losses, accs, durations = [], [], []
    # Each run corresponds to a different train/val/test split
    for run in range(runs):
        data = dataset[0]
        data = data.to(device)
        
        # Suppress warnings regarding masks
        data.train_mask = data.train_mask.bool()
        data.val_mask = data.val_mask.bool()
        data.test_mask = data.test_mask.bool()

        model.to(device).reset_parameters()
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_start = time.perf_counter()

        best_val_loss = float('inf')
        best_val_acc = 0
        cur_step = 0

        test_acc = 0

        for epoch in range(1, epochs + 1):
            # Training and evaluation
            train(model, optimizer, data, split=run)
            eval_info = evaluate(model, data, split=run)
            eval_info['epoch'] = epoch

            # Best test acc logging and early stop logic
            if eval_info['val_loss'] < best_val_loss:
                best_val_loss = eval_info['val_loss']
                test_acc = eval_info['test_acc']
                cur_step = 0
            elif eval_info['val_acc'] > best_val_acc:
                best_val_acc = eval_info['val_acc']
                cur_step = 0
            else:
                cur_step += 1

            if cur_step >= early_stopping:
                print(f'Training terminated on epoch {epoch}')
                break

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        t_end = time.perf_counter()

        val_losses.append(best_val_loss)
        accs.append(test_acc)
        durations.append(t_end - t_start)
    loss, acc, duration = torch.tensor(val_losses), torch.tensor(accs), torch.tensor(durations)

    print(f'Val Loss: {float(loss.mean()):.4f}, '
          f'Test Accuracy: {float(acc.mean()):.3f} ± {float(acc.std()):.3f}, '
          f'Duration: {float(duration.mean()):.3f}s')

In [17]:
runs = dataset_texas[0]['train_mask'].shape[1]
epochs = 5000
num_hidden = 64
lr = 1e-3
weight_decay = 5e-6
early_stopping = 100

In [18]:
model = GCNNet(dataset_texas, num_hidden)

run_train(dataset_texas, model, runs, epochs, lr, weight_decay, early_stopping)

Training terminated on epoch 117
Training terminated on epoch 116
Training terminated on epoch 124
Training terminated on epoch 119
Training terminated on epoch 123
Training terminated on epoch 114
Training terminated on epoch 103
Training terminated on epoch 117
Training terminated on epoch 139
Training terminated on epoch 125
Val Loss: 1.3577, Test Accuracy: 0.543 ± 0.181, Duration: 0.971s


In [19]:
model = GRAFFNet(dataset_texas, num_hidden, self_loops=False, step_size=0.5)

run_train(dataset_texas, model, runs, epochs, lr, weight_decay, early_stopping)

Training terminated on epoch 138
Training terminated on epoch 243
Training terminated on epoch 166
Training terminated on epoch 147
Training terminated on epoch 163
Training terminated on epoch 140
Training terminated on epoch 121
Training terminated on epoch 141
Training terminated on epoch 121
Training terminated on epoch 143
Val Loss: 0.6392, Test Accuracy: 0.749 ± 0.058, Duration: 1.063s


In [20]:
runs = dataset_cora[0]['train_mask'].shape[1]
epochs = 5000
num_hidden = 64
lr = 1e-3
weight_decay = 5e-5
early_stopping = 100

In [21]:
model = GCNNet(dataset_cora, num_hidden)

run_train(dataset_cora, model, runs, epochs, lr, weight_decay, early_stopping)

Training terminated on epoch 276
Training terminated on epoch 270
Training terminated on epoch 277
Training terminated on epoch 271
Training terminated on epoch 314
Training terminated on epoch 279
Training terminated on epoch 290
Training terminated on epoch 299
Training terminated on epoch 265
Training terminated on epoch 281
Val Loss: 0.4231, Test Accuracy: 0.859 ± 0.017, Duration: 1.448s


In [22]:
model = GRAFFNet(dataset_cora, num_hidden, self_loops=False, step_size=0.25)

run_train(dataset_cora, model, runs, epochs, lr, weight_decay, early_stopping)

Training terminated on epoch 548
Training terminated on epoch 180
Training terminated on epoch 163
Training terminated on epoch 154
Training terminated on epoch 153
Training terminated on epoch 139
Training terminated on epoch 133
Training terminated on epoch 132
Training terminated on epoch 125
Training terminated on epoch 124
Val Loss: 0.4185, Test Accuracy: 0.862 ± 0.016, Duration: 1.282s


To conclude, GRAFF performs well on both extreme settings (low/high homophily). On Texas dataset, it substantially outperforms a vanilla GCN.  However, on Cora, the performance of GRAFF is on par with GCN.

You can further use this notebook to experiment with hyperparameters, such as `step_size` and `num_hidden`. One way to improve the current GRAFF implementation is substituting the linear encoder and decoder with MLPs.