In [3]:
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.loader import DataLoader

import cider
import models
import utils

import nci
import ba
import mutag

In [14]:
train_set = mutag.Mutagenicity(mode="training")
train_set_mini = train_set[0:8]

train_loader = DataLoader(
    train_set,
    batch_size=32,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
)
train_loader_mini = DataLoader(
    train_set_mini,
    batch_size=1,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
)

In [8]:
hidden_channels1 = 128
hidden_channels2 = 128
hidden_channels3 = 128

device = torch.device('cuda:3')

task_model_para_path = "./params/mutag_net.pt"
explainer_para_path = "./params/explainer_mutag.ckpt"

In [10]:
task_model = models.GcnEncoderGraph(
    input_dim=train_set.num_features,
    hidden_dim=50,
    embedding_dim=10,
    num_layers=3,
    pred_hidden_dims=[10, 10],
    label_dim=2,
)

explainer = cider.CIDER(
    train_set.num_features,
    hidden_channels1=hidden_channels1,
    hidden_channels2=hidden_channels2,
    hidden_channels3=hidden_channels3,
    task_model=task_model,
)

In [18]:
explainer.to(device)

CIDER(
  (gcn_shared): Sequential(
    (0): GCNConv(14, 128)
  )
  (gcn_mu_causal): GCNConv(128, 128)
  (gcn_mu_non_causal): GCNConv(128, 128)
  (gcn_logvar_causal): GCNConv(128, 128)
  (gcn_logvar_non_causal): GCNConv(128, 128)
  (decoder_causal): InnerProductDecoderMLP(
    (fc): Linear(in_features=128, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=128, bias=True)
  )
  (decoder_non_causal): InnerProductDecoderMLP(
    (fc): Linear(in_features=128, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=128, bias=True)
  )
  (task_model): GcnEncoderGraph(
    (conv_first): GCNConv(14, 50)
    (conv_block): ModuleList(
      (0): GCNConv(50, 50)
    )
    (conv_last): GCNConv(50, 10)
    (act): ReLU()
    (pred_model): Sequential(
      (0): Linear(in_features=110, out_features=10, bias=True)
      (1): ReLU()
      (2): Linear(in_features=10, out_features=10, bias=True)
      (3): ReLU()
      (4): Linear(in_features=10, out_feature

In [11]:
explainer.load_state_dict(torch.load(explainer_para_path))

<All keys matched successfully>

In [15]:
utils.evaluate_graphs_accuracy(train_loader, task_model, device)

0.8181000899011088

In [16]:
x = train_set_mini[0].x.to(device)
edge_index = train_set_mini[0].edge_index.to(device)

In [20]:
explainations = explainer.get_explainations(x, edge_index, device=device)

{'0.9': Data(x=[29, 14], edge_index=[2, 58]),
 '0.8': Data(x=[29, 14], edge_index=[2, 58]),
 '0.7': Data(x=[29, 14], edge_index=[2, 58]),
 '0.6': Data(x=[29, 14], edge_index=[2, 58]),
 '0.5': Data(x=[29, 14], edge_index=[2, 58]),
 '0.4': Data(x=[29, 14], edge_index=[2, 58]),
 '0.3': Data(x=[29, 14], edge_index=[2, 58]),
 '0.2': Data(x=[29, 14], edge_index=[2, 58]),
 '0.1': Data(x=[29, 14], edge_index=[2, 58])}