Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade visualize_graph for explain module #8743

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion examples/explain/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,6 @@ def forward(self, x, edge_index):
print(f"Feature importance plot has been saved to '{path}'")

path = 'subgraph.pdf'
explanation.visualize_graph(path)
explanation.visualize_graph(path, node_label=None, color_dict=None,
target_node=node_index, draw_node_idx=True)
print(f"Subgraph visualization plot has been saved to '{path}'")
14 changes: 14 additions & 0 deletions examples/explain/gnn_explainer_ba_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def test():

# Explanation ROC AUC over all test nodes:
targets, preds = [], []
# The color for each node class label. For BAshape dataset, according to GNNExplainer Fig.3
color_dict = {0: 'orange', 1: 'green', 2: 'green', 3: 'green'}

node_indices = range(400, data.num_nodes, 5)
for node_index in tqdm(node_indices, leave=False, desc='Train Explainer'):
target = data.y if explanation_type == 'phenomenon' else None
Expand All @@ -94,3 +97,14 @@ def test():

auc = roc_auc_score(torch.cat(targets), torch.cat(preds))
print(f'Mean ROC AUC (explanation type {explanation_type:10}): {auc:.4f}')

node_index = 500
explanation = explainer(data.x, data.edge_index, index=node_index,
target=data.y)
explanation.visualize_graph(
f"GNNExplainer_BAshapes_{node_index}.png",
node_label=data.y,
color_dict=color_dict,
target_node=node_index,
draw_node_idx=False,
)
5 changes: 3 additions & 2 deletions torch_geometric/explain/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def visualize_feature_importance(
return _visualize_score(score, feat_labels, path, top_k)

def visualize_graph(self, path: Optional[str] = None,
backend: Optional[str] = None):
backend: Optional[str] = None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
backend: Optional[str] = None, **kwargs):
backend: Optional[str] = None,
nodel_label: Optional[Tensor] = None,
colors_dict: Optional[Dict[int, str] = None,
target_idx: Optional[int]=None):

Lets add the new arguments as optional arguemnts and add documentation for them. That way the end user is aware of the options available to them.

r"""Visualizes the explanation graph with edge opacity corresponding to
edge importance.

Expand All @@ -246,13 +246,14 @@ def visualize_graph(self, path: Optional[str] = None,
If set to :obj:`None`, will use the most appropriate
visualization backend based on available system packages.
(default: :obj:`None`)
**kwargs: include
"""
edge_mask = self.get('edge_mask')
if edge_mask is None:
raise ValueError(f"The attribute 'edge_mask' is not available "
f"in '{self.__class__.__name__}' "
f"(got {self.available_explanations})")
visualize_graph(self.edge_index, edge_mask, path, backend)
visualize_graph(self.edge_index, edge_mask, path, backend, **kwargs)


class HeteroExplanation(HeteroData, ExplanationMixin):
Expand Down
29 changes: 26 additions & 3 deletions torch_geometric/visualization/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def visualize_graph(
edge_weight: Optional[Tensor] = None,
path: Optional[str] = None,
backend: Optional[str] = None,
**kwargs,
) -> Any:
r"""Visualizes the graph given via :obj:`edge_index` and (optional)
:obj:`edge_weight`.
Expand Down Expand Up @@ -58,7 +59,8 @@ def visualize_graph(
backend = 'graphviz' if has_graphviz() else 'networkx'

if backend.lower() == 'networkx':
return _visualize_graph_via_networkx(edge_index, edge_weight, path)
return _visualize_graph_via_networkx(edge_index, edge_weight, path,
**kwargs)
elif backend.lower() == 'graphviz':
return _visualize_graph_via_graphviz(edge_index, edge_weight, path)

Expand Down Expand Up @@ -98,10 +100,15 @@ def _visualize_graph_via_networkx(
edge_index: Tensor,
edge_weight: Tensor,
path: Optional[str] = None,
**kwargs,
) -> Any:
import matplotlib.pyplot as plt
import networkx as nx

node_label = kwargs['node_label']
color_dict = kwargs['color_dict']
target_node = kwargs['target_node']

g = nx.DiGraph()
node_size = 800

Expand All @@ -127,10 +134,26 @@ def _visualize_graph_via_networkx(
),
)

node_color = ['white'] * len(g.nodes)
if node_label != None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if node_label is None won't all nodes be white?

assert color_dict != None
node_color = []
for i, node_id in enumerate(list(g.nodes)):
node_color.append(color_dict[int(node_label[node_id])])

if target_node != None:
for i, node_id in enumerate(list(g.nodes)):
if node_id == target_node:
print("kylin")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("kylin")

node_color[i] = 'red'
break

nodes = nx.draw_networkx_nodes(g, pos, node_size=node_size,
node_color='white', margins=0.1)
node_color=node_color, margins=0.1)
nodes.set_edgecolor('black')
nx.draw_networkx_labels(g, pos, font_size=10)

if kwargs['draw_node_idx'] == True:
nx.draw_networkx_labels(g, pos, font_size=10)

if path is not None:
plt.savefig(path)
Expand Down
Loading