In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from torch_geometric.nn import GNNExplainer
from tqdm import tqdm



class GNNExplainerUpdated(GNNExplainer):
    """ 
    https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/gnn_explainer.py 
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def explain_graph(self, x, edge_index, **kwargs):
        r"""Learns and returns a node feature mask and an edge mask that play a
        crucial role to explain the prediction made by the GNN for a graph.
        Args:
            x (Tensor): The node feature matrix.
            edge_index (LongTensor): The edge indices.
            **kwargs (optional): Additional arguments passed to the GNN module.
        :rtype: (:class:`Tensor`, :class:`Tensor`)
        """

        self.model.eval()
        self.__clear_masks__()

        # all nodes belong to same graph
        batch = torch.zeros(x.shape[0], dtype=int, device=x.device)

        # Get the initial prediction.
        with torch.no_grad():
            out = self.model(x=x, edge_index=edge_index, batch=batch, **kwargs)
            log_logits = self.__to_log_prob__(out)
            pred_label = log_logits.argmax(dim=-1)

        self.__set_masks__(x, edge_index)
        self.to(x.device)

        optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
                                     lr=self.lr)

        if self.log:  # pragma: no cover
            pbar = tqdm(total=self.epochs)
            pbar.set_description('Explain graph')

        for epoch in range(1, self.epochs + 1):
            optimizer.zero_grad()
            h = x * self.node_feat_mask.view(1, -1).sigmoid()
            out = self.model(x=h, edge_index=edge_index, batch=batch, **kwargs)
            log_logits = self.__to_log_prob__(out)
            loss = self.__loss__(-1, log_logits, pred_label)
            loss.backward()
            optimizer.step()

            if self.log:  # pragma: no cover
                pbar.update(1)

        if self.log:  # pragma: no cover
            pbar.close()

        node_feat_mask = self.node_feat_mask.detach().sigmoid()
        edge_mask = self.edge_mask.detach().sigmoid()

        self.__clear_masks__()
        return node_feat_mask, edge_mask

In [3]:
from models import DeeperGCN

model = DeeperGCN(5, 256, device)
model.load_state_dict(torch.load('gnn.pt'))
model.eval()

DeeperGCN(
  (layers): ModuleList(
    (0): DeepLayer(
      (layer): DeepGCNLayer(block=res+)
      (encoder): Linear(in_features=256, out_features=128, bias=True)
    )
    (1): DeepLayer(
      (layer): DeepGCNLayer(block=res+)
      (encoder): Linear(in_features=128, out_features=64, bias=True)
    )
    (2): DeepLayer(
      (layer): DeepGCNLayer(block=res+)
      (encoder): Linear(in_features=64, out_features=32, bias=True)
    )
    (3): DeepLayer(
      (layer): DeepGCNLayer(block=res+)
      (encoder): Linear(in_features=32, out_features=16, bias=True)
    )
    (4): DeepLayer(
      (layer): DeepGCNLayer(block=res+)
      (encoder): Linear(in_features=16, out_features=8, bias=True)
    )
  )
  (in_encoder): Linear(in_features=6, out_features=256, bias=True)
  (lin): Linear(in_features=8, out_features=2, bias=True)
)

In [4]:
from graph_data import GraphData
from utils.mesh import Mesh

graph_data = GraphData()

length = graph_data.points // 3

node_relevance = [[] for i in range(length)]
node_count = [0 for i in range(length)]

def node_relevance_mean(node_relevance, node_count):
    _node_relevance_mean = []
    for idx, count in enumerate(node_count):
        mean = sum(node_relevance[idx]) / count
        _node_relevance_mean.append(mean)
        
    return _node_relevance_mean

def node_relevance_std(node_relevance, node_count):
    _node_relevance_std = []
    for idx, count in enumerate(node_count):
        mean = sum(node_relevance[idx]) / count
        var = sum([(i - mean)**2 for i in node_relevance[idx]]) / count
        std = var ** 0.5
        _node_relevance_std.append(std)
        
    return _node_relevance_std

def calculate_node_relevance(edge_index, edge_mask):
    global node_relevance
    global node_count
    
    _node_relevance = [[] for i in range(graph_data.points//3)]
    _node_count = [0 for i in range(length)]
    
    for edge, relevance in zip(edge_index, edge_mask):
        _node_relevance[edge[0]].append(relevance)
        _node_relevance[edge[1]].append(relevance)
        _node_count[edge[0]] += 1
        _node_count[edge[1]] += 1
        
    # update global node_relevance and node_count
    node_relevance = [a + b for a, b in zip(node_relevance, _node_relevance)]
    node_count  = [a + b for a, b in zip(node_count, _node_count)]
        
    return (_node_relevance, _node_count)
        
    

explainer = GNNExplainerUpdated(model, epochs=100, return_type='log_prob')

count = 0
epochs = 25

while count < epochs:
    random_sample = graph_data.return_random_graph().to(device)
    
    if random_sample.y == 1:
        _node_feat_mask, _edge_mask = explainer.explain_graph(
            random_sample.x, random_sample.edge_index
        )
        
        random_sample = random_sample.to('cpu')
        
        edge_index = random_sample.edge_index.numpy().T
        
        edge_mask = _edge_mask.to('cpu').numpy()
        
        _node_relevance, _node_count = calculate_node_relevance(edge_index, edge_mask)
        _node_relevance_mean = node_relevance_mean(_node_relevance, _node_count)
        _node_relevance_std = node_relevance_std(_node_relevance, _node_count)
        
        count += 1
        print(f'{count}/{epochs}')
                
        mesh = Mesh(verts=random_sample.pos.numpy(), connectivity=random_sample.face.numpy().T)
        mesh.writeVTU(filename=f'gnn_explainer_results/random_sample_mean_{count}.vtu', scalars=_node_relevance_mean)
        mesh.writeVTU(filename=f'gnn_explainer_results/random_sample_std_{count}.vtu', scalars=_node_relevance_std)
        

Explain graph: 100%|██████████| 100/100 [00:07<00:00, 13.38it/s]
  self.normals[i, :] = n/np.linalg.norm(n)
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

1/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.15it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.90it/s]

2/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.14it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.63it/s]

3/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.12it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

4/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.15it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.63it/s]

5/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.12it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.63it/s]

6/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

7/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.64it/s]

8/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.14it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.63it/s]

9/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.66it/s]

10/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.15it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

11/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

12/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.63it/s]

13/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.88it/s]

14/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.63it/s]

15/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.12it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

16/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.12it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

17/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.62it/s]

18/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.12it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 17.04it/s]

19/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.15it/s]
Explain graph:   2%|▏         | 2/100 [00:00<00:05, 16.63it/s]

20/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

21/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.15it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

22/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

23/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.14it/s]
Explain graph:   0%|          | 0/100 [00:00<?, ?it/s]

24/25


Explain graph: 100%|██████████| 100/100 [00:06<00:00, 16.13it/s]


25/25


In [5]:
node_relevance_mean = node_relevance_mean(node_relevance, node_count)
node_relevance_std = node_relevance_std(node_relevance, node_count)
mesh = Mesh(verts=random_sample.pos.numpy(), connectivity=random_sample.face.numpy().T)
mesh.writeVTU(filename=f'gnn_explainer_results/random_sample_final_mean.vtu', scalars=node_relevance_mean)
mesh.writeVTU(filename=f'gnn_explainer_results/random_sample_final_std.vtu', scalars=node_relevance_std)