In [2]:
# File: example.py
# Description: Example of GNNExplainer in link prediction.
# Author: Yuchuan Fu
# Created: 2023-11-23

# Reference Code:
# https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.explain.algorithm.GNNExplainer.html
# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/explain/gnn_explainer_link_pred.py

In [1]:
import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig
from torch_geometric.nn import GCNConv

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
# elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
#     device = torch.device('mps')
else:
    device = torch.device('cpu')

In [14]:
dataset = 'Cora'
# The Cora dataset has 2,708 scientific publications classified into seven classes, connected through 5,429 links. 
# Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the 
# corresponding word from the dictionary. The dictionary consists of 1433 unique words.

# dataset = 'PubMed'
# The PubMed dataset has 19,717 publications classified into three classes, connected through 44,338 links.
# Each publication in the dataset is described by a TF/IDF weighted word vector from a dictionary 
# which consists of 500 unique words.

# DDI dataset:
# The DDI dataset has 1,514 nodes representing drugs approved by the U.S. Food and Drug Administration, 
# and 48,514 edges representing interaction between drugs. The dataset does not provide node features, 
# that are provided as node embedding vectors of fixed dimension 128 computed using Node2Vec

In [15]:
# path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
path = osp.join('..', 'data', 'Planetoid')

# The transform step preprocesses the dataset and makes the edges of 10556 to 
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.1, num_test=0.1, is_undirected=True),
])
dataset = Planetoid(path, dataset, transform=transform)
train_data, val_data, test_data = dataset[0]
dataset[0]

In [19]:
print(train_data.edge_label.count_nonzero())
print(val_data.edge_label.count_nonzero())
print(test_data.edge_label.count_nonzero())

In [23]:
torch.Tensor([[1,2], [3,4], [0,1], [1,2]]).unique(dim=0)
train_data.edge_label_index.T.unique(dim=0)

In [24]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

    def decode(self, z, edge_label_index):
        src, dst = edge_label_index
        return (z[src] * z[dst]).sum(dim=-1)

    def forward(self, x, edge_index, edge_label_index):
        z = model.encode(x, edge_index)
        return model.decode(z, edge_label_index).view(-1)

model = GCN(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

In [25]:
def train():
    model.train()
    optimizer.zero_grad()

    out = model(train_data.x, train_data.edge_index,
                train_data.edge_label_index)
    loss = F.binary_cross_entropy_with_logits(out, train_data.edge_label)
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test(data):
    model.eval()
    out = model(data.x, data.edge_index, data.edge_label_index).sigmoid()
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

for epoch in range(1, 201):
    loss = train()
    if epoch % 20 == 0:
        val_auc = test(val_data)
        test_auc = test(test_data)
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
              f'Test: {test_auc:.4f}')

model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw',
)

In [31]:
val_data.edge_label_index.shape # 2, 1054

val_data.edge_label_index.size(1) # 1054

In [54]:
train_data.edge_index

In [60]:
val_data.edge_label_index[:, [0]].shape

In [57]:
val_data.edge_label_index[:, [0]]

In [None]:
# explanation_type: model / phenomenon
# compute their losses with respect to the model output ("model") or the target output ("phenomenon").

In [27]:
# Explain model output for a single edge:
edge_label_index = val_data.edge_label_index[:, 0]

explainer = Explainer(
    model=model,
    explanation_type='model',
    algorithm=GNNExplainer(epochs=200),
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=model_config,
)
explanation = explainer(
    x=train_data.x,
    edge_index=train_data.edge_index,
    edge_label_index=edge_label_index,
)
print(f'Generated model explanations in {explanation.available_explanations}')

In [28]:
explanation.edge_mask.unique()

In [29]:
explanation.node_mask.shape

In [30]:
# Explain a selected target (phenomenon) for a single edge:
edge_label_index = val_data.edge_label_index[:, 0]
target = val_data.edge_label[0].unsqueeze(dim=0).long()

explainer = Explainer(
    model=model,
    explanation_type='phenomenon',
    algorithm=GNNExplainer(epochs=200),
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=model_config,
)
explanation = explainer(
    x=train_data.x,
    edge_index=train_data.edge_index,
    target=target,
    edge_label_index=edge_label_index,
)
available_explanations = explanation.available_explanations
print(f'Generated phenomenon explanations in {available_explanations}')

In [34]:
pwd

In [62]:
import glob
import numpy as np

mylist = [f for f in glob.glob("../outputs/cora/gcn/gnnexplainer/curves/*.npy")]

res = []
for path in mylist:
    tmp_res = np.load(path)
    res.append(tmp_res)
res

In [69]:
res[3]

In [81]:
def random_line(n, c0, cf, x):
    return c0+(cf-c0)/x*n

def compute_upper_area(c0, cf):
    return 0.5*(2-c0-cf)

def linear_area_score(deletion_curve, normalize=False):
    if normalize:
        deletion_curve = normalize_bounds(deletion_curve)
    c0 = deletion_curve[0] # 0.975
    cf = deletion_curve[-1] # 0.501
    norm = deletion_curve.shape[0] # 1434
    random_baseline = np.array([random_line(n, c0, cf, norm) for n in range(norm)])
    print(random_baseline)
    upper_area = compute_upper_area(c0, cf)
    lower_area = 1.-upper_area
    Ap = np.maximum(deletion_curve-random_baseline, 0).sum()/norm
    Am = np.maximum(random_baseline-deletion_curve, 0).sum()/norm

    return Am/lower_area - Ap/upper_area

In [80]:
res[3][-1]

In [82]:
linear_area_score(res[3])

In [76]:
for r in res:
    print(linear_area_score(r))