In [None]:
from torch import nn
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from syn_dataset import SynGraphDataset
from spmotif_dataset import *
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool
from utils import *
from sklearn.model_selection import train_test_split
import shutil
import glob
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pandas as pd
import argparse
import pickle
import json
import io
from model_node import GIN, GINTELL
from train_baseline_node import test_epoch

In [None]:
dataset_name = 'BaShapes'
seed = 2
def get_best_baseline_path(dataset_name):
    l = glob.glob(f'results/{dataset_name}/*/results.json')
    fl = [json.load(open(f)) for f in l]
    df = pd.DataFrame(fl)
    if df.shape[0] == 0: return None
    df['fname'] = l
    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])
    df = df[df.fname.str.contains('nogumbel=False')]
    fname = df.iloc[-1]['fname']
    fname = fname.replace('/results.json', '')
    return fname

def get_best_path(dataset_name):
    l = glob.glob(f'results_logic/{dataset_name}/*/*/results.json')
    fl = [json.load(open(f)) for f in l]
    df = pd.DataFrame(fl)
    if df.shape[0] == 0: return None
    df['fname'] = l
    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])
    df = df[df.fname.str.contains('nogumbel=False')]
    fname = df.iloc[-1]['fname']
    fname = fname.replace('/results.json', '')
    return fname


results_path = os.path.join(get_best_path(dataset_name), str(seed))

In [None]:
import pickle
data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))

In [None]:
args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))
args

In [None]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
device

In [None]:
dataset = get_dataset(dataset_name)
num_classes = dataset.num_classes
num_features = dataset.num_features
num_layers = 5
hidden_dim = 32

In [None]:
model_tell = torch.load(os.path.join(results_path, 'best.pt'), map_location=device)
# model_tell.dropout = lambda x:x

In [None]:
val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)
test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)


In [None]:
test_acc

In [None]:
dataset[0].y

In [None]:
import torch

def inverse_sigmoid(x):
    """Computes the inverse of the sigmoid function (logit function)."""
    return torch.log(x / (1 - x))

torch.no_grad()
def find_logic_rules(w, t_in, t_out, activations=None, max_rule_len=10, max_rules=100, min_support=5):
    w = w.clone()
    t_in = t_in.clone()
    t_out = t_out.clone()
    t_out = t_out.item()
    ordering_scores = w
    sorted_idxs = torch.argsort(ordering_scores, 0, descending=True)
    mask = w > 1e-5
    if activations is not None:
        mask = mask & (activations.sum(0) >= min_support)
    total_result = set()

    # Filter and sort indices based on the mask
    idxs_to_visit = sorted_idxs[mask[sorted_idxs]]
    if idxs_to_visit.numel() == 0:
        return total_result

    # Sort weights based on the filtered indices
    sorted_weights = w[idxs_to_visit]
    current_combination = []
    result = set()

    def find_logic_rules_recursive(index, current_sum):
        # Stop if the maximum number of rules has been reached
        if len(result) >= max_rules:
            return

        if len(current_combination) > max_rule_len:
            return

        # Check if the current combination satisfies the condition
        if current_sum >= t_out:
            c = idxs_to_visit[current_combination].cpu().detach().tolist()
            c = tuple(sorted(c))
            result.add(c)
            return

        # Prune if remaining weights can't satisfy t_out
        remaining_max_sum = current_sum + sorted_weights[index:].sum()
        if remaining_max_sum < t_out:
            return

        # Explore further combinations
        for i in range(index, idxs_to_visit.shape[0]):
            # Prune based on activations if provided
            if activations is not None and len(current_combination) > 0 and activations[:, idxs_to_visit[current_combination + [i]]].all(-1).sum().item() < min_support:
                continue

            current_combination.append(i)
            find_logic_rules_recursive(i + 1, current_sum + sorted_weights[i])
            current_combination.pop()

    # Start the recursive process
    find_logic_rules_recursive(0, 0)
    return result


def extract_rules(self, feature=None, activations=None, max_rule_len=float('inf'), max_rules=100, min_support=1, out_threshold=0.5):
    ws = self.weight
    t_in = self.phi_in.t
    t_out = -self.b + inverse_sigmoid(torch.tensor(out_threshold))

    rules = []
    if feature is None:
        features = range(self.out_features)
    else:
        features = [feature]
    for i in features:
        w = ws[i].to('cpu')
        ti = t_in.to('cpu')
        to = t_out[i].to('cpu')
        rules.append(find_logic_rules(w, ti, to, activations, max_rule_len, max_rules, min_support))

    return rules

In [None]:
import torch
from torch_geometric.data import Data, Batch
import networkx as nx
import matplotlib.pyplot as plt

def plot_activations(batch_ids, batch, attr):
    if type(batch_ids) != list:
        batch_ids = [batch_ids]
    fig, axs = plt.subplots(1, len(batch_ids), figsize=(16*len(batch_ids), 8))
    if type(axs) != np.ndarray: axs = np.array([axs])
    for i, batch_id in enumerate(batch_ids):
        node_mask = batch.batch == batch_id  # Get nodes where batch == 0
        node_indices = torch.nonzero(node_mask, as_tuple=True)[0]
        
        subgraph_edge_mask = (batch.batch[batch.edge_index[0]] == batch_id) & \
                             (batch.batch[batch.edge_index[1]] == batch_id)
        subgraph_edges = batch.edge_index[:, subgraph_edge_mask]
        
        node_mapping = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(node_indices)}
        remapped_edges = torch.tensor([[node_mapping[e.item()] for e in edge] for edge in subgraph_edges.T])
        
        G = nx.Graph()
        G.add_edges_from(remapped_edges.numpy())
        
        nx.set_node_attributes(G, {v: k for k, v in node_mapping.items()}, "original_id")
        
        node_colors = []
        node_borders = []
        
        for node in G.nodes:
            if attr[batch.batch==batch_id][node] == 1:
                node_colors.append("lightblue")  # Fill color
                node_borders.append("red")  # Border color for attr == 1
            else:
                node_colors.append("lightblue")  # Fill color
                node_borders.append("black")  # Default border color
        
        
        pos = nx.kamada_kawai_layout(G) 
        
        nx.draw(
            G, pos,
            node_color=node_colors,
            edgecolors=node_borders,  # Border colors
            node_size=700,
            with_labels=True,
            ax = axs[i]
        )
        
        axs[i].set_title(f"Class = {batch.y[batch_id]}")
    plt.show()


In [None]:
def train_sparsity_epoch(model_tell, data, mask, device, optimizer, num_classes, conv_reg=1, fc_reg=1):
    model_tell.train()
    data = data.to(device)
    total_loss = 0
    total_correct = 0
    
        # try:
    loss = 0
    if data.x is None:
        data.x = torch.ones((data.num_nodes, model_tell.num_features))
        
    y = data.y.reshape(-1).to(device).long()
    optimizer.zero_grad()

    model_tell.fc.phi_in.tau = 10
    out = model_tell(data.x.float().to(device), data.edge_index.to(device))       
    pred = out.argmax(-1)
    loss += F.binary_cross_entropy(out.reshape(-1), torch.nn.functional.one_hot(y, num_classes=num_classes).float().reshape(-1)) + F.nll_loss(F.log_softmax(out, dim=-1), y.long())
    # tells = [c.nn[0] for c in model_tell.convs] + [model_tell.fc]
    for conv in model_tell.convs:
        #loss += conv_reg*(100*hoyer_sparsity_loss(torch.clamp(conv.nn[0].weight, min=1e-5)) + conv.nn[0].reg_loss + conv.nn[0].phi_in.entropy)
        loss += fc_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)
    # loss += fc_reg*(100*hoyer_sparsity_loss(torch.clamp(model_tell.fc.weight, min=1e-5)) + model_tell.fc.reg_loss + model_tell.fc.phi_in.entropy)
    loss += fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)
    loss.backward()
    zero_nan_gradients(model_tell)
    optimizer.step()
    total_loss += loss.item() 
    total_correct += pred.eq(y).sum().item() / data.x.shape[0]
        # except Exception as e:
        #     print(e)
        #     pass

    return total_loss, total_correct


In [None]:
import torch
import torch_scatter

def scatter_sum(x, edge_index):
    # Get target nodes (i.e., the nodes receiving the messages)
    target_nodes = edge_index[1]
    
    # Perform scatter sum
    out = torch_scatter.scatter_add(x[edge_index[0]], target_nodes, dim=0, dim_size=x.size(0)) + x

    return out

In [None]:
@torch.no_grad()
def forward_with_activations(self, x, edge_index, *args, **kwargs):
    returns = []
    x = self.input_bnorm(x)
    xs = []
    for i, conv in enumerate(self.convs):
        ret = {}
        ret['x'] = torch.hstack([x, 1-x])
        ret['x_sum'] = scatter_sum(ret['x'], edge_index)
        ret['x_bin'] = conv.nn[0].phi_in(ret['x_sum']) >= 0.5
        x = conv(torch.hstack([x, 1-x]), edge_index)
        xs.append(x)
        ret['y'] = x
        ret['y_bin'] = x>=0.5
        returns.append(ret)
    ret = {}
    x = torch.hstack(xs)
    x = self.output_bnorm(x)
    ret['x'] = torch.hstack([x, 1-x])
    ret['x_bin'] = self.fc.phi_in(ret['x']) >= 0.5
    x = self.fc(torch.hstack([x, 1-x]))
    ret['y'] = x
    ret['y_bin'] = x>=0.5
    returns.append(ret)
    return x, returns

In [None]:
def sigmoid(x, tau=10):
    return 1/(1+torch.exp(-tau*x))
    
def forward_tell(self, tau):
    def fw(x):

        # x = self.phi_in(torch.hstack([x, 1-x]))
        x = self.phi_in(x)
        self.max_in, _ = x.max(0)
        reg_loss = 0
        entropy_loss = 0
        if self.use_weight_sigma:
            reg_loss += torch.clamp(self.weight_s, min=1e-5).sum(-1).mean()
        else:
            reg_loss += torch.clamp(self.weight, min=1e-5).sum(-1).mean()
        if self.phi_in.entropy is not None:
            entropy_loss += self.phi_in.entropy
        # print('b', reg_loss, entropy_loss)
        self.reg_loss = reg_loss
        
        w = self.weight
        o = sigmoid(x @ w.t() + self.b, tau=tau)
        
        self.entropy_loss = entropy_loss + -(o*torch.log(o+1e-8) + (1-o)*torch.log(1-o + 1e-8)).mean()
        return o
    return fw



# Extract Rules for Last Layer

In [None]:
model_tell = model_tell.to(device)

In [None]:
import torch
import io

def clone_model(model):
    # Save the model to an in-memory buffer
    buffer = io.BytesIO()
    # model.dropout = None
    torch.save(model, buffer)
    # model.dropout = lambda x:x
    
    # Rewind the buffer
    buffer.seek(0)
    
    
    # Load the saved state into the new instance
    cloned_model = torch.load(buffer)
    # cloned_model.dropout = lambda x:x
    return cloned_model


In [None]:
optimizer = torch.optim.Adam([model_tell.fc.weight_sigma, model_tell.fc.weight_exp], lr=0.01, weight_decay=0)

# for param in model_tell.parameters():
#     param.requires_grad = False
    
print("Pruning last layer")
# model_tell.fc.weight_sigma.requires_grad = True
# model_tell.fc.weight_exp.requires_grad = True
# model_tell.fc.phi_in.w.requires_grad = True
# model_tell.fc.phi_in.b.requires_grad = True
best_weights = clone_model(model_tell)
val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)
n_w =  (model_tell.fc.weight>1e-4).sum().item()
best_situation = (val_acc, -n_w)
patience = max_patience = 50

for i in range(1000):
    # model_tell.fc.forward = forward_tell(model_tell.fc, 10)
    train_loss, train_acc = train_sparsity_epoch(model_tell, dataset[0], dataset[0].train_mask, device, optimizer, num_classes, conv_reg=0.1, fc_reg=0.1)
    # model_tell.fc.forward = forward_tell(model_tell.fc, 1000)
    val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)
    test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)
    n_w =  (model_tell.fc.weight>1e-4).sum().item()
    if (val_acc, -n_w) > best_situation or (-n_w > best_situation[1] and val_acc >= 0.95*best_situation[0]):
        best_weights = clone_model(model_tell)
        patience = max_patience
        best_situation = (val_acc, -n_w)
    patience -= 1 
    if i%10 == 0:
        print(i, train_loss, train_acc, val_acc, test_acc, n_w, patience)
    if patience == 0:
        break
model_tell = clone_model(best_weights)
val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)
test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)
n_w =  (model_tell.fc.weight>1e-4).sum().item()
print(val_acc, test_acc, n_w)

In [None]:
extract_rules(model_tell.fc)

In [None]:
model_tell = model_tell.to(device)


best_weights = clone_model(model_tell)

for l in reversed(range(len(model_tell.convs))):
    optimizer = torch.optim.Adam([model_tell.convs[l].nn[0].weight_sigma, model_tell.convs[l].nn[0].weight_exp], lr=0.005)
    
    val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)
    test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)
    n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()
    best_situation = (val_acc, -n_w)
    print(best_situation)
    patience = max_patience = 50
    for i in range(3000):
        # tell_layer.forward = forward_tell(tell_layer, 10)
        train_loss, train_acc = train_sparsity_epoch(model_tell, dataset[0], dataset[0].train_mask, device, optimizer, num_classes, conv_reg=0.1, fc_reg=0.1)
        # model_tell.fc.forward = forward_tell(model_tell.fc, 1000)
        val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)
        test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)
        n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()
        if (val_acc, -n_w) > best_situation or (-n_w > best_situation[1] and val_acc >= 0.95*best_situation[0]):
            best_weights = clone_model(model_tell)
            patience = max_patience
            best_situation = (val_acc, -n_w)
        patience -= 1 
        if i%10 == 0:
            print(i, train_loss, train_acc, val_acc, test_acc, n_w, patience)
        if patience == 0:
            break    
    model_tell = clone_model(best_weights)
    val_acc = test_epoch(model_tell, dataset[0], dataset[0].val_mask, device)
    test_acc = test_epoch(model_tell, dataset[0], dataset[0].test_mask, device)
    n_w =  (model_tell.convs[l].nn[0].weight>1e-4).sum().item()
    print(val_acc, test_acc, n_w)

In [None]:
with torch.no_grad():
    activations = None
    _, activations = forward_with_activations(model_tell, dataset[0].x.to(device), dataset[0].edge_index.to(device))


In [None]:
last_layer_rules = extract_rules(model_tell.fc, activations=activations[-1]['x_bin'].cpu())

In [None]:
feat_map = []
for pos_neg in ['pos', 'neg']:
    for l in range(num_layers):
        for d in range(hidden_dim):
            # feat_map.append(f'{pos_neg}_{readout}_{l}_{d}')
            feat_map.append((pos_neg, l, d))

In [None]:
last_layer_rules

In [None]:
last_layer_rules_renamed = []
for c in range(len(last_layer_rules)):
    s = set()
    for t in last_layer_rules[c]:
        new_t = []
        for i in t:
            if not activations[-1]['x_bin'][:,i].all().item():
                new_t.append((i, *feat_map[i]))
        s.add(tuple(new_t))
    last_layer_rules_renamed.append(s)

In [None]:
last_layer_rules_renamed

In [None]:
from torch_geometric.data import Batch, Data

In [None]:
from torch_geometric.utils import to_networkx, subgraph

In [None]:
def get_subgraph(data, node_mask):
    nodes_to_keep = torch.where(node_mask)[0]
    new_edge_index = subgraph(nodes_to_keep, data.edge_index, relabel_nodes=True)[0]
    new_x = data.x[node_mask]
    G = to_networkx(Data(x=new_x, edge_index=new_edge_index))
    return G

In [None]:
last_layer_rules_renamed

In [None]:
def find_minimal_sets(list_of_sets):
    minimal_sets = []
    for i, s in enumerate(list_of_sets):
        if not any((set(other)<set(s)) or (s == other and i != j) for j, other in enumerate(list_of_sets)):
            minimal_sets.append(s)
    return minimal_sets

In [None]:
last_layer_rules_renamed = [find_minimal_sets(r) for r in last_layer_rules_renamed]

In [None]:
tells = [c.nn[0] for c in model_tell.convs] + [model_tell.fc]
for tell in tells:
    tell.forward = forward_tell(tell, 10)

In [None]:
def find_step_intervals(w, b, xmin, xmax, tau=5, resolution=1000):
    # Sample x values
    xs = torch.linspace(xmin, xmax, resolution)
    wxb = w * xs + b
    ys = step(wxb, tau)

    intervals = []
    above = ys[0] > 0.5
    start = xs[0].item() if above else None

    for i in range(1, len(xs)):
        curr = ys[i] > 0.5
        if curr and not above:
            # Rising edge
            start = xs[i-1].item()
        elif not curr and above:
            # Falling edge
            end = xs[i].item()
            intervals.append((start, end))
            start = None
        above = curr

    if above and start is not None:
        intervals.append((start, xs[-1].item()))

    return intervals


In [None]:
from tell import step

In [None]:
conv_rules = []
fc_rules = []
model_tell = model_tell.cpu()
class_to_explain = 0
for rule in last_layer_rules_renamed[class_to_explain]:
    print(rule)
    # for literal in rule:

    
    ll_feats = [literal[0] for literal in rule]
    # if not rets[-1]['x_bin'][:,ll_feats].all(): continue
    node_mask = torch.ones(dataset[0].x.shape[0]).bool()
    phi_in = model_tell.fc.phi_in
    ands = []
    node_mask = None
    for literal in rule:
        m = (activations[-1]['x_bin'][:,literal[0]].detach().cpu())
        if node_mask is None:
            node_mask = m
        else: node_mask &= m
        intervals = find_step_intervals(phi_in.w[literal[0]].cpu(), phi_in.b[literal[0]].cpu(), activations[-1]['x'][:, literal[0]].min().cpu(), activations[-1]['x'][:, literal[0]].max().cpu(), tau=10, resolution=1000)
        ands.append((literal, intervals))
    print(node_mask)
    if node_mask.sum().item() == 0: continue
    fc_rules.append(ands)

    for literal,_ in ands:
        conv_rules.append(literal[-2:])

In [None]:
activations[-1]['x'][:,159].median()

In [None]:
fc_rules

In [None]:
conv_rules_details = {}
while(len(conv_rules)):
    layer, feat = conv_rules.pop(0)
    phi_in = model_tell.convs[layer].nn[0].phi_in
    hd = phi_in.w.shape[0]//2
    ho = model_tell.convs[layer].nn[0].weight.shape[0]
    rules = extract_rules(model_tell.convs[layer].nn[0], feature=feat%ho)
    refined_rules = []
    for c in range(len(rules)):
        s = set()
        for t in rules[c]:
            new_t = []
            for i in t:
                if not activations[layer]['x_bin'][:,i].all().item():
                    intervals = find_step_intervals(phi_in.w[i].cpu(), phi_in.b[i].cpu(), (activations[layer]['x_sum'][:, i]).min().cpu(), (activations[layer]['x_sum'][:, i]).max().cpu(), tau=10, resolution=1000)
                    new_t.append((i, tuple(intervals)))
                    if layer!=0:
                        if (layer-1, i) not in conv_rules_details and (layer-1, i) not in conv_rules:
                            conv_rules.append((layer-1, i))
            if new_t:
                s.add(tuple(new_t))
            else:
                s.add((True,))
        refined_rules.append(s)
    # print(rules)
    # print(refined_rules)
    # print(layer, feat, find_minimal_sets(refined_rules[0]))

    conv_rules_details[layer, feat] = (find_minimal_sets(refined_rules[0]), activations[layer]['y_bin'][:,feat%ho] if feat < ho else ~activations[layer]['y_bin'][:,feat%ho])
    print(conv_rules)

In [None]:
# conv_rules_details

In [None]:
from collections import Counter

dict(Counter([x[0]+1 for x in conv_rules_details.keys()if len(conv_rules_details[x][0])!=0 ]))

In [None]:
acts = []
for i in range(5):
    rs = [activations[x[0]]['x_bin'][:,[x[1]%activations[x[0]]['x_bin'].shape[1]]].float().mean().item() for x in conv_rules_details.keys() if x[0] == i]
    acts.append(rs)

In [None]:
acts

In [None]:
conv_rules_details