In [1]:
import os
os.chdir('..')

In [2]:
from torch import nn
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from utils.syn_dataset import SynGraphDataset
from utils.spmotif_dataset import *
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool
from utils.utils import *
from sklearn.model_selection import train_test_split
import shutil
import glob
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pandas as pd
import argparse
import pickle
import json
import io
from models.model import GIN, GINTELL
import sys
from train_logic import test_epoch

import torch
import io
import tqdm


def get_best_baseline_path(dataset_name):
    l = glob.glob(f'results/{dataset_name}/*/results.json')
    fl = [json.load(open(f)) for f in l]
    df = pd.DataFrame(fl)
    if df.shape[0] == 0: return None
    df['fname'] = l
    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])
    df = df[df.fname.str.contains('nogumbel=False')]
    fname = df.iloc[-1]['fname']
    fname = fname.replace('/results.json', '')
    return fname

def get_best_path(dataset_name):
    l = glob.glob(f'results_logic/{dataset_name}/*/*/results.json')
    fl = [json.load(open(f)) for f in l]
    df = pd.DataFrame(fl)
    if df.shape[0] == 0: return None
    df['fname'] = l
    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])
    df = df[df.fname.str.contains('nogumbel=False')]
    print(df.tail())
    fname = df.iloc[-1]['fname']
    fname = fname.replace('/results.json', '')
    return fname

import torch

def inverse_sigmoid(x):
    """Computes the inverse of the sigmoid function (logit function)."""
    return torch.log(x / (1 - x))

torch.no_grad()
def find_logic_rules(w, t_in, t_out, activations=None, max_rule_len=10, max_rules=100, min_support=5):
    w = w.clone()
    t_in = t_in.clone()
    t_out = t_out.clone()
    t_out = t_out.item()
    ordering_scores = w
    sorted_idxs = torch.argsort(ordering_scores, 0, descending=True)
    mask = w > 1e-5
    if activations is not None:
        mask = mask & (activations.sum(0) >= min_support)
    total_result = set()

    # Filter and sort indices based on the mask
    idxs_to_visit = sorted_idxs[mask[sorted_idxs]]
    if idxs_to_visit.numel() == 0:
        return total_result

    # Sort weights based on the filtered indices
    sorted_weights = w[idxs_to_visit]
    current_combination = []
    result = set()

    def find_logic_rules_recursive(index, current_sum):
        # Stop if the maximum number of rules has been reached
        if len(result) >= max_rules:
            return

        if len(current_combination) > max_rule_len:
            return

        # Check if the current combination satisfies the condition
        if current_sum >= t_out:
            c = idxs_to_visit[current_combination].cpu().detach().tolist()
            c = tuple(sorted(c))
            result.add(c)
            return

        # Prune if remaining weights can't satisfy t_out
        remaining_max_sum = current_sum + sorted_weights[index:].sum()
        if remaining_max_sum < t_out:
            return

        # Explore further combinations
        for i in range(index, idxs_to_visit.shape[0]):
            # Prune based on activations if provided
            if activations is not None and len(current_combination) > 0 and activations[:, idxs_to_visit[current_combination + [i]]].all(-1).sum().item() < min_support:
                continue

            current_combination.append(i)
            find_logic_rules_recursive(i + 1, current_sum + sorted_weights[i])
            current_combination.pop()

    # Start the recursive process
    find_logic_rules_recursive(0, 0)
    return result


def extract_rules(self, feature=None, activations=None, max_rule_len=float('inf'), max_rules=5, min_support=10, out_threshold=0.5):
    ws = self.weight
    t_in = self.phi_in.t
    t_out = -self.b + inverse_sigmoid(torch.tensor(out_threshold))

    rules = []
    if feature is None:
        features = range(self.out_features)
    else:
        features = [feature]
    for i in features:
        w = ws[i].to('cpu')
        ti = t_in.to('cpu')
        to = t_out[i].to('cpu')
        rules.append(find_logic_rules(w, ti, to, activations, max_rule_len, max_rules, min_support))

    return rules


def hoyer_sparsity_loss(weights, lambda_=1.0, epsilon=1e-12):
    """
    Hoyer's sparsity loss to promote sparsity.
    
    Args:
        weights (torch.Tensor): The weights to regularize.
        lambda_ (float): Regularization strength.
        epsilon (float): Small value to prevent division by zero.
    
    Returns:
        torch.Tensor: The Hoyer's sparsity loss.
    """
    l1_norm = torch.sum(torch.abs(weights), -1)
    l2_norm = torch.sqrt(torch.sum(weights**2, -1) + epsilon)
    hoyer = (torch.sqrt(torch.tensor(weights.numel())) - l1_norm / l2_norm) / \
            (torch.sqrt(torch.tensor(weights.numel())) - 1 + epsilon)
    loss = lambda_ * (1 - hoyer)
    return loss.mean()


def train_sparsity_epoch(model_tell, loader, device, optimizer, num_classes, conv_reg=1, fc_reg=1):
    model_tell.train()
    
    total_loss = 0
    total_correct = 0
    
    for data in loader:
        try:
            loss = 0
            if data.x is None:
                data.x = torch.ones((data.num_nodes, model_tell.num_features))
            if data.y.numel() == 0: continue
            if data.x.isnan().any(): continue
            if data.y.isnan().any(): continue
            y = data.y.reshape(-1).to(device).long()
            optimizer.zero_grad()

            model_tell.fc.phi_in.tau = 10
            out = model_tell(data.x.float().to(device), data.edge_index.to(device), data.batch.to(device))       
            pred = out.argmax(-1)
            loss += F.binary_cross_entropy(out.reshape(-1), torch.nn.functional.one_hot(y, num_classes=num_classes).float().reshape(-1)) + F.nll_loss(F.log_softmax(out, dim=-1), y.long())
            # tells = [c.nn[0] for c in model_tell.convs] + [model_tell.fc]
            for conv in model_tell.convs:
                loss += conv_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)
                #oss += (hoyer_sparsity_loss(torch.clamp(conv.nn[0].weight, min=1e-5)) + conv.nn[0].reg_loss + conv.nn[0].phi_in.entropy) + conv_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)
                # loss += conv_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)
            # loss += fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)
            loss += (hoyer_sparsity_loss(torch.clamp(model_tell.fc.weight, min=1e-5)) + model_tell.fc.reg_loss + model_tell.fc.phi_in.entropy) + fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)
            # loss += fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)
            loss.backward()
            zero_nan_gradients(model_tell)
            optimizer.step()
            total_loss += loss.item() * data.num_graphs / len(loader.dataset)
            total_correct += pred.eq(y).sum().item() / len(loader.dataset)
        except Exception as e:
            print(e)
            pass

    return total_loss, total_correct

import torch
import torch_scatter

import torch
import torch_scatter
def scatter_sum(x, edge_index):
    # Get target nodes (i.e., the nodes receiving the messages)
    target_nodes = edge_index[1]
    
    # Perform scatter sum
    out = torch_scatter.scatter_add(x[edge_index[0]], target_nodes, dim=0, dim_size=x.size(0)) + x

    return out
@torch.no_grad()
def forward_with_activations(self, x, edge_index, batch, *args, **kwargs):
    returns = []
    x = self.input_bnorm(x)
    xs = []
    for i, conv in enumerate(self.convs):
        ret = {}
        ret['x'] = torch.hstack([x, 1-x])
        ret['x_sum'] = scatter_sum(ret['x'], edge_index)
        ret['x_bin'] = conv.nn[0].phi_in(ret['x_sum']) >= 0.5
        x = conv(torch.hstack([x, 1-x]), edge_index)
        xs.append(x)
        ret['y'] = x
        ret['y_bin'] = x>=0.5
        returns.append(ret)
    ret = {}
    x_mean = global_mean_pool(torch.hstack(xs), batch)
    x_max = global_max_pool(torch.hstack(xs), batch)
    x_sum = global_add_pool(torch.hstack(xs), batch)
    x = torch.hstack([x_mean, x_max, x_sum])
    x = self.output_bnorm(x)
    ret['x'] = torch.hstack([x, 1-x])
    ret['x_bin'] = self.fc.phi_in(ret['x']) >= 0.5
    x = self.fc(torch.hstack([x, 1-x]))
    ret['y'] = x
    ret['y_bin'] = x>=0.5
    returns.append(ret)
    return x, returns

def sigmoid(x, tau=10):
    return 1/(1+torch.exp(-tau*x))
    
def forward_tell(self, tau):
    def fw(x):

        # x = self.phi_in(torch.hstack([x, 1-x]))
        x = self.phi_in(x)
        self.max_in, _ = x.max(0)
        reg_loss = 0
        entropy_loss = 0
        if self.use_weight_sigma:
            reg_loss += torch.clamp(self.weight_s, min=1e-5).sum(-1).mean()
        else:
            reg_loss += torch.clamp(self.weight, min=1e-5).sum(-1).mean()
        if self.phi_in.entropy is not None:
            entropy_loss += self.phi_in.entropy
        # print('b', reg_loss, entropy_loss)
        self.reg_loss = reg_loss
        
        w = self.weight
        o = sigmoid(x @ w.t() + self.b, tau=tau)
        
        self.entropy_loss = entropy_loss + -(o*torch.log(o+1e-8) + (1-o)*torch.log(1-o + 1e-8)).mean()
        return o
    return fw


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def clone_model(model):
    # Save the model to an in-memory buffer
    buffer = io.BytesIO()
    torch.save(model, buffer)
    buffer.seek(0)
    cloned_model = torch.load(buffer)
    return cloned_model

In [4]:
##### import tqdm
final_res = {}

for dataset_name in ['MUTAG']:
    print(dataset_name)
    final_res[dataset_name] = []
    for seed in range(1):
        results_path = os.path.join(get_best_path(dataset_name), str(seed))
        data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))
        args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))
        device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
        dataset = get_dataset(dataset_name)
        num_classes = dataset.num_classes
        num_features = dataset.num_features
    
        indices = list(range(len(dataset)))
        train_indices, val_test_indices = train_test_split(indices, test_size=0.2,
        shuffle=True, stratify=dataset.data.y, random_state=1)
        
        val_indices = val_test_indices[:len(val_test_indices)//2]
        test_indices = val_test_indices[len(val_test_indices)//2:]
        
        train_dataset = dataset[train_indices]
        val_dataset = dataset[val_indices]
        test_dataset = dataset[test_indices]
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
        model_tell = torch.load(os.path.join(results_path, 'best.pt'), map_location=device)
        model_tell = model_tell.to(device)

        initial_acc = test_acc = test_epoch(model_tell, test_loader, device)
        
        best_weights = clone_model(model_tell)
        layers = [model_tell.fc, *[l.nn[0] for l in model_tell.convs[::-1]]]
        
        initial_weights = [(layer.weight>1e-4).sum().item() for layer in layers]
        for l in reversed(range(len(layers))):
            print('Layer', l)
            layer = layers[l]
            optimizer = torch.optim.Adam([layers[l].weight_sigma, layers[l].weight_exp], lr=0.005)
            val_acc = test_epoch(model_tell, val_loader, device)
            test_acc = test_epoch(model_tell, test_loader, device)
            n_w =  (layer.weight>1e-4).sum().item()
            best_situation = (val_acc, -n_w)
            print(best_situation)
            patience = max_patience = 150
            for i in tqdm.tqdm(range(1000)):
                # model_tell.fc.forward = forward_tell(model_tell.fc, 10)
                train_loss, train_acc = train_sparsity_epoch(model_tell, train_loader, device, optimizer, num_classes, conv_reg=0.1, fc_reg=0.01)
                val_acc = test_epoch(model_tell, val_loader, device)
                test_acc = test_epoch(model_tell, test_loader, device)
                n_w =  (layers[l].weight>1e-4).sum().item()
                # if (val_acc, -n_w) > best_situation:
                if (val_acc, -n_w) > best_situation or (-n_w > best_situation[1] and val_acc >= 0.999*best_situation[0]):
                    best_weights = clone_model(model_tell)
                    patience = max_patience
                    best_situation = (val_acc, -n_w)
                patience -= 1 
                if patience == 0:
                    break
            model_tell = clone_model(best_weights)
            layers = [model_tell.fc, *[l.nn[0] for l in model_tell.convs[::-1]]]
        
        final_weights = [(layer.weight>1e-4).sum().item() for layer in layers]
        final_acc = test_epoch(model_tell, test_loader, device)
        final_res[dataset_name].append([initial_acc, initial_weights, final_acc, final_weights])

MUTAG
   val_acc_mean  test_acc_mean  val_acc_std  test_acc_std  \
0           1.0       0.842105          NaN           NaN   

                                               fname  
0  results_logic/MUTAG/batch_size=32|conv_reg=0.0...  
Layer 2
(0.9473684210526315, -252)


100%|██████████| 1000/1000 [01:04<00:00, 15.40it/s]


Layer 1
(0.9473684210526315, -767)


100%|██████████| 1000/1000 [01:04<00:00, 15.39it/s]


Layer 0
(0.9473684210526315, -70)


100%|██████████| 1000/1000 [01:04<00:00, 15.55it/s]


In [5]:
final_res

{'MUTAG': [[0.8947368421052632,
   [70, 767, 252],
   0.8947368421052632,
   [26, 174, 85]]]}