In [77]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import torchdiffeq
from typing import Callable
from tqdm import tqdm

In [78]:
class PerformanceContainer(object):
    """ Simple data class for metrics logging."""
    def __init__(self, data:dict):
        self.data = data
        
    @staticmethod
    def deep_update(x, y):
        for key in y.keys():
            x.update({key: list(x[key] + y[key])})
        return x
    
def accuracy(y_hat:torch.Tensor, y:torch.Tensor):
    """ Standard percentage accuracy computation """
    preds = torch.max(y_hat, 1)[1]
    return torch.mean((y == preds).float())


In [45]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# seed for repeatability
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.manual_seed(0)
np.random.seed(0)

In [55]:
# Seed for repeatability
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.manual_seed(0)
np.random.seed(0)

# Load the Cora dataset using torch_geometric
dataset = datasets.Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

# Features and labels
X = data.x.to(device)
Y = data.y.to(device)

# Masks for training, validation, and test
train_mask = data.train_mask.to(device)
val_mask = data.val_mask.to(device)
test_mask = data.test_mask.to(device)

# Number of features and classes
num_feats = X.shape[1]
num_classes = dataset.num_classes

# Summary
num_classes, train_mask.sum().item(), val_mask.sum().item(), test_mask.sum().item()

(7, 140, 500, 1000)

In [71]:
class GCNLayer(nn.Module):
    def __init__(self, adj_matrix: torch.Tensor, in_feats: int, out_feats: int,
                 activation: Callable[[torch.Tensor], torch.Tensor] = None,
                 dropout: float = 0.0, bias: bool = True):
        super().__init__()
        
        # Convert edge_index to a square adjacency matrix
        edge_index = adj_matrix
        num_nodes = torch.max(edge_index) + 1  # Assuming node indices start from 0

        # Convert edge_index to adjacency matrix
        adj_matrix = torch.zeros((num_nodes, num_nodes), device=edge_index.device)
        adj_matrix[edge_index[0], edge_index[1]] = 1
        
        # Store the adjacency matrix
        self.adj_matrix = adj_matrix
        
        # Define the layer weights and biases
        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.bias = None

        # Activation function
        self.activation = activation

        # Dropout layer
        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()

        # Initialize weights
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, h: torch.Tensor):
        if self.dropout:
            h = self.dropout(h)
        h = torch.mm(h, self.weight)

        # Compute the degree matrix
        degree_inv_sqrt = torch.pow(self.adj_matrix.sum(dim=1), -0.5)
        degree_inv_sqrt[torch.isinf(degree_inv_sqrt)] = 0.0
        norm = torch.diag(degree_inv_sqrt)

        # Symmetric normalization: D^(-1/2) A D^(-1/2)
        adj_norm = torch.mm(norm, torch.mm(self.adj_matrix.float(), norm))

        # Aggregation
        h = torch.mm(adj_norm, h)

        # Apply bias if necessary
        if self.bias is not None:
            h = h + self.bias

        # Apply activation function
        if self.activation is not None:
            h = self.activation(h)

        return h

In [64]:
class GCN(nn.Module):
    def __init__(self, num_layers: int, adj_matrix: torch.Tensor, in_feats: int, hidden_feats: int,
                 out_feats: int, activation: Callable[[torch.Tensor], torch.Tensor],
                 dropout: float = 0.0, bias: bool = True):
        super().__init__()
        self.layers = nn.ModuleList()

        # First GCN layer
        self.layers.append(GCNLayer(adj_matrix, in_feats, hidden_feats, activation, dropout, bias))

        # Middle GCN layers
        for _ in range(num_layers - 2):
            self.layers.append(GCNLayer(adj_matrix, hidden_feats, hidden_feats, activation, dropout, bias))

        # Last GCN layer
        self.layers.append(GCNLayer(adj_matrix, hidden_feats, out_feats, None, 0.0, bias))

    def forward(self, features: torch.Tensor):
        h = features
        for layer in self.layers:
            h = layer(h)
        return h

In [48]:
class GDEFunc(nn.Module):
    def __init__(self, gnn: nn.Module):
        """General GDE function class. To be passed to an ODEBlock."""
        super().__init__()
        self.gnn = gnn
        self.nfe = 0  # Number of function evaluations (NFE)
    
    def set_graph(self, adj_matrix: torch.Tensor):
        """Set adjacency matrix for the GNN layers if needed."""
        for layer in self.gnn:
            if hasattr(layer, 'adj_matrix'):
                layer.adj_matrix = adj_matrix
            
    def forward(self, t, x):
        """Forward method to compute the GNN output."""
        self.nfe += 1
        x = self.gnn(x)
        return x

In [49]:
class ControlledGDEFunc(GDEFunc):
    def __init__(self, gnn: nn.Module):
        """Controlled GDE version. Input information is preserved longer via hooks to input node features X_0, 
           affecting all ODE function steps. Requires assignment of '.h0' before calling .forward."""
        super().__init__(gnn)
        self.nfe = 0
        self.h0 = None  # Placeholder for the initial node features
            
    def forward(self, t, x):
        """Forward method that concatenates initial node features with current features."""
        self.nfe += 1
        if self.h0 is not None:
            x = torch.cat([x, self.h0], dim=1)
        x = self.gnn(x)
        return x

In [53]:
class ODEBlock(nn.Module):
    def __init__(self, odefunc: nn.Module, method: str = 'dopri5', rtol: float = 1e-3, atol: float = 1e-4, adjoint: bool = True):
        """ Standard ODEBlock class. Can handle all types of ODE functions
            :method:str = {'euler', 'rk4', 'dopri5', 'adams'}
        """
        super().__init__()
        self.odefunc = odefunc
        self.method = method
        self.adjoint_flag = adjoint
        self.atol, self.rtol = atol, rtol

    def forward(self, x: torch.Tensor, T: int = 1):
        self.integration_time = torch.tensor([0, T]).float()
        self.integration_time = self.integration_time.type_as(x)

        if self.adjoint_flag:
            out = torchdiffeq.odeint_adjoint(self.odefunc, x, self.integration_time,
                                             rtol=self.rtol, atol=self.atol, method=self.method)
        else:
            out = torchdiffeq.odeint(self.odefunc, x, self.integration_time,
                                     rtol=self.rtol, atol=self.atol, method=self.method)
            
        return out[-1]
    
    def forward_batched(self, x: torch.Tensor, nn: int, indices: list, timestamps: set):
        """ Modified forward for ODE batches with different integration times """
        timestamps = torch.Tensor(list(timestamps))
        timestamps = timestamps.type_as(x)
        
        if self.adjoint_flag:
            out = torchdiffeq.odeint_adjoint(self.odefunc, x, timestamps,
                                             rtol=self.rtol, atol=self.atol, method=self.method)
        else:
            out = torchdiffeq.odeint(self.odefunc, x, timestamps,
                                     rtol=self.rtol, atol=self.atol, method=self.method)

        out = self._build_batch(out, nn, indices).reshape(x.shape)
        return out
    
    def _build_batch(self, odeout, nn, indices):
        b_out = []
        for i in range(len(indices)):
            b_out.append(odeout[indices[i], i*nn:(i+1)*nn])
        return torch.cat(b_out).to(odeout.device)
              
    def trajectory(self, x: torch.Tensor, T: int, num_points: int):
        self.integration_time = torch.linspace(0, T, num_points)
        self.integration_time = self.integration_time.type_as(x)
        out = torchdiffeq.odeint(self.odefunc, x, self.integration_time,
                                 rtol=self.rtol, atol=self.atol, method=self.method)
        return out

In [56]:
# Convert the edge index to a NetworkX graph
G = pyg_utils.to_networkx(data, to_undirected=True)

# Remove existing self-loops
G.remove_edges_from(nx.selfloop_edges(G))

# Add self-loops
G.add_edges_from(zip(G.nodes(), G.nodes()))

# Convert the graph back to PyTorch Geometric format
edge_index = pyg_utils.from_networkx(G).edge_index

# Update the data object with the new edge index
data.edge_index = edge_index

# Calculate the number of edges
n_edges = data.edge_index.size(1)

n_edges

13264

In [57]:
# Step 1: Compute the degrees of the nodes
degs = pyg_utils.degree(data.edge_index[0], num_nodes=data.num_nodes).float()

# Step 2: Compute the normalization values (D^-0.5)
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0  # Replace infinities with zero

# Step 3: Add the norm as a node feature to the data object
data.norm = norm.unsqueeze(1).to(device)

In [72]:
adj_matrix = edge_index.to(device)  # Assuming data.edge_index is the adjacency matrix in torch_geometric format

# Create the GNN layers using the adjacency matrix
gnn = nn.Sequential(
    GCNLayer(adj_matrix=adj_matrix, in_feats=64, out_feats=64, activation=nn.Softplus(), dropout=0.9),
    GCNLayer(adj_matrix=adj_matrix, in_feats=64, out_feats=64, activation=None, dropout=0.9)
).to(device)

# Define the graph ODE function
gdefunc = GDEFunc(gnn)

# Create the ODE block with the 'rk4' method
gde = ODEBlock(odefunc=gdefunc, method='rk4', atol=1e-3, rtol=1e-4, adjoint=False).to(device)

# Create the full model
m = nn.Sequential(
    GCNLayer(adj_matrix=adj_matrix, in_feats=num_feats, out_feats=64, activation=F.relu, dropout=0.4),
    gde,
    GCNLayer(adj_matrix=adj_matrix, in_feats=64, out_feats=n_classes, activation=None, dropout=0.0)
).to(device)

In [79]:
# Optimizer and Loss
opt = torch.optim.Adam(m.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

# Performance Logging
logger = PerformanceContainer(data={
    'train_loss': [], 'train_accuracy': [],
    'test_loss': [], 'test_accuracy': [],
    'forward_time': [], 'backward_time': [],
    'nfe': []
})

steps = 3000
verbose_step = 150
num_grad_steps = 0

for i in tqdm(range(steps)):  # Looping over epochs
    m.train()
    start_time = time.time()

    outputs = m(X)
    f_time = time.time() - start_time

    nfe = m._modules['1'].odefunc.nfe

    y_pred = outputs

    loss = criterion(y_pred[train_mask], Y[train_mask])
    opt.zero_grad()
    
    start_time = time.time()
    loss.backward()
    b_time = time.time() - start_time
    
    opt.step()
    num_grad_steps += 1

    with torch.no_grad():
        m.eval()

        # Calculate outputs again with zeroed dropout
        y_pred = m(X)
        m._modules['1'].odefunc.nfe = 0

        train_loss = loss.item()
        train_acc = accuracy(y_pred[train_mask], Y[train_mask]).item()
        test_acc = accuracy(y_pred[test_mask], Y[test_mask]).item()
        test_loss = criterion(y_pred[test_mask], Y[test_mask]).item()

        logger.deep_update(logger.data, dict(
            train_loss=[train_loss],
            train_accuracy=[train_acc],
            test_loss=[test_loss],
            test_accuracy=[test_acc],
            nfe=[nfe],
            forward_time=[f_time],
            backward_time=[b_time]
        ))

    if num_grad_steps % verbose_step == 0:
        print(f'[{num_grad_steps}], Loss: {train_loss:.3f}, Train Accuracy: {train_acc:.3f}, '
              f'Test Accuracy: {test_acc:.3f}, NFE: {nfe}')

  5%|▌         | 150/3000 [00:56<17:55,  2.65it/s]

[150], Loss: 0.083, Train Accuracy: 1.000, Test Accuracy: 0.802, NFE: 4


 10%|█         | 300/3000 [01:52<16:58,  2.65it/s]

[300], Loss: 0.029, Train Accuracy: 1.000, Test Accuracy: 0.806, NFE: 4


 15%|█▌        | 450/3000 [02:49<16:00,  2.65it/s]

[450], Loss: 0.018, Train Accuracy: 1.000, Test Accuracy: 0.807, NFE: 4


 20%|██        | 600/3000 [03:46<15:04,  2.65it/s]

[600], Loss: 0.018, Train Accuracy: 1.000, Test Accuracy: 0.809, NFE: 4


 25%|██▌       | 750/3000 [04:42<14:11,  2.64it/s]

[750], Loss: 0.013, Train Accuracy: 1.000, Test Accuracy: 0.801, NFE: 4


 30%|███       | 900/3000 [05:39<13:12,  2.65it/s]

[900], Loss: 0.012, Train Accuracy: 1.000, Test Accuracy: 0.808, NFE: 4


 35%|███▌      | 1050/3000 [06:36<12:19,  2.64it/s]

[1050], Loss: 0.011, Train Accuracy: 1.000, Test Accuracy: 0.806, NFE: 4


 40%|████      | 1200/3000 [07:32<11:19,  2.65it/s]

[1200], Loss: 0.015, Train Accuracy: 1.000, Test Accuracy: 0.814, NFE: 4


 45%|████▌     | 1350/3000 [08:29<10:22,  2.65it/s]

[1350], Loss: 0.015, Train Accuracy: 1.000, Test Accuracy: 0.807, NFE: 4


 50%|█████     | 1500/3000 [09:26<09:25,  2.65it/s]

[1500], Loss: 0.012, Train Accuracy: 1.000, Test Accuracy: 0.813, NFE: 4


 55%|█████▌    | 1650/3000 [10:22<08:29,  2.65it/s]

[1650], Loss: 0.008, Train Accuracy: 1.000, Test Accuracy: 0.811, NFE: 4


 60%|██████    | 1800/3000 [11:19<07:34,  2.64it/s]

[1800], Loss: 0.008, Train Accuracy: 1.000, Test Accuracy: 0.801, NFE: 4


 65%|██████▌   | 1950/3000 [12:16<06:36,  2.65it/s]

[1950], Loss: 0.008, Train Accuracy: 1.000, Test Accuracy: 0.809, NFE: 4


 70%|███████   | 2100/3000 [13:13<05:40,  2.64it/s]

[2100], Loss: 0.011, Train Accuracy: 1.000, Test Accuracy: 0.807, NFE: 4


 75%|███████▌  | 2250/3000 [14:09<04:43,  2.64it/s]

[2250], Loss: 0.006, Train Accuracy: 1.000, Test Accuracy: 0.823, NFE: 4


 80%|████████  | 2400/3000 [15:06<03:47,  2.64it/s]

[2400], Loss: 0.044, Train Accuracy: 1.000, Test Accuracy: 0.820, NFE: 4


 85%|████████▌ | 2550/3000 [16:03<02:50,  2.64it/s]

[2550], Loss: 0.006, Train Accuracy: 1.000, Test Accuracy: 0.811, NFE: 4


 90%|█████████ | 2700/3000 [17:00<01:53,  2.64it/s]

[2700], Loss: 0.011, Train Accuracy: 1.000, Test Accuracy: 0.809, NFE: 4


 95%|█████████▌| 2850/3000 [17:56<00:57,  2.63it/s]

[2850], Loss: 0.006, Train Accuracy: 1.000, Test Accuracy: 0.821, NFE: 4


100%|██████████| 3000/3000 [18:53<00:00,  2.65it/s]

[3000], Loss: 0.012, Train Accuracy: 1.000, Test Accuracy: 0.809, NFE: 4



