Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GNNExplainer with GCN #3508

Merged
merged 6 commits into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 11 additions & 9 deletions examples/gnn_explainer.py
Expand Up @@ -9,39 +9,41 @@

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
transform = T.Compose([T.GCNNorm(), T.NormalizeFeatures()])
dataset = Planetoid(path, dataset, transform=transform)
data = dataset[0]


class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
self.conv1 = GCNConv(dataset.num_features, 16, normalize=False)
self.conv2 = GCNConv(16, dataset.num_classes, normalize=False)

def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
def forward(self, x, edge_index, edge_weight):
x = F.relu(self.conv1(x, edge_index, edge_weight))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = self.conv2(x, edge_index, edge_weight)
return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
x, edge_index = data.x, data.edge_index
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight

for epoch in range(1, 201):
model.train()
optimizer.zero_grad()
log_logits = model(x, edge_index)
log_logits = model(x, edge_index, edge_weight)
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()

explainer = GNNExplainer(model, epochs=200, return_type='log_prob')
node_idx = 10
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index,
edge_weight=edge_weight)
ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y)
plt.show()
21 changes: 13 additions & 8 deletions examples/gnn_explainer_ba_shapes.py
Expand Up @@ -8,8 +8,9 @@
from torch_geometric.datasets import BAShapes
from torch_geometric.nn import GCN, GNNExplainer
from torch_geometric.utils import k_hop_subgraph
import torch_geometric.transforms as T

dataset = BAShapes()
dataset = BAShapes(transform=T.GCNNorm())
data = dataset[0]

idx = torch.arange(data.num_nodes)
Expand All @@ -18,14 +19,14 @@
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
model = GCN(data.num_node_features, hidden_channels=20, num_layers=3,
out_channels=dataset.num_classes).to(device)
out_channels=dataset.num_classes, normalize=False).to(device)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005)


def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
out = model(data.x, data.edge_index, data.edge_weight)
loss = F.cross_entropy(out[train_idx], data.y[train_idx])
torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
loss.backward()
Expand All @@ -36,7 +37,7 @@ def train():
@torch.no_grad()
def test():
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=-1)
pred = model(data.x, data.edge_index, data.edge_weight).argmax(dim=-1)

train_correct = int((pred[train_idx] == data.y[train_idx]).sum())
train_acc = train_correct / train_idx.size(0)
Expand All @@ -56,14 +57,18 @@ def test():

model.eval()
targets, preds = [], []
explainer = GNNExplainer(model, epochs=300, return_type='raw', log=False)
expl = GNNExplainer(model, epochs=300, return_type='raw', log=False)

# Explanation ROC AUC over all test nodes:
self_loop_mask = data.edge_index[0] != data.edge_index[1]
for node_idx in tqdm(data.expl_mask.nonzero(as_tuple=False).view(-1).tolist()):
_, edge_mask = explainer.explain_node(node_idx, data.x, data.edge_index)
_, expl_edge_mask = expl.explain_node(node_idx, data.x, data.edge_index,
edge_weight=data.edge_weight)
subgraph = k_hop_subgraph(node_idx, num_hops=3, edge_index=data.edge_index)
targets.append(data.edge_label[subgraph[3]].cpu())
preds.append(edge_mask[subgraph[3]].cpu())
expl_edge_mask = expl_edge_mask[self_loop_mask]
subgraph_edge_mask = subgraph[3][self_loop_mask]
targets.append(data.edge_label[subgraph_edge_mask].cpu())
preds.append(expl_edge_mask[subgraph_edge_mask].cpu())

auc = roc_auc_score(torch.cat(targets), torch.cat(preds))
print(f'Mean ROC AUC: {auc:.4f}')