In [1]:
import sys, os, torch
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from src.models.gcn import GCN
from src.models.gat import GAT
from src.tasks.node_classify import train_model
from src.utils.explain import get_attention, top_k_neighbors, plot_attention_subgraph

# GAT

In [None]:
# Train a GAT
_, _, model, data, ds = train_model(
    GAT,
    {"hid": 8, "heads": 12, "dropout": 0.7},
    epochs=300, lr=0.005, wd=5e-4, seed=4200
)

Epoch 001 | loss 1.9453 | val 0.212 | test 0.221
Epoch 020 | loss 1.7475 | val 0.778 | test 0.805
Epoch 040 | loss 1.4078 | val 0.792 | test 0.782
Epoch 060 | loss 1.2410 | val 0.782 | test 0.804
Epoch 080 | loss 1.0892 | val 0.806 | test 0.817
Epoch 100 | loss 1.0349 | val 0.796 | test 0.818
Epoch 120 | loss 0.9437 | val 0.792 | test 0.814
Epoch 140 | loss 0.8660 | val 0.794 | test 0.817
Epoch 160 | loss 0.7640 | val 0.800 | test 0.814
Epoch 180 | loss 0.8391 | val 0.794 | test 0.828
Epoch 200 | loss 0.7639 | val 0.790 | test 0.818
Epoch 220 | loss 0.6997 | val 0.802 | test 0.823
Epoch 240 | loss 0.7577 | val 0.792 | test 0.806
Epoch 260 | loss 0.8540 | val 0.792 | test 0.803
Epoch 280 | loss 0.6776 | val 0.790 | test 0.814
Epoch 300 | loss 0.7573 | val 0.798 | test 0.825
Best val: 0.814 | Final test: 0.819


In [None]:
# 2) Grab attention from layer 0
ei_used, alpha = get_attention(model, data.x, data.edge_index, layer=0)

# 3) Pick a node and get top-5 neighbors by attention
node_id = 42
nbrs = top_k_neighbors(ei_used, alpha, node_id, k=5)
print("Top neighbors:", nbrs)

# 4) Plot a mini star graph
sys.path.insert(0, os.path.abspath(".."))
plot_attention_subgraph(G=None, center=node_id, nbrs=nbrs, save_path="../results/attention/node_42.png")