In [1]:
from datasets import load_dataset
from token_shap import TokenSHAP
from nltk.corpus import words
from termcolor import colored
import random
import nltk

In [2]:
nltk.download('words')

[nltk_data] Downloading package words to /home/ronig/nltk_data...
[nltk_data]   Package words is already up-to-date!


True

In [3]:
def inject_random_words(prompts, injection_rate=(0.2, 0.3)):
    word_list = words.words()
    injected_prompts = []
    dict_injected = {}
    for prompt in prompts:
        words_in_prompt = prompt.split()
        num_injections = int(len(words_in_prompt) * random.uniform(*injection_rate))
        injection_indices = random.sample(range(len(words_in_prompt) + 1), num_injections)
        random_words = []
        for index in sorted(injection_indices, reverse=True):
            random_word = random.choice(word_list)
            words_in_prompt.insert(index, random_word)
            random_words.append(random_word)
        injected_prompts.append(' '.join(words_in_prompt))
        dict_injected[prompt] = random_words
    return injected_prompts, dict_injected

def color_injected_words(original_prompts, injected_prompts, n):
    for _ in range(n):
        idx = random.randint(0, len(original_prompts) - 1)
        original_words = set(original_prompts[idx].split())
        injected_words = injected_prompts[idx].split()
        
        colored_prompt = []
        for word in injected_words:
            if word not in original_words:
                colored_prompt.append(colored(word, 'red'))
            else:
                colored_prompt.append(word)
        
        print(' '.join(colored_prompt))

In [4]:
ds = load_dataset("tatsu-lab/alpaca")

In [5]:
prompts = random.sample(ds['train']['instruction'], 100)

In [6]:
injected_prompts, dict_injected = inject_random_words(prompts)

In [7]:
color_injected_words(prompts, injected_prompts, 10)

[31mbinotic[0m Generate a list of five popular streaming subscription services. [31mfusteric[0m
How many [31mspirometric[0m countries make up the European Union?
[31mplayfulness[0m Identify the type of the function y = x^2 [31mfranticly[0m + 3
What is the legal [31mtripsill[0m principle behind copyright [31mbiting[0m law?
Explain how infectious disease [31mGraptoloidea[0m spreads
Output the name of [31mcaliduct[0m the day [31mpatriarchy[0m of the week [31mJennifer[0m for [31mretroiridian[0m a given date in MM/DD/YYYY format.
[31mdacryocystotomy[0m List 5 popular dishes in US.
Give me [31mmedically[0m a sentence to [31mheyday[0m describe the feeling of joy.
Suggest a [31mcartographic[0m suitable input to the following instruction.
Generate a question [31munhastened[0m about the immune system


In [8]:
import json
import os

# Initialize TokenSHAP with your model & tokenizer
model_name = "llama3"
tshap = TokenSHAP(model_name, tokenizer_path="NousResearch/Hermes-2-Theta-Llama-3-8B")

# Path to save SHAP values
save_path = "shap_values.json"

# Load existing SHAP values if the file exists
if os.path.exists(save_path):
    with open(save_path, 'r') as f:
        original_shap_values = json.load(f)
else:
    original_shap_values = {}

# Function to save SHAP values to disk
def save_shap_values(shap_values, save_path):
    with open(save_path, 'w') as f:
        json.dump(shap_values, f)

for prompt in prompts:
    print(prompt)
    results = tshap.analyze(prompt, sampling_ratio=0.2, splitter=' ')
    original_shap_values[prompt] = tshap.shapley_values
    save_shap_values(original_shap_values, save_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Identify the landmark built in 1060 near Athens.
Generate an animated gif with an astronaut sailing in a spaceship
Create a mnemonic device to remember the following words
Create a binary classification query which determines whether a given article talks about Covid-19.
Suggest a suitable input to the following instruction.
Compare the terms 'sublimation' and 'deposition'.
Generate a list of five popular streaming subscription services.
Deleted the second-to-last sentence of this paragraph.
Construct a for loop to count from 1 to 10.
Name two endangered species of plants and two endangered species of animals.
Write a short story about a computer that can predict the future.
Given an article, identify the main author's point of view.
Given a suitable input, generate a poem that captures the emotion of happiness.
Organize a list of tasks in chronological order.
Paraphrase the following sentence to emphasize the main idea.
Which chess piece moves in an "L" shape?
Cite three references fo

ChunkedEncodingError: Response ended prematurely

In [None]:
%%time
# Initialize TokenSHAP with your model & tokenizer
model_name = "llama3"
tshap = TokenSHAP(model_name, tokenizer_path = "NousResearch/Hermes-2-Theta-Llama-3-8B")
injected_shap_values = {}
for prompt in injected_prompts:
    print(prompt)
    results = tshap.analyze(prompt, sampling_ratio = 0.2, splitter = ' ')
    injected_shap_values[prompt] = tshap.shapley_values

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


subreason Describe the different flavors of the following ice cream: tutor


In [None]:
from collections import defaultdict
import numpy as np

all_words = defaultdict(list)
for prompt_dict in original_shap_values.values():
    for word, value in prompt_dict.items():
        all_words[word].append(value)
for prompt_dict in injected_shap_values.values():
    for word, value in prompt_dict.items():
        all_words[word].append(value)

word_shap =  {word: np.mean(values) for word, values in all_words.items()}
word_shap

In [None]:
import pandas as pd
word_freq = defaultdict(int)
for prompt in prompts + injected_prompts:
    for word in prompt.split():
        word_freq[word] += 1

injected_words = set([word for words in dict_injected.values() for word in words])
word_correlation = {}
for word in word_shap.keys():
    in_injected = sum(1 for prompt in injected_prompts if word in prompt.split())
    in_original = sum(1 for prompt in prompts if word in prompt.split())
    word_correlation[word] = (in_injected / len(injected_prompts)) - (in_original / len(prompts))

results = pd.DataFrame({
    'word': list(word_shap.keys()),
    'shap_value': list(word_shap.values()),
    'correlation': [word_correlation.get(word, 0) for word in word_shap.keys()],
    'frequency': [word_freq.get(word, 0) for word in word_shap.keys()],
    'is_injected': [word in injected_words for word in word_shap.keys()]
})

results

In [None]:
low_importance_words = set(results[results['shap_value'] < low_importance_threshold]['word'])
injected_words = set([word for words in dict_injected.values() for word in words])

low_importance_injected = low_importance_words.intersection(injected_words)
print(f"Number of low importance words that are also injected: {len(low_importance_injected)}")
print(f"Percentage of injected words that are low importance: {len(low_importance_injected) / len(injected_words) * 100:.2f}%")

correlation_matrix = np.corrcoef(results['shap_value'], results['is_injected'])
print(f"Correlation coefficient between SHAP values and being an injected word: {correlation_matrix[0, 1]:.4f}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
low_importance_threshold=0.1
low_importance_words = results[results['shap_value'] < low_importance_threshold]
    
plt.figure(figsize=(12, 8))
scatter = plt.scatter(low_importance_words['shap_value'], 
                      low_importance_words['correlation'],
                      c=low_importance_words['frequency'], 
                      cmap='viridis', 
                      s=low_importance_words['frequency'], 
                      alpha=0.6)
plt.colorbar(scatter, label='Frequency')
plt.xlabel('SHAP Value')
plt.ylabel('Correlation with Injected Words')
plt.title('Low Importance Words: SHAP Value vs. Correlation with Injected Words')

for _, row in low_importance_words.nlargest(10, 'correlation').iterrows():
    plt.annotate(row['word'], (row['shap_value'], row['correlation']))

plt.tight_layout()
plt.show()

heatmap_data = top_low_importance[['shap_value', 'correlation', 'frequency']].astype(float)
heatmap_data['is_injected'] = top_low_importance['is_injected'].astype(int)

plt.figure(figsize=(12, 8))
sns.heatmap(heatmap_data.set_index(top_low_importance['word']), 
            annot=True, cmap='YlOrRd', fmt='.2f')
plt.title('Top 20 Low Importance Words by Correlation with Injected Words')
plt.tight_layout()
plt.show()