In [None]:
from torch import nn
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from syn_dataset import SynGraphDataset
from spmotif_dataset import *
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool
from utils import *
from sklearn.model_selection import train_test_split
import shutil
import glob
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pandas as pd
import argparse
import pickle
import json
import io
from model import GIN, GINTELL

In [None]:
dataset_name = 'MUTAG'
seed = 0
def get_best_baseline_path(dataset_name):
    l = glob.glob(f'results/{dataset_name}/*/results.json')
    fl = [json.load(open(f)) for f in l]
    df = pd.DataFrame(fl)
    if df.shape[0] == 0: return None
    df['fname'] = l
    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])
    df = df[df.fname.str.contains('nogumbel=False')]
    fname = df.iloc[-1]['fname']
    fname = fname.replace('/results.json', '')
    return fname

def get_best_path(dataset_name):
    l = glob.glob(f'results_logic/{dataset_name}/*/*/results.json')
    fl = [json.load(open(f)) for f in l]
    df = pd.DataFrame(fl)
    if df.shape[0] == 0: return None
    df['fname'] = l
    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])
    df = df[df.fname.str.contains('nogumbel=False')]
    fname = df.iloc[-1]['fname']
    fname = fname.replace('/results.json', '')
    return fname


results_path = os.path.join(get_best_path(dataset_name), str(seed))

In [None]:
import pickle
data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))

In [None]:
args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))
args

In [None]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
device

In [None]:
dataset = get_dataset(dataset_name)
num_classes = dataset.num_classes
num_features = dataset.num_features
num_layers = 5
hidden_dim = 32

In [None]:
indices = list(range(len(dataset)))
train_indices, val_test_indices = train_test_split(indices, test_size=0.2,
shuffle=True, stratify=dataset.data.y, random_state=1)

val_indices = val_test_indices[:len(val_test_indices)//2]
test_indices = val_test_indices[len(val_test_indices)//2:]

train_dataset = dataset[train_indices]
val_dataset = dataset[val_indices]
test_dataset = dataset[test_indices]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
model_tell = torch.load(os.path.join(results_path, 'best.pt'), map_location=device)
model_tell.dropout = lambda x:x

In [None]:
import torch
import torch_scatter
def scatter_sum(x, edge_index):
    # Get target nodes (i.e., the nodes receiving the messages)
    target_nodes = edge_index[1]
    
    # Perform scatter sum
    out = torch_scatter.scatter_add(x[edge_index[0]], target_nodes, dim=0, dim_size=x.size(0))

    return out

@torch.no_grad()
def forward_with_activations(self, x, edge_index, batch, *args, **kwargs):
    returns = []
    x = self.input_bnorm(x)
    xs = []
    for i, conv in enumerate(self.convs):
        ret = {}
        ret['x'] = torch.hstack([x, 1-x])
        ret['x_sum'] = scatter_sum(ret['x'], edge_index)
        ret['x_bin'] = conv.nn[0].phi_in(ret['x_sum']) >= 0.5
        x = conv(torch.hstack([x, 1-x]), edge_index)
        xs.append(x)
        ret['y'] = x
        ret['y_bin'] = x>=0.5
        returns.append(ret)
    ret = {}
    x_mean = global_mean_pool(torch.hstack(xs), batch)
    x_max = global_max_pool(torch.hstack(xs), batch)
    x_sum = global_add_pool(torch.hstack(xs), batch)
    x = torch.hstack([x_mean, x_max, x_sum])
    x = self.output_bnorm(x)
    ret['x'] = torch.hstack([x, 1-x])
    ret['x_bin'] = self.fc.phi_in(ret['x']) >= 0.5
    x = self.fc(torch.hstack([x, 1-x]))
    ret['y'] = x
    ret['y_bin'] = x>=0.5
    returns.append(ret)
    return x, returns

In [None]:
import torch

def inverse_sigmoid(x):
    """Computes the inverse of the sigmoid function (logit function)."""
    return torch.log(x / (1 - x))

torch.no_grad()
def find_logic_rules(w, t_in, t_out, activations=None, max_rule_len=10, max_rules=100, min_support=5):
    w = w.clone()
    t_in = t_in.clone()
    t_out = t_out.clone()
    t_out = t_out.item()
    ordering_scores = w
    sorted_idxs = torch.argsort(ordering_scores, 0, descending=True)
    mask = w > 1e-5
    if activations is not None:
        mask = mask & (activations.sum(0) >= min_support)
    total_result = set()

    # Filter and sort indices based on the mask
    idxs_to_visit = sorted_idxs[mask[sorted_idxs]]
    if idxs_to_visit.numel() == 0:
        return total_result

    # Sort weights based on the filtered indices
    sorted_weights = w[idxs_to_visit]
    current_combination = []
    result = set()

    def find_logic_rules_recursive(index, current_sum):
        # Stop if the maximum number of rules has been reached
        if len(result) >= max_rules:
            return

        if len(current_combination) > max_rule_len:
            return

        # Check if the current combination satisfies the condition
        if current_sum >= t_out:
            c = idxs_to_visit[current_combination].cpu().detach().tolist()
            c = tuple(sorted(c))
            result.add(c)
            return

        # Prune if remaining weights can't satisfy t_out
        remaining_max_sum = current_sum + sorted_weights[index:].sum()
        if remaining_max_sum < t_out:
            return

        # Explore further combinations
        for i in range(index, idxs_to_visit.shape[0]):
            # Prune based on activations if provided
            if activations is not None and len(current_combination) > 0 and activations[:, idxs_to_visit[current_combination + [i]]].all(-1).sum().item() < min_support:
                continue

            current_combination.append(i)
            find_logic_rules_recursive(i + 1, current_sum + sorted_weights[i])
            current_combination.pop()

    # Start the recursive process
    find_logic_rules_recursive(0, 0)
    return result


def extract_rules(self, feature=None, activations=None, max_rule_len=float('inf'), max_rules=100, min_support=1, out_threshold=0.5):
    ws = self.weight
    t_in = self.phi_in.t
    t_out = -self.b + inverse_sigmoid(torch.tensor(out_threshold))

    rules = []
    if feature is None:
        features = range(self.out_features)
    else:
        features = [feature]
    for i in features:
        w = ws[i].to('cpu')
        ti = t_in.to('cpu')
        to = t_out[i].to('cpu')
        rules.append(find_logic_rules(w, ti, to, activations, max_rule_len, max_rules, min_support))

    return rules

In [None]:
rules = extract_rules(model_tell.fc)

In [None]:
rules

In [None]:
feat_map = []
for pos_neg in ['pos', 'neg']:
    for readout in ['mean', 'max', 'sum']:
        for l in range(num_layers):
            for d in range(hidden_dim):
                # feat_map.append(f'{pos_neg}_{readout}_{l}_{d}')
                feat_map.append((pos_neg, readout, l, d))

In [None]:
from tqdm import tqdm
from torch_geometric.data import Batch
tell_explainer_results = []
for data in tqdm(test_dataset):
    data = data.to(device)
    data = Batch.from_data_list([data])
    data.x = data.x.float()
    pred, rets =forward_with_activations(model_tell, data.x.float(), data.edge_index, data.batch)
    pred_c = pred.argmax(-1).item()
    rules = extract_rules(model_tell.fc)
    soft_mask = torch.zeros(data.x.shape[0]).to(device)
    for c, class_rules in enumerate(rules):
        for rule in class_rules:
            # print(rule)
            # print(rule_acts[:,rule])
            
            # if not rule_activated: continue
            for literal in rule:
                # print(literal)
                pos_neg, agg, layer, i = feat_map[literal]
                acts = rets[layer]['y_bin'][:,i]
                m = torch.zeros_like(soft_mask)
                if agg == 'max':
                    m[acts>=acts.max()] = (1 if pred_c==c else -1)*acts.max()*rets[-1]['x_bin'][:,literal].item()*model_tell.fc.weight[c,literal]
                elif agg == 'sum':
                    m=(1 if pred_c==c else -1)*acts*rets[-1]['x_bin'][:,literal].item()*model_tell.fc.weight[c,literal]
                else:
                    m=(1 if pred_c==c else -1)*acts*rets[-1]['x_bin'][:,literal].item()*model_tell.fc.weight[c,literal]
                m_=torch.zeros_like(m)
                for i in range(len(m)):
                    if m[i] > 0:
                        try:
                            subset, _, _, _ = k_hop_subgraph(i, 1, data.edge_index.cpu())
                            m_[subset] += m[i]
                        except:
                            m_[i] = m[i]
                soft_mask+=m_    

    # print(soft_mask)
    soft_mask = soft_mask.detach().cpu()
    break

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.utils import subgraph, add_self_loops
from torch_geometric.data import Data,Batch

def calculate_fidelity(data, node_mask, model, remove_nodes=True, top_k=None):
    data = Batch.from_data_list([data])
    device = next(model.parameters()).device
    data = data.to(device)

    # Compute sparsity
    total_nodes = data.x.shape[0]
    sparsity = 1 - node_mask.sum().item() / total_nodes if total_nodes > 0 else 0

    # Get original predictions
    original_pred = model(data.x, data.edge_index, data.batch)
    original_pred = F.softmax(original_pred, dim=1)
    label = original_pred.argmax(-1).item()
    
    # Apply node mask (InvFidelity)
    if remove_nodes:
        masked_edge_index, _ = subgraph(node_mask == 0, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)
        n_nodes = (node_mask == 0).sum()
        new_data = Batch.from_data_list([Data(x=data.x[node_mask == 0], edge_index=masked_edge_index)])
        masked_pred = model(new_data.x, new_data.edge_index, new_data.batch)
    else:
        masked_x = data.x.clone()
        masked_x[node_mask==1] = 0
        new_data = Batch.from_data_list([Data(x=masked_x, edge_index=data.edge_index)])
        masked_pred = model(new_data.x, new_data.edge_index, new_data.batch)
    masked_pred = F.softmax(masked_pred, dim=1)

    # Keep only important nodes (Fidelity)
    if remove_nodes:
        masked_edge_index, _ = subgraph(node_mask == 1, edge_index=data.edge_index, num_nodes=data.x.size(0), relabel_nodes=True)
        n_nodes = (node_mask == 1).sum()
        new_data = Batch.from_data_list([Data(x=data.x[node_mask == 1], edge_index=masked_edge_index)])
        retained_pred = model(new_data.x, new_data.edge_index, new_data.batch)
    else:
        masked_x = data.x.clone()
        masked_x[node_mask==0] = 0
        new_data = Batch.from_data_list([Data(x=masked_x, edge_index=data.edge_index)])
        retained_pred = model(new_data.x, new_data.edge_index, new_data.batch)
    retained_pred = F.softmax(retained_pred, dim=1)


    # Compute Fidelity+ and Fidelity-
    inv_fidelity = (original_pred[:, label] - 
                                 retained_pred[:, label]).mean().item()

    fidelity = (original_pred[:, label] - 
                                 masked_pred[:, label]).mean().item()

    # inv_fidelity = (original_pred.argmax(-1) != 
    #                              retained_pred.argmax(-1)).float().item()

    # fidelity = (original_pred.argmax(-1)  != 
    #                              masked_pred.argmax(-1) ).float().item()

    n_fidelity = inv_fidelity*sparsity
    n_inv_fidelity = inv_fidelity*(1-sparsity)
    
    # Compute HFidelity (harmonic mean of Fidelity+ and Fidelity-)
    hfidelity = ((1+n_fidelity) * (1-n_inv_fidelity)) / (2 + n_fidelity - n_inv_fidelity) if (1 + n_fidelity - n_inv_fidelity) != 0 else 0

    return {
        "Fidelity": fidelity,
        "InvFidelity": inv_fidelity,
        "HFidelity": hfidelity
    }

In [None]:
# def generate_hard_masks(soft_mask):
#     sparsity_levels = torch.arange(0.5,1, 0.05)
#     hard_masks = []
#     for sparsity in sparsity_levels:
#         threshold = np.percentile(soft_mask, sparsity * 100)
#         hard_mask = (soft_mask > threshold).int()
#         if hard_mask.sum() == 0:
#             hard_mask = (soft_mask > soft_mask.min()).int()
#         hard_masks.append(hard_mask)
#     return list(zip(sparsity_levels, hard_masks))
import torch
import numpy as np

import torch

def generate_hard_masks(soft_mask):
    soft_mask_flat = soft_mask.flatten()
    total_elements = soft_mask_flat.numel()
    sparsity_levels = torch.arange(0.5, 1.0, 0.05)
    hard_masks = []

    # Get sorted indices (ascending: lowest values first)
    sorted_indices = torch.argsort(soft_mask_flat)

    for sparsity in sparsity_levels:
        num_to_mask = int(sparsity.item() * total_elements)
        mask_flat = torch.ones_like(soft_mask_flat, dtype=torch.int)
        
        if num_to_mask >= total_elements:
            num_to_mask = total_elements - 1  # keep at least one element
        
        # Zero out the lowest `num_to_mask` elements
        mask_flat[sorted_indices[:num_to_mask]] = 0
        
        # Reshape to original shape
        hard_mask = mask_flat.view_as(soft_mask)
        hard_masks.append(hard_mask)

    return list(zip(sparsity_levels, hard_masks))



In [None]:
masks = generate_hard_masks(soft_mask)

In [None]:
calculate_fidelity(data, masks[0][1], model_tell, remove_nodes=True, top_k=None)

In [None]:
tell_explainer_results = []
for data in tqdm(test_dataset):
    try:
        data = data.to(device)
        data.x = data.x.float()
        pred_tell, rule_acts, layers_acts = model_tell(data.x.float(), data.edge_index, activations=True)
        
        pred = model(data.x.float(), data.edge_index)
        # rule_acts = rule_acts>0.5
        r = {
            'data': data,
            'pred': pred.softmax(-1).detach().cpu().numpy(),
            'res':{}
        }
        pred_c = r['pred'].argmax(-1).item()
        rules = extract_rules(model_tell.fc)
        soft_mask = torch.zeros(data.x.shape[0]).to(device)
        for c, class_rules in enumerate(rules):
            for rule in class_rules:
                # print(rule)
                # print(rule_acts[:,rule])
                
                # if not rule_activated: continue
                for literal in rule:
                    # print(literal)
                    agg, layer, i = feat_map[literal]
                    acts = layers_acts[layer][:,i]
                    m = torch.zeros_like(soft_mask)
                    if agg == 'max':
                        m[acts>=acts.max()] = (1 if pred_c==c else -1)*acts.max()*rule_acts[:,literal].item()*model_tell.fc.weight[c,literal]
                    elif agg == 'sum':
                        m=(1 if pred_c==c else -1)*acts*rule_acts[:,literal].item()*model_tell.fc.weight[c,literal]
                    else:
                        m=(1 if pred_c==c else -1)*acts*rule_acts[:,literal].item()*model_tell.fc.weight[c,literal]
                    m_=torch.zeros_like(m)
                    for i in range(len(m)):
                        if m[i] > 0:
                            try:
                                subset, _, _, _ = k_hop_subgraph(i, 1, data.edge_index.cpu())
                                m_[subset] += m[i]
                            except:
                                m_[i] = m[i]
                    soft_mask+=m_    

        # print(soft_mask)
        soft_mask = soft_mask.detach().cpu()
        r['soft_mask'] = soft_mask
        hard_masks = generate_hard_masks(soft_mask)
        for sparsity, hard_mask in hard_masks:
            sparsity = sparsity.item()
            r['res'][sparsity] = calculate_fidelity(data, hard_mask, model)
            r['res'][sparsity]['hard_mask'] = hard_mask
        r['res_topk'] = {
            1: calculate_fidelity_topk(data, soft_mask, model,1),
            3: calculate_fidelity_topk(data, soft_mask, model,3),
            5: calculate_fidelity_topk(data, soft_mask, model,5)
        }
        tell_explainer_results.append(r)
    except Exception as e:
        print(e)

