# Aggregated SHAP values: understanding feature importances for finetuned BERT model

SHAP values are assigned PER TOKEN PER INSTANCE (local explanation). By aggregating SHAP values across instances, we can get insights about which features (i.e., tokens) are important for the classification task (global explanation).


Features == tokens


3 types of aggregation:

- at the token level: look at shap values of individual tokens across instances and average across occurrences of each token.

- at the ngram level: look at ngrams across instances (any value of n). The SHAP value of an ngram is the sum of the SHAP values of its constituent tokens divided by n. Then, average across occurrences of each ngram.

- at the chunk level: look at consecutive tokens that get the exact same SHAP values, treat them as a chunk whose SHAP values are the one of any of its constituent tokens (since they are all the same). Then, average across occurrences of each chunk.


2 types of average:

- per class: average contribution of a token WHEN IT DOES CONTRIBUTE to a class prediction (magnitude of contribution when it contributes)

- overall: average contribution of a token across all instances where it appears, regardless of whether it contributes to the prediction or not (by averaging we can see whether on average it contributes more to pos/neg class)

In [None]:
import json
import os
import pandas as pd
import numpy as np
import pickle as pkl

PROJECT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
SHAP_DIR = os.path.join(PROJECT_DIR, 'classification/shap_values/coqa/')
MASK_DIR = os.path.join(SHAP_DIR, 'masktoken_mask_default')

: 

In [None]:
entity_frequency_dict = json.load(open(os.path.join(SHAP_DIR, 'bert_entity_frequency_test_empty.json'), 'r'))
entity_frequency_dict_mask = None

: 

In [None]:
def visualize_shap_results(shap_file, n_features=10, min_freq=20, min_chunk_size=1):

    freq_dict = entity_frequency_dict
    load_from_dir = SHAP_DIR
    entity_type = shap_file.split('_')[3]+'s'

    # Relevant for chunk-level shap values
    # Using different mask token leads to different chunks
    # + values derived with mask_token==[MASK] are in a different DIR
    if 'mask' in shap_file:
        freq_dict = entity_frequency_dict_mask
        load_from_dir = MASK_DIR

    if 'class-avg' in shap_file:
        with open(os.path.join(load_from_dir, shap_file), 'r') as f:
            shap_results = json.load(f)

        # Filter out chunks with length < min_chunk_size
        if entity_type == 'chunks':
            shap_results = {
                k: v for k, v in shap_results.items() if len(k.split(' ')) >= min_chunk_size
            }

        # Get frequency info and filter out entities with frequency < min_freq
        sorted_neg = [
            (k, v) for k, v in sorted(shap_results.items(), key=lambda item: item[1][0])
            if freq_dict[entity_type][k] >= min_freq
        ]
        top_n_neg = sorted_neg[:n_features]
        neg_df = pd.DataFrame(top_n_neg, columns=['entity', 'shap_values'])
        neg_df['SV_neg_contribution'] = neg_df['shap_values'].apply(lambda x: float(x[0]))
        neg_df['SV_pos_contribution'] = neg_df['shap_values'].apply(lambda x: float(x[1]))
        neg_df = neg_df.drop(columns=['shap_values'])
        
        sorted_pos = [
            (k, v) for k, v in sorted(shap_results.items(), key=lambda item: item[1][1])
            if freq_dict[entity_type][k] >= min_freq
        ]
        top_n_pos = sorted_pos[-n_features:]
        pos_df = pd.DataFrame(top_n_pos, columns=['entity', 'shap_values'])
        pos_df['SV_neg_contribution'] = pos_df['shap_values'].apply(lambda x: float(x[0]))
        pos_df['SV_pos_contribution'] = pos_df['shap_values'].apply(lambda x: float(x[1]))
        pos_df = pos_df.drop(columns=['shap_values'])

        top_shap_values = pd.merge(neg_df, pos_df, how='outer')

        # add frequency info: add to name of entity
        top_shap_values['entity'] = top_shap_values['entity'].apply(
            lambda x: x + f" ({freq_dict[entity_type][x]})"
        )

        top_shap_values.plot(
            x='entity', 
            y=['SV_neg_contribution', 'SV_pos_contribution'], 
            kind='barh', 
            figsize=(20, 15),
            stacked=True
        )
        
    else:
        with open(os.path.join(load_from_dir, shap_file), 'r') as f:
            shap_results = json.load(f)

        # Filter out chunks with length < min_chunk_size
        if entity_type == 'chunks':
            shap_results = {
                k: v for k, v in shap_results.items() if len(k.split(' ')) >= min_chunk_size
            }
        
        sorted_shap_results = {
            k: v for k, v in sorted(shap_results.items(), key=lambda item: item[1])
        }
        
        shap_df = pd.DataFrame(sorted_shap_results.items(), columns=['entity', 'shap_value'])

        # add frequency column
        shap_df['frequency'] = shap_df['entity'].apply(
            lambda x: freq_dict[entity_type][x]
        )
        # filter out entitys with frequency < min_freq
        shap_df = shap_df[shap_df['frequency'] >= min_freq]

        top_n_neg = shap_df.head(n_features)
        top_n_pos = shap_df.tail(n_features)
        top_shap_values = pd.merge(top_n_neg, top_n_pos, how='outer')

        # add frequency to name of entity
        top_shap_values['entity'] = top_shap_values['entity'].apply(
            lambda x: x + f" ({freq_dict[entity_type][x]})"
        )
    
        top_shap_values.plot(x='entity', y='shap_value', kind='barh', figsize=(20, 15))

    return top_shap_values

: 

# Aggregations (global explanation)

## Token-level

### Overall avg

In [None]:
top_tokens_overall = visualize_shap_results('agg-sv_bert_overall-avg_token_test.json', n_features=15, min_freq=20)
#top_tokens_overall

: 

### Per class avg

In [None]:
top_tokens_class = visualize_shap_results('agg-sv_bert_class-avg_token_test.json', n_features=15, min_freq=60)
#top_tokens_class

: 

## 3gram-level

### Overall avg

In [None]:
top_3grams_overall = visualize_shap_results('agg-sv_bert_overall-avg_3gram_test.json', n_features=15, min_freq=10)

: 

### Per class avg

In [None]:
top_3grams_class = visualize_shap_results('agg-sv_bert_class-avg_3gram_test.json', n_features=15, min_freq=40)
#top_3grams_class

: 

## 5gram-level

### Overall avg

In [None]:
top_5grams_overall = visualize_shap_results('agg-sv_bert_overall-avg_5gram_test.json', n_features=15, min_freq=10)

: 

### Per class avg

In [None]:
top_5grams_class = visualize_shap_results('agg-sv_bert_class-avg_5gram_test.json', n_features=15, min_freq=20)

: 

## Chunck-level

### Overall avg

In [None]:
top_chunks_overall = visualize_shap_results('agg-sv_bert_overall-avg_chunk_test.json', n_features=15, min_freq=30)
#top_chunks_overall

: 

### Per class avg

In [None]:
top_chunks_class = visualize_shap_results('agg-sv_bert_class-avg_chunk_test.json', n_features=15, min_freq=40)
#top_chunks_class

: 

In [None]:
top_chunks2_class = visualize_shap_results('agg-sv_bert_class-avg_chunk_test.json', n_features=15, min_freq=5, min_chunk_size=3)
#top_chunks2_class

: 

# Instance-level (local explanation)

to do...