In [None]:
from datasets import load_dataset
from token_shap import TokenSHAP
from nltk.corpus import words
from termcolor import colored
import random
import nltk

In [None]:
nltk.download('words')

In [None]:
def inject_random_words(prompts, injection_rate=(0.2, 0.3)):
    word_list = words.words()
    injected_prompts = []
    dict_injected = {}
    for prompt in prompts:
        words_in_prompt = prompt.split()
        num_injections = int(len(words_in_prompt) * random.uniform(*injection_rate))
        injection_indices = random.sample(range(len(words_in_prompt) + 1), num_injections)
        random_words = []
        for index in sorted(injection_indices, reverse=True):
            random_word = random.choice(word_list)
            words_in_prompt.insert(index, random_word)
            random_words.append(random_word)
        injected_prompts.append(' '.join(words_in_prompt))
        dict_injected[prompt] = random_words
    return injected_prompts, dict_injected

def color_injected_words(original_prompts, injected_prompts, n):
    for _ in range(n):
        idx = random.randint(0, len(original_prompts) - 1)
        original_words = set(original_prompts[idx].split())
        injected_words = injected_prompts[idx].split()
        
        colored_prompt = []
        for word in injected_words:
            if word not in original_words:
                colored_prompt.append(colored(word, 'red'))
            else:
                colored_prompt.append(word)
        
        print(' '.join(colored_prompt))

In [None]:
ds = load_dataset("tatsu-lab/alpaca")

In [None]:
prompts = random.sample(ds['train']['instruction'], 100)

In [None]:
injected_prompts, dict_injected = inject_random_words(prompts)

In [None]:
color_injected_words(prompts, injected_prompts, 3)

In [None]:
%%time
# Initialize TokenSHAP with your model & tokenizer
model_name = "llama3"
tshap = TokenSHAP(model_name, tokenizer_path = "NousResearch/Hermes-2-Theta-Llama-3-8B")
original_shap_values = {}
for prompt in prompts:
    results = tshap.analyze(prompt, sampling_ratio = 0.2, splitter = ' ')
    original_shap_values[prompt] = tshap.shapley_values

In [None]:
%%time
# Initialize TokenSHAP with your model & tokenizer
model_name = "llama3"
tshap = TokenSHAP(model_name, tokenizer_path = "NousResearch/Hermes-2-Theta-Llama-3-8B")
injected_shap_values = {}
for prompt in injected_prompts:
    results = tshap.analyze(prompt, sampling_ratio = 0.2, splitter = ' ')
    injected_shap_values[prompt] = tshap.shapley_values

In [None]:
from collections import defaultdict
import numpy as np

all_words = defaultdict(list)
for prompt_dict in original_shap_values.values():
    for word, value in prompt_dict.items():
        all_words[word].append(value)
for prompt_dict in injected_shap_values.values():
    for word, value in prompt_dict.items():
        all_words[word].append(value)

word_shap =  {word: np.mean(values) for word, values in all_words.items()}
word_shap

In [None]:
import pandas as pd
word_freq = defaultdict(int)
for prompt in prompts + injected_prompts:
    for word in prompt.split():
        word_freq[word] += 1

injected_words = set([word for words in dict_injected.values() for word in words])
word_correlation = {}
for word in word_shap.keys():
    in_injected = sum(1 for prompt in injected_prompts if word in prompt.split())
    in_original = sum(1 for prompt in prompts if word in prompt.split())
    word_correlation[word] = (in_injected / len(injected_prompts)) - (in_original / len(prompts))

results = pd.DataFrame({
    'word': list(word_shap.keys()),
    'shap_value': list(word_shap.values()),
    'correlation': [word_correlation.get(word, 0) for word in word_shap.keys()],
    'frequency': [word_freq.get(word, 0) for word in word_shap.keys()],
    'is_injected': [word in injected_words for word in word_shap.keys()]
})

results

In [None]:
low_importance_words = set(results[results['shap_value'] < low_importance_threshold]['word'])
injected_words = set([word for words in dict_injected.values() for word in words])

low_importance_injected = low_importance_words.intersection(injected_words)
print(f"Number of low importance words that are also injected: {len(low_importance_injected)}")
print(f"Percentage of injected words that are low importance: {len(low_importance_injected) / len(injected_words) * 100:.2f}%")

correlation_matrix = np.corrcoef(results['shap_value'], results['is_injected'])
print(f"Correlation coefficient between SHAP values and being an injected word: {correlation_matrix[0, 1]:.4f}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
low_importance_threshold=0.1
low_importance_words = results[results['shap_value'] < low_importance_threshold]
    
plt.figure(figsize=(12, 8))
scatter = plt.scatter(low_importance_words['shap_value'], 
                      low_importance_words['correlation'],
                      c=low_importance_words['frequency'], 
                      cmap='viridis', 
                      s=low_importance_words['frequency'], 
                      alpha=0.6)
plt.colorbar(scatter, label='Frequency')
plt.xlabel('SHAP Value')
plt.ylabel('Correlation with Injected Words')
plt.title('Low Importance Words: SHAP Value vs. Correlation with Injected Words')

for _, row in low_importance_words.nlargest(10, 'correlation').iterrows():
    plt.annotate(row['word'], (row['shap_value'], row['correlation']))

plt.tight_layout()
plt.show()

heatmap_data = top_low_importance[['shap_value', 'correlation', 'frequency']].astype(float)
heatmap_data['is_injected'] = top_low_importance['is_injected'].astype(int)

plt.figure(figsize=(12, 8))
sns.heatmap(heatmap_data.set_index(top_low_importance['word']), 
            annot=True, cmap='YlOrRd', fmt='.2f')
plt.title('Top 20 Low Importance Words by Correlation with Injected Words')
plt.tight_layout()
plt.show()