In [None]:
"""
Hierarchical Explanation via Divisive Generation (HEDGE) AOPC Calculation for BERT Sequence Classification Models
==================================================================
This script demonstrates how to compute HEDGE attributions and AOPC (Area Over the Perturbation Curve)
for a HuggingFace Transformer model. HEDGE helps explain predictions by identifying which parts of text
are most important for a model's decision.
The script is designed to be run locally or on a hosted runtime like Colab.

INSTRUCTIONS:
-------------
1. Install requirements (see below).
2. Set your own HuggingFace model and tokenizer, and provide paths to your train/test data.
3. Run the script!

REQUIREMENTS:
-------------
!pip install transformers datasets lime pandas torch

If running in Colab, uncomment and run the pip commands at the top of your notebook.
"""
from __future__ import absolute_import, division, print_function

import argparse
import glob
import logging
import os
import random
import itertools
import torch
import numpy as np
import pandas as pd

from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from copy import copy, deepcopy
from itertools import combinations

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import TextClassificationPipeline

Optional: Login to your HuggingFace Hub account

In [None]:
from huggingface_hub import login
login("your_token") # <-- CHANGE THIS to your HuggingFace Login Access Token

Optional: Check GPU and RAM availability

In [None]:
# -- Optional: Check GPU and RAM availability --
def print_gpu_ram_info():
    try:
        import subprocess
        # Check GPU info
        gpu_info = subprocess.check_output(['nvidia-smi']).decode()
        print("GPU Info:\n", gpu_info)
    except Exception:
        print('No GPU found or not connected to a GPU.')

    # Check RAM info
    try:
        from psutil import virtual_memory
        ram_gb = virtual_memory().total / 1e9
        print('Your runtime has {:.1f} GB of available RAM\n'.format(ram_gb))
    except ImportError:
        print('psutil not installed, skipping RAM check.')

# Call the function (optional)
print_gpu_ram_info()

User Configuration

In [None]:
# --- User Configuration ---
# Provide the name of your model (must be compatible with HuggingFace Transformers)
MODEL_NAME = "your_model"  # <-- CHANGE THIS to your model

# Provide the name of your tokenizer
TOKENIZER = "your_tokenizer" # <-- CHANGE THIS if not the same as your MODEL_NAME

# Path to your test CSV file (should have at least 'EssayText' and 'essay_score' columns)
TEST_CSV_PATH = "path/to/your/test_data.csv"  # <-- CHANGE THIS to your test data path

# Number of classes in your classification problem
NUM_LABELS = 2  # <-- CHANGE THIS to your number of classes

# Class names (must match your dataset)
CLASS_NAMES = [str(i) for i in range(NUM_LABELS)]  # or use your actual class names

# Random seed for reproducibility
RANDOM_STATE = 0

# Folder to save HEDGE attribution score computations
DIR = "path/to/your/folder" # <-- CHANGE THIS to your folder

Load Model and Tokenizer

In [None]:
# --- Load Model and Tokenizer ---
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

# Use GPU if available, else CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
model.to(device)
model.eval()

Load Data

In [None]:
# --- Load Data ---
# If using HuggingFace datasets:
# test_set = load_dataset('csv', data_files=TEST_CSV_PATH)['train']
# test_doc = list(test_set['EssayText'])

# Or load with pandas:
test = pd.read_csv(TEST_CSV_PATH)
assert 'EssayText' in test.columns, "Your CSV file must have an 'EssayText' column."
test_doc = list(test['EssayText'])
test.head()

Define Functions to Compute HEDGE Attribution Scores

In [None]:
class HEDGE:
    def __init__(self, model, inputs,  device, max_level=-1, thre=0.3):
        # Ensure inputs and model are on GPU
        inputs = {k: v.to(device) for k, v in inputs.items()}

        score = model(**inputs)[1].detach()  # Keep on GPU
        score_norm = torch.softmax(score, dim=1).detach()

        self.pred_label = torch.argmax(score_norm, dim=1).item()
        self.max_level = max_level
        self.output = []
        self.fea_num = len(inputs['input_ids'][0]) - 2
        self.level = 0
        #self.args = args
        self.thre = thre

        # Initial masking for bias calculation
        input_ids = inputs['input_ids'][0]
        mask_input = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
        mask_attention = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
        mask_type = torch.zeros(input_ids.shape, dtype=torch.long, device=device)

        temp = {
            'input_ids': torch.unsqueeze(mask_input, 0),
            'attention_mask': torch.unsqueeze(mask_attention, 0),
            'token_type_ids': torch.unsqueeze(mask_type, 0),
            'labels': inputs['labels']
        }

        #score = model(**temp)[1].detach()  # Keep on GPU
        #score_norm = torch.softmax(score, dim=1).detach()

        self.bias = score_norm[0][self.pred_label].item()

    def set_contribution_func(self, model, fea_set, inputs):
        inputs = {k: v.to(device) for k, v in inputs.items()}  # Ensure inputs are on GPU
        input_ids = inputs['input_ids'][0]

        mask_input = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
        mask_input[0] = input_ids[0]
        mask_input[-1] = input_ids[-1]

        mask_attention = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
        mask_attention[0] = 1
        mask_attention[-1] = 1

        mask_type = torch.zeros(input_ids.shape, dtype=torch.long, device=device)

        for fea_idx in fea_set:
            if isinstance(fea_idx, int):
                mask_input[fea_idx+1] = input_ids[fea_idx+1]
                mask_attention[fea_idx+1] = 1
            else:
                for idx in fea_idx:
                    mask_input[idx+1] = input_ids[idx+1]
                    mask_attention[idx+1] = 1

        temp = {
            'input_ids': torch.unsqueeze(mask_input, 0),
            'attention_mask': torch.unsqueeze(mask_attention, 0),
            'token_type_ids': torch.unsqueeze(mask_type, 0),
            'labels': inputs['labels']
        }

        # Model inference with masked inputs
        score = model(**temp)[1].detach()  # Keep on GPU
        score_norm = torch.softmax(score, dim=1).detach()

        return score_norm[0][self.pred_label].item() - self.bias

    def shapley_interaction_score_approx(self, model, inputs, feature_set, left, right, win_size):
        if left + 1 != right:
            print("Not adjacent interaction")
            return -1

        curr_set_lr = list((feature_set[left], feature_set[right]))
        curr_set_l = [feature_set[left]] if isinstance(feature_set[left], int) else feature_set[left]
        curr_set_r = [feature_set[right]] if isinstance(feature_set[right], int) else feature_set[right]

        fea_num = len(feature_set)
        if left - win_size > 0:
            left_set = feature_set[left - win_size:left]
        else:
            left_set = feature_set[0:left]

        if right + win_size > fea_num - 1:
            right_set = feature_set[right + 1:]
        else:
            right_set = feature_set[right + 1:right + win_size + 1]

        adj_set = left_set + right_set
        num_adj = len(adj_set)
        dict_subset = {r: list(combinations(adj_set, r)) for r in range(num_adj+1)}

        score = 0.0
        for i in range(num_adj + 1):
            weight = self.get_shapley_interaction_weight(fea_num, i)
            if i == 0:
                score_included = self.set_contribution_func(model, curr_set_lr, inputs)
                score_excluded_l = self.set_contribution_func(model, curr_set_r, inputs)
                score_excluded_r = self.set_contribution_func(model, curr_set_l, inputs)
                score_excluded = self.set_contribution_func(model, [], inputs)
                score += (score_included - score_excluded_l - score_excluded_r + score_excluded) * weight
            else:
                for subsets in dict_subset[i]:
                    score_included = self.set_contribution_func(model, list(subsets) + curr_set_lr, inputs)
                    score_excluded_l = self.set_contribution_func(model, list(subsets) + curr_set_r, inputs)
                    score_excluded_r = self.set_contribution_func(model, list(subsets) + curr_set_l, inputs)
                    score_excluded = self.set_contribution_func(model, list(subsets), inputs)
                    score += (score_included - score_excluded_l - score_excluded_r + score_excluded) * weight

        return score

      # Import the standard library math module

    def get_shapley_interaction_weight(self, d, s):
        return math.factorial(s) * math.factorial(d - s - 2) / math.factorial(d - 1) / 2


    def shapley_interaction_score_approx(self, model, input, feature_set,left, right, win_size):
        if left + 1 != right:
            print("Not adjacent interaction")
            return -1
        fea_num = len(feature_set)
        curr_set_lr = list((feature_set[left], feature_set[right]))
        curr_set_l = [feature_set[left]] if type(feature_set[left]) == int else feature_set[left]
        curr_set_r = [feature_set[right]] if type(feature_set[right]) == int else feature_set[right]
        if left + 1 - win_size > 0:
            left_set = feature_set[left - win_size:left]
        else:
            left_set = feature_set[0:left]
        if right + win_size > fea_num - 1:
            right_set = feature_set[right + 1:]
        else:
            right_set = feature_set[right + 1:right + win_size + 1]
        adj_set = left_set + right_set
        num_adj = len(adj_set)
        dict_subset = {r: list(combinations(adj_set, r)) for r in range(num_adj+1)}
        score = 0.0
        for i in range(num_adj+1):
            weight = self.get_shapley_interaction_weight(fea_num, i)
            if i == 0:
                score_included = self.set_contribution_func(model, curr_set_lr, input)
                score_excluded_l = self.set_contribution_func(model, curr_set_r, input)
                score_excluded_r = self.set_contribution_func(model, curr_set_l, input)
                score_excluded = self.set_contribution_func(model, [], input)
                score += (score_included - score_excluded_l - score_excluded_r + score_excluded) * weight
            else:
                for subsets in dict_subset[i]:
                    score_included = self.set_contribution_func(model, list(subsets) + curr_set_lr, input)
                    score_excluded_l = self.set_contribution_func(model, list(subsets) + curr_set_r, input)
                    score_excluded_r = self.set_contribution_func(model, list(subsets) + curr_set_l, input)
                    score_excluded = self.set_contribution_func(model, list(subsets), input)
                    score += (score_included - score_excluded_l - score_excluded_r + score_excluded) * weight
        return score

    def shapley_topdown_tree(self, model, inputs, win_size):
        fea_num = self.fea_num
        if fea_num == 0:
            return -1
        fea_set = [list(range(fea_num))]
        if self.max_level < 1:
            self.max_level = 300
        #begin split the sentence
        pos = 0
        level = 0
        hier_tree = {}
        hier_tree[0] = fea_set
        for level in range(1, self.fea_num):
            pos = 0
            min_inter_score = 1e8
            pos_opt = 0
            inter_idx_opt = 0
            while pos < len(fea_set):
                subset = fea_set[pos]
                sen_len = len(subset)
                if sen_len == 1:
                    pos += 1
                    continue
                new_fea_set = [ele for x, ele in enumerate(fea_set) if x != pos]
                score_buff = []
                for idx in range(1, sen_len):
                    leave_one_set = deepcopy(new_fea_set)
                    sub_set1 = subset[0:idx]
                    sub_set2 = subset[idx:]
                    leave_one_set.insert(pos, sub_set1)
                    leave_one_set.insert(pos + 1, sub_set2)
                    score_buff.append(self.shapley_interaction_score_approx(model, inputs, leave_one_set, pos, pos + 1, win_size))
                inter_score = np.array(score_buff)
                min_inter_idx = np.argmin(inter_score)
                minter = inter_score[min_inter_idx]
                if minter < min_inter_score:
                    min_inter_score = minter
                    inter_idx_opt = min_inter_idx
                    pos_opt = pos
                pos += 1

            new_fea_set = [ele for x, ele in enumerate(fea_set) if x != pos_opt]
            subset = fea_set[pos_opt]
            sub_set1 = subset[0:inter_idx_opt + 1]
            sub_set2 = subset[inter_idx_opt + 1:]
            new_fea_set.insert(pos_opt, sub_set1)
            new_fea_set.insert(pos_opt + 1, sub_set2)
            fea_set = new_fea_set
            hier_tree[level] = fea_set
        self.max_level = level
        self.hier_tree = hier_tree
        return hier_tree

    def compute_shapley_hier_tree(self, model, inputs, win_size):
        hier_tree = self.shapley_topdown_tree(model,inputs, win_size)
        self.hier_tree = {}
        for level in range(self.max_level+1):
            self.hier_tree[level] = []
            for subset in hier_tree[level]:
                self.hier_tree[level].append((subset,2*(self.set_contribution_func(model,subset, inputs)+self.bias)-1))
        return self.hier_tree

    def get_importance_phrase(self, num=-1):
        hier_list = []

        for level in range(1, self.max_level + 1):
            for fea_set, score in self.hier_tree[level]:
                hier_list.append((fea_set, score))
        hier_list = sorted(hier_list, key=lambda item: item[1], reverse=True)
        phrase_list = []
        if num == -1:
            num = 10000
        pre_items = []
        score_list = []
        count = 0
        for items, score in hier_list:
            if count == num:
                break
            if not set(items) == set(pre_items):
                phrase_list.append(items)
                score_list.append(score)
                pre_items = items
                count += 1
        return phrase_list, score_list


    def collect_unsplit_items(self):
        levels = range(self.max_level)
        words_list = []
        for level in levels:
            for fea in self.hier_tree[level]:
                #get next level words
                next_level_list = []
                for item in self.hier_tree[level+1]:
                    next_level_list += item[0]
                if len(set(fea[0]).intersection(set(next_level_list)))==0:
                    words_list.append((fea[0],fea[1]))
        words_list = sorted(words_list, key=lambda item: item[1], reverse=True)
        return words_list

    def complete_hier_tree(self):
        # word_dict = {}
        hier_tree = {}
        word_list = self.collect_unsplit_items()

        for level in range(self.max_level + 1):
            hier_tree[level] = []
            ele_list = []
            for subset in self.hier_tree[level]:
                hier_tree[level].append(subset)
                ele_list += subset[0]
            for ele in word_list:
                if len(set(ele_list).intersection(set(ele[0]))) == 0:
                    hier_tree[level].append((ele[0], ele[1]))
        return hier_tree

    def get_last_level_phrases(self, inputs):
        text = inputs['input_ids'][0]
        text = text.detach().to(device).numpy()
        hier_tree = self.complete_hier_tree()
        last_level = hier_tree[self.max_level]
        ordered_list = sorted(last_level, key=lambda item: item[1], reverse=True)
        return [[text[idx+1] for idx in idx_set[0]] for idx_set in ordered_list]


    def visualize_tree(self, batch, tokenizer, fontsize=10, tag=''):
        text = batch['input_ids'][0]
        text = text.detach().to(device).numpy()
        levels = self.max_level
        vals = np.array([fea[1] for level in range(levels) for fea in self.hier_tree[level]])
        min_val = np.min(vals)
        max_val = np.max(vals)
        import matplotlib as mpl
        import matplotlib.pyplot as plt
        cnorm = mpl.colors.Normalize(vmin=-1, vmax=1, clip=False)
        if self.pred_label == 1:  # 1 stands for positive
            cmapper = mpl.cm.ScalarMappable(norm=cnorm, cmap='RdYlBu')
        else:  # 0 stands for negative
            cmapper = mpl.cm.ScalarMappable(norm=cnorm, cmap='RdYlBu_r')

        fig, ax = plt.subplots(figsize=(12, 7))
        ax.xaxis.set_visible(False)
        ylabels = ['Level ' + str(idx) for idx in range(self.max_level + 1)]
        ax.set_yticks(list(range(0, self.max_level + 1)))
        ax.set_yticklabels(ylabels)
        ax.set_ylim(self.max_level + 0.5, 0 - 0.5)
        sep_len = 0.3
        for key in range(levels+1):
            for fea in self.hier_tree[key]:
                len_fea = len(fea[0])
                start_fea = fea[0] if type(fea[0])==int else fea[0][0]
                start = sep_len * start_fea + start_fea + 0.5
                width = len_fea + sep_len * (len_fea - 1)
                fea_color = cmapper.to_rgba(fea[1])
                r, g, b, _ = fea_color
                c = ax.barh(key, width=width, height=0.5, left=start, color=fea_color)
                text_color = 'white' if r * g * b < 0.3 else 'black'
                #         text_color = 'black'
                word_idxs = fea[0]
                for i, idx in enumerate(word_idxs):
                    word_pos = start + sep_len * (i) + i + 0.5
                    word_str = tokenizer.ids_to_tokens[text[idx+1]]#+1 accounts for the CLS token at the begining
                    ax.text(word_pos, key, word_str, ha='center', va='center',
                            color=text_color, fontsize=fontsize)
                    word_pos += sep_len
                start += (width + sep_len)
        fig.colorbar(cmapper, ax=ax)
        plt.savefig('visualization_sentence_{}.png'.format(tag))
#        plt.show()

Define Functions to Compute AOPC

In [None]:
def compute_aopc_top_20_percent(hedge, model, inputs, word_list, score_list, tokenizer, device):
    # Ensure each tensor in the inputs dictionary is moved to the correct device
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Step 1: Get the model's original prediction confidence
    with torch.no_grad():
        original_output = model(**inputs)  # Ensure inputs are on GPU
        logits = original_output.logits  # Access logits from the output
        score = logits  # Keep the logits on the GPU

        # Apply softmax to get probabilities
        score_norm = torch.softmax(score, dim=1)
        original_confidence = score_norm[0][hedge.pred_label]

    # Step 2: Identify the top 20% most important tokens
    important_phrases = [i for i, j in zip(word_list, score_list) if len(i)==1]
    score_list = [j for i, j in zip(word_list, score_list) if len(i)==1]

    num_tokens = sum([len(phrase) for phrase in important_phrases])
    top_20_percent_count = max(1, int(0.2 * num_tokens))  # Ensure at least 1 token is masked

    # Flatten important_phrases and score_list to a single list of (token_indices, score)
    flat_phrases_scores = [(token_set, score) for token_set, score in zip(important_phrases, score_list)]
    flat_phrases_scores = sorted(flat_phrases_scores, key=lambda x: x[1], reverse=True)

    # Collect the top 20% token indices
    top_tokens = []
    count = 0
    for phrase, score in flat_phrases_scores:
        top_tokens.extend(phrase)
        count += len(phrase)
        if count >= top_20_percent_count:
            break

    total_score_drop = 0
    top_tokens_scores = [ (tokens[i[0]+1], round(j,3), i ) for i , j in flat_phrases_scores]

    for i in range(1, len(top_tokens) + 1):
        # Perturb the input by removing the top i important tokens
        # Step 3: Remove the top 20% important tokens by replacing them with the PAD token

        indices_to_remove = top_tokens[:i]
        perturbed_tokens = [tokens[j] for j in range(len(tokens)) if j not in indices_to_remove]
        perturbed_input = tokenizer(" ".join(perturbed_tokens), return_tensors='pt', padding=True, truncation=True).to(device)
        perturbed_input = {
            'input_ids': perturbed_input['input_ids'],
            'attention_mask': perturbed_input['attention_mask'],
            'token_type_ids': perturbed_input.get('token_type_ids'),  # Not all tokenizers return token_type_ids
            'labels': inputs['labels'].to(device)
        }

        # Step 4: Get the model's new confidence after masking the top 20% important tokens
        with torch.no_grad():
            masked_output = model(**perturbed_input)
            masked_logits = masked_output.logits  # Access logits from the output
            masked_score = masked_logits  # Keep the logits on the GPU

            # Apply softmax to get probabilities
            masked_score_norm = torch.softmax(masked_score, dim=1)
            masked_confidence = masked_score_norm[0][hedge.pred_label]

        # Step 5: Calculate AOPC (original - new confidence for top 20% masked)
        score_drop = original_confidence - masked_confidence
        print(score_drop)
        total_score_drop += score_drop.item()

    aopc = total_score_drop / len(top_tokens)
    return aopc, top_tokens_scores

Main Loop: Calculate AOPC for Test Set

In [None]:
i=0
dir_ = DIR

#final_hedge = []
#final_list = []
for ex, sc in zip(test['EssayText'].values, test['essay_score'].values):
  inputs = tokenizer(ex, return_tensors='pt')
  inputs = {k: v.to(device) for k, v in inputs.items()}
  inputs['labels'] = torch.tensor([int(sc)]).to(device)

  hedge = HEDGE(model, inputs, device)
  print(ex)
  print('processing ---------------------', i)
  hedge.compute_shapley_hier_tree(model, inputs, 2)
  word_list, score_list = hedge.get_importance_phrase()
  tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

  aopc_hedge, attribution_score_list  = compute_aopc_top_20_percent(hedge, model, inputs, word_list, score_list, tokenizer, device)

  print(f"HEDGE attribution scores: {attribution_score_list}")
  print(f"AOPC using HEDGE scores: {aopc_hedge}")
  #final_hedge.append(aopc_hedge)
  #final_list.append(attribution_score_list)

  with open(dir_+str(i)+'.txt', 'w') as t:
    t.write(str(aopc_hedge))
    t.write('\n')
    t.write(str(attribution_score_list))
    t.close()
    print(i)
  i+=1

#print(np.mean(final_hedge))