In [None]:
import torch
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.explain import GNNExplainer, Explainer


#testing on one city, can possibly expand to more 
import numpy as np

data_npz = np.load(r"data\TAP-city\aberdeen_md.npz", allow_pickle=True)

# Extract components
x = torch.tensor(data_npz["x"], dtype=torch.float)
y = torch.tensor(data_npz["occur_labels"], dtype=torch.long)
edge_index = torch.tensor(data_npz["edge_index"].T, dtype=torch.long)  # shape [2, num_edges]
edge_attr = torch.tensor(data_npz["edge_attr"], dtype=torch.float)

# Construct the PyG data object
data = Data(x=x, y=y, edge_index=edge_index)


In [9]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


In [10]:
model = GCN(in_channels=x.shape[1], hidden_channels=16, out_channels=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
for epoch in range(20):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = torch.nn.functional.cross_entropy(out, data.y)
    loss.backward()
    optimizer.step()

In [17]:
model.eval()

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=30),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

importance_scores = []

# 🚨 This loop may take a while!
for node_idx in range(data.x.shape[0]):
    explanation = explainer(data.x, data.edge_index, index=int(node_idx))

    # Extract only the mask for the node of interest
    feat_mask = explanation.node_mask[node_idx]

    importance_scores.append(feat_mask.cpu().detach().numpy())

# Aggregate
importance_scores = np.array(importance_scores)
avg_importance = importance_scores.mean(axis=0)

# Rank features
feature_ranking = sorted(
    [(i, float(score)) for i, score in enumerate(avg_importance)],
    key=lambda x: -x[1]
)

# Optional: name features
feature_names = ['highway', 'length', 'bridge', 'lanes', 'oneway']
print("\nTop Features (full graph):")
for idx, importance in feature_ranking:
    name = feature_names[idx] if idx < len(feature_names) else f"feature_{idx}"
    print(f"{name}: {importance:.4f}")



Top Features (full graph):
highway: 0.5108
bridge: 0.3863
oneway: 0.0328
lanes: 0.0182
length: 0.0013
