In [13]:
import argparse
import glob
import io
import json
import os
import pickle
import shutil
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Optional

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.data import Data, Batch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    GCNConv,
    GINConv,
    global_add_pool,
    global_max_pool,
    global_mean_pool
)
from torch_geometric.utils import k_hop_subgraph, subgraph

import torch_geometric.transforms as T

from syn_dataset import SynGraphDataset
from spmotif_dataset import *
from utils import *
from model import GIN
from train_baseline import test_epoch
from tell_sigmoid import LogicalLayer

from graphxai.explainers._base import _BaseExplainer
from graphxai.utils import Explanation, node_mask_from_edge_mask
from pgexplainer import PGExplainer
from gstarx import GStarX

# Utilities

In [14]:
def get_best_baseline_path(dataset_name):
    l = glob.glob(f'results/{dataset_name}/*/results.json')
    fl = [json.load(open(f)) for f in l]
    df = pd.DataFrame(fl)
    if df.shape[0] == 0: return None
    df['fname'] = l
    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])
    print(df)
    fname = df.iloc[-1]['fname']
    fname = fname.replace('/results.json', '')
    return fname

def inverse_sigmoid(x):
    """Computes the inverse of the sigmoid function (logit function)."""
    return torch.log(x / (1 - x))

@torch.no_grad()
def find_logic_rules(w, t_in, t_out, activations=None, max_rule_len=10, max_rules=100, min_support=1):
    w = w.clone()
    t_in = t_in.clone()
    t_out = t_out.clone()
    t_out = t_out.item()
    ordering_scores = w
    sorted_idxs = torch.argsort(ordering_scores, 0, descending=True)
    mask = w > 1e-5
    if activations is not None:
        mask = mask & (activations.sum(0) >= min_support)
    total_result = set()

    # Filter and sort indices based on the mask
    idxs_to_visit = sorted_idxs[mask[sorted_idxs]]
    if idxs_to_visit.numel() == 0:
        return total_result

    # Sort weights based on the filtered indices
    sorted_weights = w[idxs_to_visit]
    current_combination = []
    result = set()

    def find_logic_rules_recursive(index, current_sum):
        # Stop if the maximum number of rules has been reached
        if len(result) >= max_rules:
            return

        if len(current_combination) > max_rule_len:
            return

        # Check if the current combination satisfies the condition
        if current_sum >= t_out:
            c = idxs_to_visit[current_combination].cpu().detach().tolist()
            c = tuple(sorted(c))
            result.add(c)
            return

        # Prune if remaining weights can't satisfy t_out
        remaining_max_sum = current_sum + sorted_weights[index:].sum()
        if remaining_max_sum < t_out:
            return

        # Explore further combinations
        for i in range(index, idxs_to_visit.shape[0]):
            # Prune based on activations if provided
            if activations is not None and len(current_combination) > 0 and activations[:, idxs_to_visit[current_combination + [i]]].all(-1).sum().item() < min_support:
                continue

            current_combination.append(i)
            find_logic_rules_recursive(i + 1, current_sum + sorted_weights[i])
            current_combination.pop()

    # Start the recursive process
    find_logic_rules_recursive(0, 0)
    return result


def extract_rules(self, feature=None, activations=None, max_rule_len=float('inf'), max_rules=100, min_support=1, out_threshold=0.5):
    ws = self.weight
    t_in = self.phi_in.t
    t_out = -self.b + inverse_sigmoid(torch.tensor(out_threshold))

    rules = []
    if feature is None:
        features = range(self.out_features)
    else:
        features = [feature]
    for i in features:
        w = ws[i].to('cpu')
        ti = t_in.to('cpu')
        to = t_out[i].to('cpu')
        rules.append(find_logic_rules(w, ti, to, activations, max_rule_len, max_rules, min_support))

    return rules



def get_blues_color(value):
    cmap = plt.get_cmap("Blues")  # Get the Blues colormap
    return cmap(value)  # Return the RGBA color

def plot_activations(batch_ids, batch, attr, soft=False):
    if type(batch_ids) != list:
        batch_ids = [batch_ids]
    if soft: attr = (attr-attr.min() + 1e-6)/(attr.max()-attr.min()+1e-6)
    fig, axs = plt.subplots(1, len(batch_ids), figsize=(16*len(batch_ids), 8))
    if type(axs) != np.ndarray: axs = np.array([axs])
    for i, batch_id in enumerate(batch_ids):
        node_mask = batch.batch == batch_id  # Get nodes where batch == 0
        node_indices = torch.nonzero(node_mask, as_tuple=True)[0]
        
        subgraph_edge_mask = (batch.batch[batch.edge_index[0]] == batch_id) & \
                             (batch.batch[batch.edge_index[1]] == batch_id)
        subgraph_edges = batch.edge_index[:, subgraph_edge_mask]
        
        node_mapping = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(node_indices)}
        remapped_edges = torch.tensor([[node_mapping[e.item()] for e in edge] for edge in subgraph_edges.T])
        
        G = nx.Graph()
        G.add_edges_from(remapped_edges.numpy())
        
        nx.set_node_attributes(G, {v: k for k, v in node_mapping.items()}, "original_id")
        
        node_colors = []
        node_borders = []
        
        for node in G.nodes:
            node_colors.append(get_blues_color(attr[batch.batch==batch_id][node]))  # Fill color
            
        
        pos = nx.kamada_kawai_layout(G) 
        
        nx.draw(
            G, pos,
            node_color=node_colors,
            edgecolors=node_borders,  # Border colors
            node_size=700,
            with_labels=False,
            ax = axs[i]
        )
        
        axs[i].set_title(f"Class = {batch.y[batch_id]}")
    plt.show()

def calculate_fidelity(data, node_mask, model, remove_nodes=True, top_k=None):

    device = next(model.parameters()).device
    data = data.to(device)

    # Compute sparsity
    total_nodes = data.x.shape[0]
    sparsity = 1 - node_mask.sum().item() / total_nodes if total_nodes > 0 else 0

    # Get original predictions
    original_pred = model(data.x.float(), data.edge_index)
    # original_pred = F.softmax(original_pred, dim=1)
    label = original_pred.argmax(-1).item()
    
    if remove_nodes:
        masked_edge_index, _ = subgraph(node_mask == 0, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)
        masked_pred = model(data.x[node_mask == 0], masked_edge_index)
    else:
        masked_x = data.x.clone()
        masked_x[node_mask==1] = 0
        masked_pred = model(masked_x, data.edge_index)

    if remove_nodes:
        masked_edge_index, _ = subgraph(node_mask == 1, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)
        retained_pred = model(data.x[node_mask == 1], masked_edge_index)
    else:
        masked_x = data.x.clone()
        masked_x[node_mask==0] = 0
        retained_pred = model(masked_x, data.edge_index)

    inv_fidelity = (original_pred.argmax(-1) != 
                                 retained_pred.argmax(-1)).float().item()

    fidelity = (original_pred.argmax(-1)  != 
                                 masked_pred.argmax(-1) ).float().item()

    n_fidelity = inv_fidelity*sparsity
    n_inv_fidelity = inv_fidelity*(1-sparsity)
    
    hfidelity = ((1+n_fidelity) * (1-n_inv_fidelity)) / (2 + n_fidelity - n_inv_fidelity) if (1 + n_fidelity - n_inv_fidelity) != 0 else 0

    return {
        "Fidelity": fidelity,
        "InvFidelity": inv_fidelity,
        "HFidelity": hfidelity
    }


def calculate_fidelity_topk(data, node_soft_mask, model, top_k, remove_nodes=True):

    device = next(model.parameters()).device
    data = data.to(device)

    # Compute sparsity
    total_nodes = data.x.shape[0]

    # Get original predictions
    original_pred = model(data.x.float(), data.edge_index)
    # original_pred = F.softmax(original_pred, dim=1)
    label = original_pred.argmax(-1).item()

    

    node_mask = torch.zeros_like(node_soft_mask)
    node_mask[torch.topk(node_soft_mask, top_k).indices] = 1
    
    if remove_nodes:
        masked_edge_index, _ = subgraph(node_mask == 0, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)
        masked_pred = model(data.x[node_mask == 0], masked_edge_index)
    else:
        masked_x = data.x.clone()
        masked_x[node_mask==1] = 0
        masked_pred = model(masked_x, data.edge_index)

    node_mask = torch.zeros_like(node_soft_mask)
    node_mask[torch.topk(node_soft_mask, total_nodes-top_k).indices] = 1
    
    if remove_nodes:
        masked_edge_index, _ = subgraph(node_mask == 1, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)
        retained_pred = model(data.x[node_mask == 1], masked_edge_index)
    else:
        masked_x = data.x.clone()
        masked_x[node_mask==0] = 0
        retained_pred = model(masked_x, data.edge_index)

    inv_fidelity = (original_pred.argmax(-1) != 
                                 retained_pred.argmax(-1)).float().item()

    fidelity = (original_pred.argmax(-1)  != 
                                 masked_pred.argmax(-1) ).float().item()
    return {
        "Fidelity": fidelity,
        "InvFidelity": inv_fidelity,
    }

def node_imp_from_edge_imp(edge_index, n_nodes, edge_imp):
    node_imp = torch.zeros(n_nodes)
    for i in range(n_nodes):
        node_imp[i] = edge_imp[(edge_index[0]==i) | (edge_index[1]==i)].mean()
    return node_imp
        
def generate_hard_masks(soft_mask):
    sparsity_levels = torch.arange(0.5,1, 0.05)
    hard_masks = []
    for sparsity in sparsity_levels:
        threshold = np.percentile(soft_mask, sparsity * 100)
        hard_mask = (soft_mask > threshold).int()
        if hard_mask.sum() == 0:
            hard_mask = (soft_mask > soft_mask.min()).int()
        hard_masks.append(hard_mask)
    return list(zip(sparsity_levels, hard_masks))

# Model and Data Loading

In [15]:
dataset_name = 'MUTAG'
seed = 0
results_path = os.path.join(get_best_baseline_path(dataset_name), str(seed))
data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))
args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dataset = get_dataset(dataset_name)
num_classes = dataset.num_classes
num_features = dataset.num_features
num_layers = args['num_layers']
hidden_dim = args['hidden_dim']
model = GIN(num_classes=num_classes, num_features=num_features, num_layers=num_layers, hidden_dim=hidden_dim, dropout=0.1)
model.load_state_dict(torch.load(os.path.join(results_path, 'best.pt'), map_location=device))
model = model.to(device)
train_indices = data['train_indices']
val_indices = data['val_indices']
test_indices = data['test_indices']
train_dataset = dataset[train_indices]
val_dataset = dataset[val_indices]
test_dataset = dataset[test_indices]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
val_acc = test_epoch(model, val_loader, device)
test_acc = test_epoch(model, test_loader, device)

os.makedirs(f'post_hoc/{dataset_name}/{seed}', exist_ok=True)

    val_acc_mean  test_acc_mean  val_acc_std  test_acc_std  \
0       0.931579       0.857895     0.065877      0.055755   
10      0.931579       0.800000     0.055755      0.064699   
6       0.936842       0.836842     0.059752      0.083954   
15      0.936842       0.821053     0.059752      0.066574   
14      0.942105       0.810526     0.063012      0.105846   
8       0.942105       0.847368     0.063012      0.083954   
1       0.942105       0.852632     0.063012      0.073601   
7       0.942105       0.878947     0.052338      0.092999   
4       0.942105       0.836842     0.052338      0.080204   
3       0.947368       0.831579     0.065643      0.069293   
9       0.947368       0.821053     0.042974      0.083033   
13      0.947368       0.836842     0.042974      0.063012   
11      0.952632       0.815789     0.067720      0.083218   
5       0.952632       0.826316     0.046084      0.078654   
2       0.952632       0.868421     0.038835      0.071263   
12      

# PGExplainer

In [16]:
def explain_pg_explainer(model, data):
    r = {
        'data': data,
        'pred': model(data.x.float(), data.edge_index).softmax(-1).detach().cpu().numpy(),
        'res' : {}
    }
    data = data.to(device)
    model=model.to(device)
    pgexplainer = PGExplainer(model, explain_graph=True)
    exp = pgexplainer.get_explanation_graph(x=data.x.float(), edge_index=data.edge_index)
    soft_mask = node_imp_from_edge_imp(data.edge_index, data.x.shape[0], exp.edge_imp)
    soft_mask[soft_mask!=soft_mask]=0
    r['soft_mask'] = soft_mask
    hard_masks = generate_hard_masks(soft_mask)
    for sparsity, hard_mask in hard_masks:
        sparsity=sparsity.item()
        r['res'][sparsity] = calculate_fidelity(data, hard_mask, model)
        r['res'][sparsity]['hard_mask'] = hard_mask

    r['res_topk'] = {
        1: calculate_fidelity_topk(data, soft_mask, model,1),
        3: calculate_fidelity_topk(data, soft_mask, model,3),
        5: calculate_fidelity_topk(data, soft_mask, model,5)
    }
    return r
    
pg_explainer_results = []
for data in tqdm(test_dataset):
    try:
        data = data.to(device)
        data.x = data.x.float()
        r = explain_pg_explainer(model, data)
        pg_explainer_results.append(r)
    except: pass


with open(f'post_hoc/{dataset_name}/{seed}/pgexplainer.pkl', 'wb') as f:
    pickle.dump(pg_explainer_results, f)
pg_explainer_results = pickle.load(open(f'post_hoc/{dataset_name}/{seed}/pgexplainer.pkl', 'rb'))


100%|██████████| 19/19 [00:00<00:00, 453.43it/s]


# GNNExplainer

In [17]:
from graphxai.explainers import GNNExplainer

def explain_gnn_explainer(model, data):
    r = {
        'data': data,
        'pred': model(data.x.float(), data.edge_index).softmax(-1).detach().cpu().numpy(),
        'res' : {}
    }
    data = data.to(device)
    model=model.to(device)
    gnnexplainer = GNNExplainer(model)
    exp = gnnexplainer.get_explanation_graph(x=data.x.float(), edge_index=data.edge_index)
    soft_mask = node_imp_from_edge_imp(data.edge_index, data.x.shape[0], exp.edge_imp)
    soft_mask[soft_mask!=soft_mask]=0
    r['soft_mask'] = soft_mask
    hard_masks = generate_hard_masks(soft_mask)
    for sparsity, hard_mask in hard_masks:
        sparsity=sparsity.item()
        r['res'][sparsity] = calculate_fidelity(data, hard_mask, model)
        r['res'][sparsity]['hard_mask'] = hard_mask

    r['res_topk'] = {
        1: calculate_fidelity_topk(data, soft_mask, model,1),
        3: calculate_fidelity_topk(data, soft_mask, model,3),
        5: calculate_fidelity_topk(data, soft_mask, model,5)
    }
    return r
    
gnn_explainer_results = []
for data in tqdm(test_dataset):
    data = data.to(device)
    data.x = data.x.float()
    try:
        r = explain_gnn_explainer(model, data)
        gnn_explainer_results.append(r)
    except: pass

with open(f'post_hoc/{dataset_name}/{seed}/gnnexplainer.pkl', 'wb') as f:
    pickle.dump(gnn_explainer_results, f)
gnn_explainer_results = pickle.load(open(f'post_hoc/{dataset_name}/{seed}/gnnexplainer.pkl', 'rb'))

100%|██████████| 19/19 [00:12<00:00,  1.52it/s]


# IG

In [18]:
from graphxai.explainers import IntegratedGradExplainer

def explain_ig_explainer(model, data):
    r = {
        'data': data,
        'pred': model(data.x.float(), data.edge_index).softmax(-1).detach().cpu().numpy(),
        'res' : {}
    }
    data = data.to(device)
    model=model.to(device)
    igexplainer = IntegratedGradExplainer(model, torch.nn.CrossEntropyLoss())
    exp = igexplainer.get_explanation_graph(x=data.x.float(), edge_index=data.edge_index, label=torch.tensor(r['pred'].argmax(-1)).to(device))
    soft_mask = exp.node_imp.detach().cpu()
    r['soft_mask'] = soft_mask
    hard_masks = generate_hard_masks(soft_mask)
    for sparsity, hard_mask in hard_masks:
        sparsity=sparsity.item()
        r['res'][sparsity] = calculate_fidelity(data, hard_mask, model)
        r['res'][sparsity]['hard_mask'] = hard_mask
    r['res_topk'] = {
        1: calculate_fidelity_topk(data, soft_mask, model,1),
        3: calculate_fidelity_topk(data, soft_mask, model,3),
        5: calculate_fidelity_topk(data, soft_mask, model,5)
    }
    return r
    
ig_explainer_results = []
for data in tqdm(test_dataset):
    data = data.to(device)
    data.x = data.x.float()
    try:
        r = explain_ig_explainer(model, data)
        ig_explainer_results.append(r)
    except: pass


with open(f'post_hoc/{dataset_name}/{seed}/ig.pkl', 'wb') as f:
    pickle.dump(ig_explainer_results, f)
ig_explainer_results = pickle.load(open(f'post_hoc/{dataset_name}/{seed}/ig.pkl', 'rb'))

100%|██████████| 19/19 [00:04<00:00,  4.70it/s]


# GStarX

In [19]:
preds = []
for data in test_dataset:
    try:
        data.to(device)
        data.x = data.x.float()
        pred = model(data.x.float(), data.edge_index).softmax(-1)
        preds += [pred]
    except: pass
preds = torch.concat(preds)
payoff_avg = preds.mean(0).tolist()
gstarx = GStarX(model, device, payoff_avg=payoff_avg)
gstarx_explainer_results = []

for data in tqdm(test_dataset):
    try:
        data = data.to(device)
        data.x = data.x.float()
        r = {
            'data': data,
            'pred': model(data.x.float(), data.edge_index).softmax(-1).detach().cpu().numpy(),
            'res':{}
        }
        soft_mask = torch.tensor(gstarx.explain(data, superadditive_ext=False, num_samples=5))
        r['soft_mask'] = soft_mask
        hard_masks = generate_hard_masks(soft_mask)
        for sparsity, hard_mask in hard_masks:
            # print(sparsity)
            r['res'][sparsity.item()] = calculate_fidelity(data, hard_mask, model)
            r['res'][sparsity.item()]['hard_mask'] = hard_mask
        r['res_topk'] = {
            1: calculate_fidelity_topk(data, soft_mask, model,1),
            3: calculate_fidelity_topk(data, soft_mask, model,3),
            5: calculate_fidelity_topk(data, soft_mask, model,5)
        }
        gstarx_explainer_results.append(r)
    except:
        pass

with open(f'post_hoc/{dataset_name}/{seed}/gstarx.pkl', 'wb') as f:
    pickle.dump(gstarx_explainer_results, f)
gstarx_explainer_results = pickle.load(open(f'post_hoc/{dataset_name}/{seed}/gstarx.pkl', 'rb'))


100%|██████████| 19/19 [00:28<00:00,  1.52s/it]


# SubGraphX

In [20]:
from graphxai.explainers import SubgraphX
subgraphx_explainer = SubgraphX(model, sample_num=10)

subgraphx_explainer_results = []
for data in tqdm(test_dataset):
    try:
        data = data.to(device)
        data.x = data.x.float()
        r = {
            'data': data,
            'pred': model(data.x.float(), data.edge_index).softmax(-1).detach().cpu().numpy(),
            'res':{}
        }
        exp = subgraphx_explainer.get_explanation_graph(x=data.x.float(), edge_index=data.edge_index, label=torch.tensor(r['pred'].argmax(-1)))
        soft_mask = exp.node_imp
        r['soft_mask'] = soft_mask
        hard_masks = generate_hard_masks(soft_mask)
        for sparsity, hard_mask in hard_masks:
            # print(sparsity)
            r['res'][sparsity.item()] = calculate_fidelity(data, hard_mask, model)
            r['res'][sparsity.item()]['hard_mask'] = hard_mask
        r['res_topk'] = {
            1: calculate_fidelity_topk(data, soft_mask, model,1),
            3: calculate_fidelity_topk(data, soft_mask, model,3),
            5: calculate_fidelity_topk(data, soft_mask, model,5)
        }
    except: 
        continue
    subgraphx_explainer_results.append(r)


with open(f'post_hoc/{dataset_name}/{seed}/subgraphx.pkl', 'wb') as f:
    pickle.dump(subgraphx_explainer_results, f)
subgraphx_explainer_results = pickle.load(open(f'post_hoc/{dataset_name}/{seed}/subgraphx.pkl', 'rb'))

100%|██████████| 19/19 [01:40<00:00,  5.28s/it]


# LogiX

In [21]:
def train_epoch(model_tell, loader, device, optimizer, num_classes, reg=1, sqrt_reg=False):
    model_tell.train()
    
    total_loss = 0
    total_correct = 0
    
    for data in loader:
        try:
            loss = 0
            if data.x is None:
                data.x = torch.ones((data.num_nodes, model_tell.num_features))
            if data.y.numel() == 0: continue
            if data.x.isnan().any(): continue
            if data.y.isnan().any(): continue
            data.x = data.x.float()
            y = data.y.reshape(-1).to(device).long()
            optimizer.zero_grad()

            out = model_tell(data.x.float().to(device), data.edge_index.to(device), data.batch.to(device))       
            pred = out.argmax(-1)
            loss += F.binary_cross_entropy(out.reshape(-1), torch.nn.functional.one_hot(y, num_classes=num_classes).float().reshape(-1)) + F.nll_loss(F.log_softmax(out, dim=-1), y.long())
            # loss += reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)
            loss += reg*model_tell.fc.phi_in.entropy
            if sqrt_reg:
                loss+= reg*torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean()
            else:
                loss+=reg*model_tell.fc.reg_loss
            loss.backward()
            zero_nan_gradients(model_tell)#torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean()  + 
            optimizer.step()
            total_loss += loss.item() * data.num_graphs / len(loader.dataset)
            total_correct += pred.eq(y).sum().item() / len(loader.dataset)
        except Exception as e:
            print(e)
            pass

    return total_loss, total_correct

model_tell = GIN(num_classes=num_classes, num_features=num_features, num_layers=num_layers, hidden_dim=hidden_dim)
model_tell.load_state_dict(torch.load(os.path.join(results_path, 'best.pt'), map_location=device))
model_tell = model_tell.to(device)



model_tell.fc = LogicalLayer(model_tell.fc1.in_features, num_classes).to(device)
model_tell.fc.phi_in.tau = 10

def forward_tell(self):
    def fwd(x, edge_index, batch=None, activations=False, *args, **kwargs):
        if batch is None:
            batch = torch.zeros(x.shape[0]).long().to(x.device)
        xs = []
        for conv in self.convs:
            x = conv(x, edge_index)
            xs.append(x)
            x = self.dropout(x)
    
        x_mean = global_mean_pool(torch.hstack(xs), batch)
        x_max = global_max_pool(torch.hstack(xs), batch)
        x_sum = global_add_pool(torch.hstack(xs), batch)
        x = torch.hstack([x_mean, x_max, x_sum])
        # x = self.dropout(x)
        acts = self.fc.phi_in(x)
        x = self.fc(x)
        if activations:
            return x, acts, xs
        return x
    return fwd


model_tell.forward = forward_tell(model_tell)
model_tell.fc.phi_in.w.shape
optimizer = torch.optim.Adam(model_tell.fc.parameters(), lr=0.001, weight_decay=0)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

for i in range(2000):
    train_loss, train_acc = train_epoch(model_tell, train_loader, device, optimizer, num_classes, reg=0.1 if i<=800 else 0.01, sqrt_reg=i>800)
    val_acc = test_epoch(model_tell, val_loader, device)
    test_acc = test_epoch(model_tell, test_loader, device)
    if i%10 == 0:
        print(i, train_loss, train_acc, val_acc, test_acc, (model_tell.fc.weight>1e-4).sum())


model_tell.forward = forward_tell(model_tell)


from torch_geometric.utils import k_hop_subgraph
feat_map = []
for readout in ['mean', 'max', 'sum']:
    for l in range(num_layers):
        for d in range(hidden_dim):
            feat_map.append((readout, l, d))

tell_explainer_results = []
for data in tqdm(test_dataset):
    try:
        data = data.to(device)
        data.x = data.x.float()
        pred_tell, rule_acts, layers_acts = model_tell(data.x.float(), data.edge_index, activations=True)
        
        pred = model(data.x.float(), data.edge_index)
        r = {
            'data': data,
            'pred': pred.softmax(-1).detach().cpu().numpy(),
            'res':{}
        }
        pred_c = r['pred'].argmax(-1).item()
        rules = extract_rules(model_tell.fc)
        soft_mask = torch.zeros(data.x.shape[0]).to(device)
        for c, class_rules in enumerate(rules):
            for rule in class_rules:
                for literal in rule:
                    agg, layer, i = feat_map[literal]
                    acts = layers_acts[layer][:,i]
                    m = torch.zeros_like(soft_mask)
                    if agg == 'max':
                        m[acts>=acts.max()] = (1 if pred_c==c else -1)*acts.max()*rule_acts[:,literal].item()*model_tell.fc.weight[c,literal]
                    elif agg == 'sum':
                        m=(1 if pred_c==c else -1)*acts*rule_acts[:,literal].item()*model_tell.fc.weight[c,literal]
                    else:
                        m=(1 if pred_c==c else -1)*acts*rule_acts[:,literal].item()*model_tell.fc.weight[c,literal]
                    m_=torch.zeros_like(m)
                    for i in range(len(m)):
                        if m[i] > 0:
                            try:
                                subset, _, _, _ = k_hop_subgraph(i, 1, data.edge_index.cpu())
                                m_[subset] += m[i]
                            except:
                                m_[i] = m[i]
                    soft_mask+=m_    

        # print(soft_mask)
        soft_mask = soft_mask.detach().cpu()
        r['soft_mask'] = soft_mask
        hard_masks = generate_hard_masks(soft_mask)
        for sparsity, hard_mask in hard_masks:
            sparsity = sparsity.item()
            r['res'][sparsity] = calculate_fidelity(data, hard_mask, model)
            r['res'][sparsity]['hard_mask'] = hard_mask
        r['res_topk'] = {
            1: calculate_fidelity_topk(data, soft_mask, model,1),
            3: calculate_fidelity_topk(data, soft_mask, model,3),
            5: calculate_fidelity_topk(data, soft_mask, model,5)
        }
        tell_explainer_results.append(r)
    except Exception as e:
        print(e)


model_tell.forward = None

with open(f'post_hoc/{dataset_name}/{seed}/tell_model.pkl', 'wb') as f:
    pickle.dump(model_tell, f)


with open(f'post_hoc/{dataset_name}/{seed}/tell.pkl', 'wb') as f:
    pickle.dump(tell_explainer_results, f)
torch.save(model_tell, f'post_hoc/{dataset_name}/{seed}/model_tell.pt')
tell_explainer_results = pickle.load(open(f'post_hoc/{dataset_name}/{seed}/tell.pkl', 'rb'))


0 75.15151733398437 0.33333333333333337 0.3684210526315789 0.3157894736842105 tensor(712, device='cuda:0')
10 73.50275146484375 0.33333333333333337 0.3684210526315789 0.3157894736842105 tensor(688, device='cuda:0')
20 71.27392781575521 0.3333333333333333 0.3684210526315789 0.3157894736842105 tensor(677, device='cuda:0')
30 67.9484770711263 0.33333333333333337 0.3684210526315789 0.3157894736842105 tensor(666, device='cuda:0')
40 63.7150765991211 0.33333333333333337 0.3684210526315789 0.3157894736842105 tensor(662, device='cuda:0')
50 59.62910196940104 0.3333333333333333 0.3684210526315789 0.3157894736842105 tensor(656, device='cuda:0')
60 56.259444885253906 0.33333333333333337 0.3684210526315789 0.3157894736842105 tensor(647, device='cuda:0')
70 25.23301063537598 0.6666666666666666 0.631578947368421 0.6842105263157895 tensor(627, device='cuda:0')
80 24.83815897623698 0.6666666666666666 0.631578947368421 0.6842105263157895 tensor(616, device='cuda:0')
90 9.92570612589518 0.33333333333333

100%|██████████| 19/19 [00:45<00:00,  2.39s/it]
