## Import libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
import warnings
import tarfile
import requests

torch_dev = torch.accelerator.current_accelerator() if torch.accelerator.is_available() \
    else torch.device('cpu')

## Obtain the dataset

In [2]:
os.makedirs("dataset", exist_ok=True)

response = requests.get("https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz", stream=True)
with open("dataset/cora.tgz", "wb") as f:
    for chunk in response.iter_content(chunk_size=8192):
        f.write(chunk)

data_dir = "dataset/cora_extracted"

with tarfile.open("dataset/cora.tgz", "r:gz") as tar:
    os.makedirs(data_dir, exist_ok=True)
    tar.extractall(path=data_dir)

data_dir = os.path.join(data_dir, "cora")

citations = pd.read_csv(
    os.path.join(data_dir, "cora.cites"),
    sep="\t",
    header=None,
    names=["target", "source"],
)

papers = pd.read_csv(
    os.path.join(data_dir, "cora.content"),
    sep="\t",
    header=None,
    names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
)

class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])

print(citations)

print(papers)

      target  source
0          0      21
1          0     905
2          0     906
3          0    1909
4          0    1940
...      ...     ...
5424    1873     328
5425    1873    1876
5426    1874    2586
5427    1876    1874
5428    1897    2707

[5429 rows x 2 columns]
      paper_id  term_0  term_1  term_2  term_3  term_4  term_5  term_6  \
0          462       0       0       0       0       0       0       0   
1         1911       0       0       0       0       0       0       0   
2         2002       0       0       0       0       0       0       0   
3          248       0       0       0       0       0       0       0   
4          519       0       0       0       0       0       0       0   
...        ...     ...     ...     ...     ...     ...     ...     ...   
2703      2370       0       0       0       0       0       0       0   
2704      2371       0       0       0       0       0       0       0   
2705      2372       0       0       0       0       0   

  tar.extractall(path=data_dir)


### Split the dataset

In [3]:
# Obtain random indices
random_indices = np.random.permutation(range(papers.shape[0]))

# 50/50 split
train_data = papers.iloc[random_indices[: len(random_indices) // 2]]
test_data = papers.iloc[random_indices[len(random_indices) // 2 :]]

### Prepare the graph data

In [4]:
# Obtain paper indices which will be used to gather node states
# from the graph later on when training the model
train_indices = train_data["paper_id"].to_numpy()
test_indices = test_data["paper_id"].to_numpy()

# Obtain ground truth labels corresponding to each paper_id
train_labels = train_data["subject"].to_numpy()
test_labels = test_data["subject"].to_numpy()

# Define graph, namely an edge tensor and a node feature tensor
edges = torch.tensor(citations[["target", "source"]].values, dtype=torch.long)
node_states = torch.tensor(papers.sort_values("paper_id").iloc[:, 1:-1].values, dtype=torch.float32)

# Print shapes of the graph
print("Edges shape:         ", edges.shape)
print("Node features shape: ", node_states.shape)

Edges shape:          torch.Size([5429, 2])
Node features shape:  torch.Size([2708, 1433])


## Build the model

### (Multi-head) graph attention layer

In [5]:
class GraphAttention(nn.Module):
    def __init__(self, units, kernel_initializer="glorot_uniform"):
        super(GraphAttention, self).__init__()
        self.units = units
        self.kernel_initializer = kernel_initializer
        # We delay weight creation until we see the input dimensions.
        self.kernel = None
        self.kernel_attention = None
        self.built = False

    def build(self, input_dim, device):
        self.kernel = nn.Parameter(torch.empty(input_dim, self.units, device=device))
        self.kernel_attention = nn.Parameter(torch.empty(self.units * 2, 1, device=device))

        match self.kernel_attention:
            case "glroot_uniform":
                nn.init.xavier_uniform_(self.kernel)
                nn.init.xavier_uniform_(self.kernel_attention)
            case _:
                # Fallback initializer; you might change this if needed.
                nn.init.kaiming_uniform_(self.kernel, nonlinearity='leaky_relu')
                nn.init.kaiming_uniform_(self.kernel_attention, nonlinearity='leaky_relu')
        self.built = True

    def forward(self, node_states, edges):
        # Build weights on first forward pass (if not built)
        if not self.built:
            self.build(node_states.size(1), node_states.device)
        
        # (1) Linearly transform node states.
        # Shape: (N, units)
        node_states_transformed = torch.matmul(node_states, self.kernel)

        # (2) Compute pairwise attention scores.
        node_states_expanded = node_states_transformed[edges]
        node_states_expanded = node_states_expanded.reshape(edges.size(0), -1)
        attention_scores = F.leaky_relu(torch.matmul(node_states_expanded, self.kernel_attention))
        attention_scores = attention_scores.squeeze(-1)

        # (3) Normalize attention scores.
        attention_scores_clipped = torch.clamp(attention_scores, -2, 2)
        attention_scores_exp = torch.exp(attention_scores_clipped)
        num_nodes = node_states.size(0)
        attention_sum = torch.zeros(num_nodes, device=attention_scores.device)
        attention_sum = attention_sum.index_add(0, edges[:, 0], attention_scores_exp)
        normalized_attention = attention_scores_exp / attention_sum[edges[:, 0]]

        # (4) Aggregate neighbor features: weighted sum over neighbors.
        node_states_neighbors = node_states_transformed[edges[:, 1]]
        weighted_neighbors = node_states_neighbors * normalized_attention.unsqueeze(1)
        out = torch.zeros_like(node_states_transformed)
        out = out.index_add(0, edges[:, 0], weighted_neighbors)
        return out

class MultiHeadGraphAttention(nn.Module):
    def __init__(self, units, num_heads=8, merge_type="concat"):
        super(MultiHeadGraphAttention, self).__init__()
        self.num_heads = num_heads
        self.merge_type = merge_type
        # Create a ModuleList of independent attention heads.
        self.attention_layers = nn.ModuleList(
            [GraphAttention(units) for _ in range(num_heads)]
        )

    def forward(self, atom_features, pair_indices):
        # Compute the output for each attention head.
        head_outputs = [att(atom_features, pair_indices) for att in self.attention_layers]

        # Merge the outputs: either concatenate along the feature dimension or average them.
        if self.merge_type == "concat":
            # Concatenate along the last dimension.
            out = torch.cat(head_outputs, dim=1)
        else:
            # Stack and then average over the head dimension.
            out = torch.mean(torch.stack(head_outputs, dim=0), dim=0)
        # Apply a ReLU nonlinearity (same as tf.nn.relu).
        return F.relu(out)

### Implement training logic

In [6]:
class GraphAttentionNetwork(nn.Module):
    def __init__(self, node_states, edges, hidden_units, num_heads, num_layers, output_dim):
        super(GraphAttentionNetwork, self).__init__()
        self.node_states = node_states  # fixed graph node features
        self.edges = edges              # fixed graph connectivity

        # Preprocessing: linearly transform the initial node features.
        in_features = node_states.size(1)
        self.preprocess = nn.Linear(in_features, hidden_units * num_heads)

        # Create a list of multi-head attention layers.
        # Each layer applies attention and then a residual connection.
        self.attention_layers = nn.ModuleList([
            MultiHeadGraphAttention(hidden_units, num_heads, merge_type="concat")
            for _ in range(num_layers)
        ])

        # Final output layer maps the aggregated representations to logits.
        self.output_layer = nn.Linear(hidden_units * num_heads, output_dim)

    def forward(self):
        # Preprocess the node features.
        x = self.preprocess(self.node_states)
        x = F.relu(x)
        # Sequentially pass through all attention layers with residual connections.
        for att_layer in self.attention_layers:
            x = att_layer(x, self.edges) + x
        outputs = self.output_layer(x)
        # The network produces raw logits.
        return outputs

### Train and evaluate

In [7]:
# ----------------------------
# Hyper-Parameters and Data Preparation
# ----------------------------
# Hyper-parameters (matching the TensorFlow script)
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
# OUTPUT_DIM is assumed to be the number of classes (len(class_values))
# Make sure that you have defined 'class_values' when processing your dataset.
OUTPUT_DIM = len(class_values)

NUM_EPOCHS = 100
BATCH_SIZE = 256
LEARNING_RATE = 3e-1
MOMENTUM = 0.9

# Convert training and test indices/labels to PyTorch tensors if not already done.
# (These values are provided by your data preparation code.)
node_states = node_states.to(torch_dev)
edges = edges.to(torch_dev)
train_indices = torch.tensor(train_indices, dtype=torch.long).to(torch_dev)
train_labels = torch.tensor(train_labels, dtype=torch.long).to(torch_dev)
test_indices = torch.tensor(test_indices, dtype=torch.long).to(torch_dev)
test_labels = torch.tensor(test_labels, dtype=torch.long).to(torch_dev)

# Make sure the node_states and edges are already torch.Tensors.
# For example, they may have been defined as:
# edges = torch.tensor(citations[["target", "source"]].values, dtype=torch.long)
# node_states = torch.tensor(papers.sort_values("paper_id").iloc[:, 1:-1].values, dtype=torch.float32)

# ----------------------------
# Build Model, Loss Function, and Optimizer
# ----------------------------
model = GraphAttentionNetwork(node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM)
model = model.to(torch_dev)

# Use SGD with momentum
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
# CrossEntropyLoss in PyTorch combines LogSoftmax and NLLLoss (from logits)
loss_fn = nn.CrossEntropyLoss()

# ----------------------------
# Training Loop with Mini-Batches
# ----------------------------
# For a graph model, the forward pass computes representations for all nodes.
# To mimic batching, we compute the forward pass each mini-batch and then gather the predictions
# corresponding to the batch indices.  (For small graphs, the extra forward passes are acceptable.)
num_train = len(train_indices)
print("Starting training...")

for epoch in range(NUM_EPOCHS):
    model.train()  # Set to training mode
    permutation = torch.randperm(num_train, device=torch_dev)
    epoch_loss = 0.0

    for i in range(0, num_train, BATCH_SIZE):
        optimizer.zero_grad()

        # Select mini-batch
        batch_perm = permutation[i : i + BATCH_SIZE]
        batch_indices = train_indices[batch_perm]
        batch_labels = train_labels[batch_perm]

        # Forward pass: compute output for entire graph (stored on device)
        outputs = model()  # outputs shape: (N, OUTPUT_DIM)
        batch_logits = outputs[batch_indices]  # gather logits for mini-batch

        loss = loss_fn(batch_logits, batch_labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / (num_train / BATCH_SIZE)
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} – Average Loss: {avg_loss:.4f}")

# ----------------------------
# Evaluation on the Test Set
# ----------------------------
model.eval()  # Set model to evaluation mode
with torch.no_grad():
    outputs = model()
    # Gather test node predictions.
    test_logits = outputs[test_indices]
    # Compute predicted labels.
    predicted_labels = torch.argmax(test_logits, dim=1)
    correct = (predicted_labels == test_labels).sum().item()
    test_accuracy = correct / len(test_labels)

print("--" * 38)
print(f"Test Accuracy {test_accuracy * 100:.1f}%")

Starting training...
Epoch 1/100 – Average Loss: 6.6442
Epoch 2/100 – Average Loss: 2.9619
Epoch 3/100 – Average Loss: 2.4086
Epoch 4/100 – Average Loss: 2.1745
Epoch 5/100 – Average Loss: 2.1651
Epoch 6/100 – Average Loss: 2.0022
Epoch 7/100 – Average Loss: 1.9628
Epoch 8/100 – Average Loss: 1.8903
Epoch 9/100 – Average Loss: 1.8635
Epoch 10/100 – Average Loss: 1.8068
Epoch 11/100 – Average Loss: 1.7912
Epoch 12/100 – Average Loss: 1.7204
Epoch 13/100 – Average Loss: 1.6415
Epoch 14/100 – Average Loss: 1.6164
Epoch 15/100 – Average Loss: 1.5927
Epoch 16/100 – Average Loss: 1.5172
Epoch 17/100 – Average Loss: 1.4171
Epoch 18/100 – Average Loss: 1.3862
Epoch 19/100 – Average Loss: 1.3109
Epoch 20/100 – Average Loss: 1.2025
Epoch 21/100 – Average Loss: 1.2593
Epoch 22/100 – Average Loss: 1.2733
Epoch 23/100 – Average Loss: 1.1259
Epoch 24/100 – Average Loss: 1.2949
Epoch 25/100 – Average Loss: 1.4896
Epoch 26/100 – Average Loss: 1.2924
Epoch 27/100 – Average Loss: 1.1853
Epoch 28/100 – A

### Predict (probabilities)

In [8]:
# Set model to evaluation mode and compute predictions.
model.eval()
with torch.no_grad():
    # Forward pass through the full graph; outputs has shape (N, OUTPUT_DIM)
    outputs = model()
    # Gather logits for test node indices.
    test_logits = outputs[test_indices]
    # Convert logits to probabilities using softmax along the class dimension.
    test_probs = F.softmax(test_logits, dim=1)

# Create a mapping from label integer to the class name.
mapping = {v: k for k, v in class_idx.items()}

# Print probabilities for the first 10 test examples.
for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
    # If label is a tensor, convert to int.
    label_int = int(label) if torch.is_tensor(label) else label
    print(f"Example {i+1}: {mapping[label_int]}")
    # probs is a tensor; converting to list for easy iteration.
    for j, c in zip(probs.tolist(), class_idx.keys()):
        print(f"\tProbability of {c: <24} = {j * 100:7.3f}%")
    print("---" * 20)

Example 1: Probabilistic_Methods
	Probability of Case_Based               =   0.000%
	Probability of Genetic_Algorithms       =   0.000%
	Probability of Neural_Networks          =   0.000%
	Probability of Probabilistic_Methods    =   0.000%
	Probability of Reinforcement_Learning   =   0.000%
	Probability of Rule_Learning            = 100.000%
	Probability of Theory                   =   0.000%
------------------------------------------------------------
Example 2: Probabilistic_Methods
	Probability of Case_Based               =   0.133%
	Probability of Genetic_Algorithms       =   0.157%
	Probability of Neural_Networks          =  95.433%
	Probability of Probabilistic_Methods    =   1.003%
	Probability of Reinforcement_Learning   =   0.153%
	Probability of Rule_Learning            =   0.892%
	Probability of Theory                   =   2.229%
------------------------------------------------------------
Example 3: Case_Based
	Probability of Case_Based               =   9.746%
	Probabili