In [1]:
import logging
from math import sqrt
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import (
    ExplainerConfig,
    MaskType,
    ModelConfig,
    ModelMode,
    ModelTaskLevel,
)

from torch_geometric.explain.algorithm.base import ExplainerAlgorithm


In [2]:
from abc import abstractmethod
from typing import Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.explain import Explanation
from torch_geometric.explain.config import (
    ExplainerConfig,
    ModelConfig,
    ModelReturnType,
)
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import k_hop_subgraph

In [3]:
from torch_geometric.datasets import Entities
from torch_geometric.nn import FastRGCNConv, RGCNConv
from torch_geometric.utils import k_hop_subgraph

path = '/Users/macoftraopia/Documents/GitHub/RGCN-Explainer/../data/Entities' #osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities')
dataset = Entities(path, 'AIFB')
data = dataset[0]

# BGS and AM graphs are too big to process them in a full-batch fashion.
# Since our model does only make use of a rather small receptive field, we
# filter the graph to only contain the nodes that are at most 2-hop neighbors
# away from any training/test node.
node_idx = torch.cat([data.train_idx, data.test_idx], dim=0)
node_idx, edge_index, mapping, edge_mask = k_hop_subgraph(
    node_idx, 2, data.edge_index, relabel_nodes=True)

data.num_nodes = node_idx.size(0)
data.edge_index = edge_index
data.edge_type = data.edge_type[edge_mask]
data.train_idx = mapping[:data.train_idx.size(0)]
data.test_idx = mapping[data.train_idx.size(0):]
node_idx = 0

In [17]:
data.node_stores

[{'edge_index': tensor([[   0,    1,    1,  ..., 6908, 6908, 6908],
         [  26, 2003, 6825,  ..., 3612, 3612, 3881]]), 'edge_type': tensor([13, 13, 13,  ...,  0,  5,  2]), 'train_idx': tensor([6348, 2772, 1274, 3084, 2295, 1023,  615, 1909,  529, 4696, 6412, 2129,
         4636, 5713, 6843, 4778, 4063,  396, 5200, 5278, 5462, 3954,  998, 5387,
         4666,    3, 4759, 4333, 1073, 6787, 3085, 5517, 6824, 2611, 4273, 2808,
         3867,  602, 3068, 6901, 4717,   59, 4354, 1130, 2853, 2363, 5927, 6021,
         1173, 2612, 4132, 5091,  826, 3764, 6867, 5821, 5293,  704, 6658, 5569,
         5014, 3022, 4979, 4256, 5982, 5148, 4165, 2852, 3087, 4812, 4091, 6366,
         6578, 1895, 4798, 3449, 5662, 2584, 5668, 1592, 6083, 3053,  350,  632,
         5697,  112, 2243, 2476, 5798, 6836, 3248, 4083, 3445, 4850, 5367, 3179,
         5090, 5667, 1818,   98, 5751, 1460, 6337, 2475, 2779, 5711, 2882,  367,
          919, 6463, 1654, 3914, 4039, 6809, 4709, 6461, 1917, 6218, 2001, 4161,
  

Model

In [4]:
import argparse
import os.path as osp

import torch
import torch.nn.functional as F



# Trade memory consumption for faster computation.
#if args.dataset in ['AIFB', 'MUTAG']:
RGCNConv = FastRGCNConv




class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = RGCNConv(data.num_nodes, 16, dataset.num_relations,
                              num_bases=30)
        self.conv2 = RGCNConv(16, dataset.num_classes, dataset.num_relations,
                              num_bases=30)

    def forward(self, edge_index, edge_type):
        x = F.relu(self.conv1(None, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)
        return F.log_softmax(x, dim=1)
    

    def forward(self, data):
        edge_index, edge_type = data.edge_index, data.edge_type
        x = F.relu(self.conv1(None, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)
        return F.log_softmax(x, dim=1)

    def forward(self, edge_type, edge_index):
        #edge_type edge_type.unsqueeze(1)
        x = F.relu(self.conv1(None, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)
        return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')# if args.dataset == 'AM' else device
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.edge_type, data.edge_index )
    #out = model(data)
    loss = F.nll_loss(out[data.train_idx], data.train_y)
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.edge_type ,data.edge_index).argmax(dim=-1)
    #pred = model(data).argmax(dim=-1)
    train_acc = float((pred[data.train_idx] == data.train_y).float().mean())
    test_acc = float((pred[data.test_idx] == data.test_y).float().mean())
    torch.save(pred[data.test_idx], 'aifb_chk/prediction_aifb')
    return train_acc, test_acc


for epoch in range(1, 51):
    loss = train()
    train_acc, test_acc = test()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} '
          f'Test: {test_acc:.4f}')
    
torch.save(model, 'aifb_chk/model_aifb')    

Epoch: 01, Loss: 1.3934, Train: 0.9429 Test: 0.7500
Epoch: 02, Loss: 0.7704, Train: 0.9643 Test: 0.8333
Epoch: 03, Loss: 0.3287, Train: 0.9714 Test: 0.8611
Epoch: 04, Loss: 0.1438, Train: 0.9857 Test: 0.8611
Epoch: 05, Loss: 0.0820, Train: 0.9857 Test: 0.9167
Epoch: 06, Loss: 0.0562, Train: 0.9857 Test: 0.9167
Epoch: 07, Loss: 0.0392, Train: 0.9857 Test: 0.9167
Epoch: 08, Loss: 0.0240, Train: 1.0000 Test: 0.9167
Epoch: 09, Loss: 0.0117, Train: 1.0000 Test: 0.9167
Epoch: 10, Loss: 0.0063, Train: 1.0000 Test: 0.8889
Epoch: 11, Loss: 0.0054, Train: 1.0000 Test: 0.9167
Epoch: 12, Loss: 0.0047, Train: 1.0000 Test: 0.8611
Epoch: 13, Loss: 0.0034, Train: 1.0000 Test: 0.8611
Epoch: 14, Loss: 0.0020, Train: 1.0000 Test: 0.8889
Epoch: 15, Loss: 0.0011, Train: 1.0000 Test: 0.8889
Epoch: 16, Loss: 0.0006, Train: 1.0000 Test: 0.8889
Epoch: 17, Loss: 0.0004, Train: 1.0000 Test: 0.8889
Epoch: 18, Loss: 0.0003, Train: 1.0000 Test: 0.8889
Epoch: 19, Loss: 0.0002, Train: 1.0000 Test: 0.8889
Epoch: 20, L

K hop subgraph

In [5]:
def k_hop_subgraph(
    node_idx: int,#Union[int, List[int], Tensor],
    num_hops: int,
    edge_index: Tensor,
    relabel_nodes: bool = False,
    num_nodes: Optional[int] = None,
    flow: str = 'source_to_target',
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    
    #num_nodes = maybe_num_nodes(edge_index, num_nodes)
    num_nodes = num_nodes
    assert flow in ['source_to_target', 'target_to_source']
    if flow == 'target_to_source':
        row, col = edge_index
    else:
        col, row = edge_index

    node_mask = row.new_empty(num_nodes, dtype=torch.bool)
    edge_mask = row.new_empty(row.size(0), dtype=torch.bool)

    if isinstance(node_idx, (int, list, tuple)):
        node_idx = torch.tensor([node_idx], device=row.device).flatten()
    else:
        node_idx = node_idx.to(row.device)

    subsets = [node_idx]

    for _ in range(num_hops):
        node_mask.fill_(False)
        node_mask[subsets[-1]] = True
        torch.index_select(node_mask, 0, row, out=edge_mask)
        subsets.append(col[edge_mask])

    subset, inv = torch.cat(subsets).unique(return_inverse=True)
    inv = inv[:node_idx.numel()]

    node_mask.fill_(False)
    node_mask[subset] = True
    edge_mask = node_mask[row] & node_mask[col]

    edge_index = edge_index[:, edge_mask]

    if relabel_nodes:
        node_idx = row.new_full((num_nodes, ), -1)
        node_idx[subset] = torch.arange(subset.size(0), device=row.device)
        edge_index = node_idx[edge_index]

    return subset, edge_index, inv, edge_mask

k_hop_subgraph(0,2, data.edge_index, num_nodes=data.num_nodes)

(tensor([   0,   26,  890, 1496, 2037, 2577, 4368, 5108, 5446, 5517, 6804, 6867]),
 tensor([[   0,   26,   26,   26,   26,   26,   26,   26,   26,   26,   26,   26,
            26,   26,  890, 1496, 1496, 1496, 2037, 2577, 2577, 4368, 4368, 5108,
          5108, 5446, 5517, 5517, 5517, 5517, 5517, 6804, 6804, 6804, 6867, 6867,
          6867, 6867],
         [  26,    0,  890, 1496, 2037, 2577, 4368, 5108, 5446, 5517, 5517, 6804,
          6867, 6867,   26,   26, 5517, 6867,   26,   26, 5517,   26, 5108,   26,
          4368,   26,   26,   26, 1496, 2577, 6804,   26, 5517, 6867,   26,   26,
          1496, 6804]]),
 tensor([0]),
 tensor([ True, False, False,  ..., False, False, False]))

Make masks

In [6]:
def _get_hard_masks(
    model: torch.nn.Module,
    index: Optional[Union[int, Tensor]],
    edge_index: Tensor,
    num_nodes: int,
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
    r"""Returns hard node and edge masks that only include the nodes and
    edges visited during message passing."""
    if index is None:
        return None, None  # Consider all nodes and edges.

    index, _, _, edge_mask = k_hop_subgraph(
        index,
        num_hops= 2,
        edge_index=edge_index,
        num_nodes=num_nodes,
        flow=ExplainerAlgorithm._flow(model),
    )

    node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)
    node_mask[index] = True

    return node_mask, edge_mask

_get_hard_masks(model, index = node_idx, edge_index=data.edge_index, num_nodes=data.num_nodes)

(tensor([ True, False, False,  ..., False, False, False]),
 tensor([ True, False, False,  ..., False, False, False]))

Train

ExplainerAlgorithm : to inherit properties from 

In [7]:
from torch_geometric.data.data import Data, warn_or_raise
from typing import List, Optional
import copy

In [9]:
class Explanation(Data):
    r"""Holds all the obtained explanations of a homogenous graph.

    The explanation object is a :obj:`~torch_geometric.data.Data` object and
    can hold node-attribution, edge-attribution, feature-attribution. It can
    also hold the original graph if needed.

    Args:
        node_mask (Tensor, optional): Node-level mask with shape
            :obj:`[num_nodes]`. (default: :obj:`None`)
        edge_mask (Tensor, optional): Edge-level mask with shape
            :obj:`[num_edges]`. (default: :obj:`None`)
        node_feat_mask (Tensor, optional): Node-level feature mask with shape
            :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`)
        edge_feat_mask (Tensor, optional): Edge-level feature mask with shape
            :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
        **kwargs (optional): Additional attributes.
    """
    def __init__(
        self,
        node_mask: Optional[Tensor] = None,
        edge_mask: Optional[Tensor] = None,
        node_feat_mask: Optional[Tensor] = None,
        edge_feat_mask: Optional[Tensor] = None,
        **kwargs,
    ):
        super().__init__(
            node_mask=node_mask,
            edge_mask=edge_mask,
            node_feat_mask=node_feat_mask,
            edge_feat_mask=edge_feat_mask,
            **kwargs,
        )

    @property
    def available_explanations(self) -> List[str]:
        """Returns the available explanation masks."""
        return [
            key for key in self.keys
            if key.endswith('_mask') and self[key] is not None
        ]

    def validate(self, raise_on_error: bool = True) -> bool:
        r"""Validates the correctness of the explanation"""
        status = super().validate()

        if 'node_mask' in self and self.num_nodes != self.node_mask.size(0):
            status = False
            warn_or_raise(
                f"Expected a 'node_mask' with {self.num_nodes} nodes "
                f"(got {self.node_mask.size(0)} nodes)", raise_on_error)

        if 'edge_mask' in self and self.num_edges != self.edge_mask.size(0):
            status = False
            warn_or_raise(
                f"Expected an 'edge_mask' with {self.num_edges} edges "
                f"(got {self.edge_mask.size(0)} edges)", raise_on_error)

        if 'node_feat_mask' in self:
            if 'x' in self and self.x.size() != self.node_feat_mask.size():
                status = False
                warn_or_raise(
                    f"Expected a 'node_feat_mask' of shape "
                    f"{list(self.x.size())} (got shape "
                    f"{list(self.node_feat_mask.size())})", raise_on_error)
            elif self.num_nodes != self.node_feat_mask.size(0):
                status = False
                warn_or_raise(
                    f"Expected a 'node_feat_mask' with {self.num_nodes} nodes "
                    f"(got {self.node_feat_mask.size(0)} nodes)",
                    raise_on_error)

        if 'edge_feat_mask' in self:
            if ('edge_attr' in self
                    and self.edge_attr.size() != self.edge_feat_mask.size()):
                status = False
                warn_or_raise(
                    f"Expected an 'edge_feat_mask' of shape "
                    f"{list(self.edge_attr.size())} (got shape "
                    f"{list(self.edge_feat_mask.size())})", raise_on_error)
            elif self.num_edges != self.edge_feat_mask.size(0):
                status = False
                warn_or_raise(
                    f"Expected an 'edge_feat_mask' with {self.num_edges} "
                    f"edges (got {self.edge_feat_mask.size(0)} edges)",
                    raise_on_error)

        return status

    def get_explanation_subgraph(self) -> 'Explanation':
        r"""Returns the induced subgraph, in which all nodes and edges with
        zero attribution are masked out."""
        return self._apply_masks(
            node_mask=self.node_mask > 0 if 'node_mask' in self else None,
            edge_mask=self.edge_mask > 0 if 'edge_mask' in self else None,
        )

    def get_complement_subgraph(self) -> 'Explanation':
        r"""Returns the induced subgraph, in which all nodes and edges with any
        attribution are masked out."""
        return self._apply_masks(
            node_mask=self.node_mask == 0 if 'node_mask' in self else None,
            edge_mask=self.edge_mask == 0 if 'edge_mask' in self else None,
        )

    def _apply_masks(
        self,
        node_mask: Optional[Tensor] = None,
        edge_mask: Optional[Tensor] = None,
    ) -> 'Explanation':
        out = copy.copy(self)

        if edge_mask is not None:
            for key, value in self.items():
                if key == 'edge_index':
                    out.edge_index = value[:, edge_mask]
                elif self.is_edge_attr(key):
                    out[key] = value[edge_mask]

        if node_mask is not None:
            out = out.subgraph(node_mask)

        return out



In [14]:
A = Explanation(data, edge_mask=edge_mask)
A.available_explanations


['node_mask', 'edge_mask']

In [15]:
class ExplainerAlgorithm(torch.nn.Module):
    r"""Abstract base class for explainer algorithms."""
    @abstractmethod
    def forward(
        self,
        model: torch.nn.Module,
        edge_index: Tensor,
        *,
        target: Tensor,
        x = None, 
        index: Optional[Union[int, Tensor]] = None,
        target_index: Optional[int] = None,
        **kwargs,
    ) -> Explanation:
        r"""Computes the explanation.

        Args:
            model (torch.nn.Module): The model to explain.
            x (torch.Tensor): The input node features.
            edge_index (torch.Tensor): The input edge indices.
            target (torch.Tensor): The target of the model.
            index (Union[int, Tensor], optional): The index of the model
                output to explain. Can be a single index or a tensor of
                indices. (default: :obj:`None`)
            target_index (int, optional): The index of the model outputs to
                reference in case the model returns a list of tensors, *e.g.*,
                in a multi-task learning scenario. Should be kept to
                :obj:`None` in case the model only returns a single output
                tensor. (default: :obj:`None`)
            **kwargs (optional): Additional keyword arguments passed to
                :obj:`model`.
        """

    @abstractmethod
    def supports(self) -> bool:
        r"""Checks if the explainer supports the user-defined settings provided
        in :obj:`self.explainer_config` and :obj:`self.model_config`."""
        pass

    ###########################################################################

    @property
    def explainer_config(self) -> ExplainerConfig:
        r"""Returns the connected explainer configuration."""
        if not hasattr(self, '_explainer_config'):
            raise ValueError(
                f"The explanation algorithm '{self.__class__.__name__}' is "
                f"not yet connected to any explainer configuration. Please "
                f"call `{self.__class__.__name__}.connect(...)` before "
                f"proceeding.")
        return self._explainer_config

    @property
    def model_config(self) -> ModelConfig:
        r"""Returns the connected model configuration."""
        if not hasattr(self, '_model_config'):
            raise ValueError(
                f"The explanation algorithm '{self.__class__.__name__}' is "
                f"not yet connected to any model configuration. Please call "
                f"`{self.__class__.__name__}.connect(...)` before "
                f"proceeding.")
        return self._model_config

    def connect(
        self,
        explainer_config: ExplainerConfig,
        model_config: ModelConfig,
    ):
        r"""Connects an explainer and model configuration to the explainer
        algorithm."""
        self._explainer_config = ExplainerConfig.cast(explainer_config)
        self._model_config = ModelConfig.cast(model_config)

        if not self.supports():
            raise ValueError(
                f"The explanation algorithm '{self.__class__.__name__}' does "
                f"not support the given explanation settings.")

    # Helper functions ########################################################

    @staticmethod
    def _post_process_mask(
        mask: Optional[Tensor],
        num_elems: int,
        hard_mask: Optional[Tensor] = None,
        apply_sigmoid: bool = True,
    ) -> Optional[Tensor]:
        r""""Post processes any mask to not include any attributions of
        elements not involved during message passing."""
        if mask is None:
            return mask

        if mask.size(0) == 1:  # common_attributes:
            mask = mask.repeat(num_elems, 1)

        mask = mask.detach().squeeze(-1)

        if apply_sigmoid:
            mask = mask.sigmoid()

        if hard_mask is not None:
            mask[~hard_mask] = 0.

        return mask

    @staticmethod
    def _get_hard_masks(
        model: torch.nn.Module,
        index: Optional[Union[int, Tensor]],
        edge_index: Tensor,
        num_nodes: int,
    ) -> Tuple[Optional[Tensor], Optional[Tensor]]:
        r"""Returns hard node and edge masks that only include the nodes and
        edges visited during message passing."""
        if index is None:
            return None, None  # Consider all nodes and edges.

        index, _, _, edge_mask = k_hop_subgraph(
            index,
            num_hops=ExplainerAlgorithm._num_hops(model),
            edge_index=edge_index,
            num_nodes=num_nodes,
            flow=ExplainerAlgorithm._flow(model),
        )

        node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)
        node_mask[index] = True

        return node_mask, edge_mask

    @staticmethod
    def _num_hops(model: torch.nn.Module) -> int:
        r"""Returns the number of hops the :obj:`model` is aggregating
        information from.
        """
        num_hops = 0
        for module in model.modules():
            if isinstance(module, MessagePassing):
                num_hops += 1
        return num_hops

    @staticmethod
    def _flow(model: torch.nn.Module) -> str:
        r"""Determines the message passing flow of the :obj:`model`."""
        for module in model.modules():
            if isinstance(module, MessagePassing):
                return module.flow
        return 'source_to_target'

    def _to_log_prob(self, y: Tensor) -> Tensor:
        r"""Converts the model output to log-probabilities.

        Args:
            y (Tensor): The output of the model.
        """
        if self.model_config.return_type == ModelReturnType.probs:
            return y.log()
        if self.model_config.return_type == ModelReturnType.raw:
            return y.log_softmax(dim=-1)
        if self.model_config.return_type == ModelReturnType.log_probs:
            return y
        raise NotImplementedError


In [16]:
class GNNExplainer(ExplainerAlgorithm):
    r"""The GNN-Explainer model from the `"GNNExplainer: Generating
    Explanations for Graph Neural Networks"
    <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
    structures and node features that play a crucial role in the predictions
    made by a GNN.

    The following configurations are currently supported:

    - :class:`torch_geometric.explain.config.ModelConfig`

        - :attr:`task_level`: :obj:`"node"`, :obj:`"edge"`, or :obj:`"graph"`

    - :class:`torch_geometric.explain.config.ExplainerConfig`

        - :attr:`node_mask_type`: :obj:`"object"`, :obj:`"common_attributes"`
          or :obj:`"attributes"`

        - :attr:`edge_mask_type`: :obj:`"object"` or :obj:`None`

    .. note::

        For an example of using :class:`GNNExplainer`, see
        `examples/gnn_explainer.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/gnn_explainer.py>`_ and
        `examples/gnn_explainer_ba_shapes.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/gnn_explainer_ba_shapes.py>`_.

    Args:
        epochs (int, optional): The number of epochs to train.
            (default: :obj:`100`)
        lr (float, optional): The learning rate to apply.
            (default: :obj:`0.01`)
        **kwargs (optional): Additional hyper-parameters to override default
            settings in
            :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
    """

    coeffs = {
        'edge_size': 0.005,
        'edge_reduction': 'sum',
        'node_feat_size': 1.0,
        'node_feat_reduction': 'mean',
        'edge_ent': 1.0,
        'node_feat_ent': 0.1,
        'EPS': 1e-15,
    }

    def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs):
        super().__init__()
        self.epochs = epochs
        self.lr = lr
        self.coeffs.update(kwargs)
        self.Features = False

        self.node_mask = self.edge_mask = None

    def supports(self) -> bool:
        task_level = self.model_config.task_level
        if task_level not in [
                ModelTaskLevel.node, ModelTaskLevel.edge, ModelTaskLevel.graph
        ]:
            logging.error(f"Task level '{task_level.value}' not supported")
            return False

        edge_mask_type = self.explainer_config.edge_mask_type
        if edge_mask_type not in [MaskType.object, None]:
            logging.error(f"Edge mask type '{edge_mask_type.value}' not "
                          f"supported")
            return False

        node_mask_type = self.explainer_config.node_mask_type
        if node_mask_type not in [
                MaskType.common_attributes, MaskType.object,
                MaskType.attributes
        ]:
            logging.error(f"Node mask type '{node_mask_type.value}' not "
                          f"supported.")
            return False

        return True

    def forward(
        self,
        model: torch.nn.Module,
        edge_index: Tensor,
        *,
        target: Tensor,
        x: Tensor = None,
        index: Optional[Union[int, Tensor]] = None,
        target_index: Optional[int] = None,
        **kwargs,
    ) -> Explanation:
        hard_node_mask = hard_edge_mask = None
        if self.model_config.task_level == ModelTaskLevel.node:
            # We need to compute hard masks to properly clean up edges and
            # nodes attributions not involved during message passing:
            hard_node_mask, hard_edge_mask = self._get_hard_masks(
                model, index, edge_index, num_nodes=x.size(0))

        self._train(model, x, edge_index, target=target, index=index,
                    target_index=target_index, **kwargs)

        node_mask = self._post_process_mask(self.node_mask, x.size(0),
                                            hard_node_mask, apply_sigmoid=True)
        edge_mask = self._post_process_mask(self.edge_mask, edge_index.size(1),
                                            hard_edge_mask, apply_sigmoid=True)
        print(self.edge_mask)
        self._clean_model(model)

        # TODO Consider dropping differentiation between `mask` and `feat_mask`
        node_feat_mask = None
        if self.explainer_config.node_mask_type in {
                MaskType.attributes, MaskType.common_attributes
        }:
            node_feat_mask, node_mask = node_mask, None

        return Explanation(x=x, edge_index=edge_index, edge_mask=edge_mask,
                           node_mask=node_mask, node_feat_mask=node_feat_mask)

    def _train(
        self,
        model: torch.nn.Module,
        #x: Optional[Tensor],
        edge_index: Tensor,
        *,
        target: Tensor,
        x = None, 
        index: Optional[Union[int, Tensor]] = None,
        target_index: Optional[int] = None,
        **kwargs,
    ):
        if self.Features:
            self._initialize_masks(x, edge_index, features=True)

            parameters = [self.node_mask]  # We always learn a node mask.
            if self.explainer_config.edge_mask_type is not None:
                set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
                parameters.append(self.edge_mask)

            optimizer = torch.optim.Adam(parameters, lr=self.lr)

            for _ in range(self.epochs):
                optimizer.zero_grad()

                h = x * self.node_mask.sigmoid()
                y_hat, y = model(h, edge_index, **kwargs), target

                if target_index is not None:
                    y_hat, y = y_hat[target_index], y[target_index]
                if index is not None:
                    y_hat, y = y_hat[index], y[index]

                loss = self._loss(y_hat, y)

                loss.backward()
                optimizer.step()
        else:
            self._initialize_masks(edge_index) 
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
            parameters = [self.edge_mask]
            optimizer = torch.optim.Adam(parameters, lr=self.lr)

            for _ in range(self.epochs):
                optimizer.zero_grad()

                h = x * self.node_mask.sigmoid()
                y_hat, y = model(h, edge_index, **kwargs), target

                if target_index is not None:
                    y_hat, y = y_hat[target_index], y[target_index]
                if index is not None:
                    y_hat, y = y_hat[index], y[index]

                if self.Features:
                    loss = self._loss(y_hat, y, Features=True)
                else:
                    loss = self._loss(y_hat, y)

                loss.backward()
                optimizer.step()


    def _initialize_masks(self, edge_index: Tensor, x = None):
        if self.Features:
        
            node_mask_type = self.explainer_config.node_mask_type
            node_mask_type = MaskType.object
            edge_mask_type = self.explainer_config.edge_mask_type

            device = x.device
            (N, F), E = x.size(), edge_index.size(1)

            std = 0.1
            if node_mask_type == MaskType.object:
                self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
            elif node_mask_type == MaskType.attributes:
                self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
            elif node_mask_type == MaskType.common_attributes:
                self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
            else:
                raise NotImplementedError

            if edge_mask_type == MaskType.object:
                std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
                self.edge_mask = Parameter(torch.randn(E, device=device) * std)
            elif edge_mask_type is not None:
                raise NotImplementedError
        else:
            edge_mask_type = self.explainer_config.edge_mask_type
            E = edge_index.size(1)
            if edge_mask_type == MaskType.object:
                std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
                self.edge_mask = Parameter(torch.randn(E, device=device) * std)
                print('edge mask: ', self.edge_mask)
            elif edge_mask_type is not None:
                raise NotImplementedError



    def _loss_regression(self, y_hat: Tensor, y: Tensor) -> Tensor:
        return F.mse_loss(y_hat, y)

    def _loss_classification(self, y_hat: Tensor, y: Tensor) -> Tensor:
        if y.dim() == 0:  # `index` was given as an integer.
            y_hat, y = y_hat.unsqueeze(0), y.unsqueeze(0)

        y_hat = self._to_log_prob(y_hat)

        return (-y_hat).gather(1, y.view(-1, 1)).mean()

    def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
        if self.model_config.mode == ModelMode.regression:
            loss = self._loss_regression(y_hat, y)
        elif self.model_config.mode == ModelMode.classification:
            loss = self._loss_classification(y_hat, y)
        else:
            raise NotImplementedError

        if self.explainer_config.edge_mask_type is not None:
            m = self.edge_mask.sigmoid()
            edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
            loss = loss + self.coeffs['edge_size'] * edge_reduce(m)
            ent = -m * torch.log(m + self.coeffs['EPS']) - (
                1 - m) * torch.log(1 - m + self.coeffs['EPS'])
            loss = loss + self.coeffs['edge_ent'] * ent.mean()

        if self.Features:
            m = self.node_mask.sigmoid()
            node_feat_reduce = getattr(torch, self.coeffs['node_feat_reduction'])
            loss = loss + self.coeffs['node_feat_size'] * node_feat_reduce(m)
            ent = -m * torch.log(m + self.coeffs['EPS']) - (
                1 - m) * torch.log(1 - m + self.coeffs['EPS'])
            loss = loss + self.coeffs['node_feat_ent'] * ent.mean()

        return loss

    def _clean_model(self, model):
        clear_masks(model)
        self.node_mask = None
        self.edge_mask = None


In [55]:
class GNNExplainer_:
    r"""Deprecated version for :class:`GNNExplainer`."""

    coeffs = GNNExplainer.coeffs

    conversion_node_mask_type = {
        'feature': 'common_attributes',
        'individual_feature': 'attributes',
        'scalar': 'object',
    }

    conversion_return_type = {
        'log_prob': 'log_probs',
        'prob': 'probs',
        'raw': 'raw',
        'regression': 'raw',
    }

    def __init__(
        self,
        model: torch.nn.Module,
        epochs: int = 100,
        lr: float = 0.01,
        return_type: str = 'log_prob',
        feat_mask_type: str = 'feature',
        allow_edge_mask: bool = True,
        **kwargs,
    ):
        assert feat_mask_type in ['feature', 'individual_feature', 'scalar']

        explainer_config = ExplainerConfig(
            explanation_type='model',
            node_mask_type=self.conversion_node_mask_type[feat_mask_type],
            edge_mask_type=MaskType.object if allow_edge_mask else None,
        )
        model_config = ModelConfig(
            mode='regression'
            if return_type == 'regression' else 'classification',
            task_level=ModelTaskLevel.node,
            return_type=self.conversion_return_type[return_type],
        )

        self.model = model
        self._explainer = GNNExplainer(epochs=epochs, lr=lr, **kwargs)
        self._explainer.connect(explainer_config, model_config)

    @torch.no_grad()
    def get_initial_prediction(self, *args, **kwargs) -> Tensor:

        training = self.model.training
        self.model.eval()

        out = self.model(*args, **kwargs)
        if self._explainer.model_config.mode == ModelMode.classification:
            out = out.argmax(dim=-1)

        self.model.train(training)

        return out

    def explain_graph(
        self,
        x: Tensor,
        edge_index: Tensor,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        self._explainer.model_config.task_level = ModelTaskLevel.graph

        explanation = self._explainer(
            self.model,
            x,
            edge_index,
            target=self.get_initial_prediction(x, edge_index, **kwargs),
            **kwargs,
        )
        return self._convert_output(explanation, edge_index)

    def explain_node(
        self,
        node_idx: int,
        x: Tensor,
        edge_index: Tensor,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        self._explainer.model_config.task_level = ModelTaskLevel.node
        explanation = self._explainer(
            self.model,
            x,
            edge_index,
            target=self.get_initial_prediction(x, edge_index, **kwargs),
            index=node_idx,
            **kwargs,
        )
        return self._convert_output(explanation, edge_index, index=node_idx,
                                    x=x)
    
    def explain_rel_node(
        self,
        node_idx: int,
        edge_index: Tensor,
        edge_type: Tensor = None, 
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        self._explainer.model_config.task_level = ModelTaskLevel.node
        explanation = self._explainer(
            self.model,
            x,
            edge_index,
            target=self.get_initial_prediction(edge_index, edge_type, **kwargs),
            index=node_idx,
            **kwargs,
        )
        return self._convert_output(explanation, edge_index, index=node_idx,
                                    x=x)

    def _convert_output(self, explanation, edge_index, index=None, x=None):
        if 'node_mask' in explanation.available_explanations:
            node_mask = explanation.node_mask
        else:
            if (self._explainer.explainer_config.node_mask_type ==
                    MaskType.common_attributes):
                node_mask = explanation.node_feat_mask[0]
            else:
                node_mask = explanation.node_feat_mask

        edge_mask = None
        if 'edge_mask' in explanation.available_explanations:
            edge_mask = explanation.edge_mask
        else:
            if index is not None:
                _, edge_mask = self._explainer._get_hard_masks(
                    self.model, index, edge_index, num_nodes=x.size(0))
                edge_mask = edge_mask.to(x.dtype)
            else:
                edge_mask = torch.ones(edge_index.shape[1],
                                       device=edge_index.device)

        return node_mask, edge_mask


In [56]:
class Net(torch.nn.Module):
    def __init__(self,data):
        super().__init__()
        self.conv1 = RGCNConv(data.num_nodes, 16, dataset.num_relations,
                              num_bases=30)
        self.conv2 = RGCNConv(16, dataset.num_classes, dataset.num_relations,
                              num_bases=30)

    def forward(self, edge_index, edge_type):
        x = F.relu(self.conv1(None, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)
        return F.log_softmax(x, dim=1)
    

    def forward(self):
        edge_index, edge_type = data.edge_index, data.edge_type
        x = F.relu(self.conv1(None, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)
        return F.log_softmax(x, dim=1)

    # def forward(self, edge_type, edge_index):
    #     #edge_type edge_type.unsqueeze(1)
    #     x = F.relu(self.conv1(None, edge_index, edge_type))
    #     x = self.conv2(x, edge_index, edge_type)
    #     return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')# if args.dataset == 'AM' else device
model, data = Net(data).to(device), data.to(device)

In [62]:
def explain_node(node_idx, model,data):
    explanation = GNNExplainer(
            model,
            edge_index,
            target= model(),
            index=node_idx
        )
    print(explanation)
    return _convert_output(explanation, edge_index, data, index=node_idx)
    

In [45]:
def _convert_output(explanation, edge_index,data,  index=0, x=None):

    # node_mask = explanation.node_mask

    # edge_mask = explanation.edge_mask
    # else:
    #if index is not None:
    _, edge_mask = _get_hard_masks(
            model, index, edge_index, num_nodes=data.num_nodes)
        #edge_mask = edge_mask.to(data.num_nodes.dtype)
    # else:
    #     edge_mask = torch.ones(edge_index.shape[1],
    #                             device=edge_index.device)

    return edge_mask


In [63]:
explain_node(0,model,data)

GNNExplainer(
  (epochs): Net(
    (conv1): RGCNConv(6909, 16, num_relations=90)
    (conv2): RGCNConv(16, 4, num_relations=90)
  )
)


tensor([ True, False, False,  ..., False, False, False])