# evaluate grafting effectiveness across a range of TREC contexts

In [None]:
import pickle
with open("./neural_chunk_dictionary/neural_chunk_dictionary.pkl", "rb") as file:
    neural_chunk_dictionary = pickle.load(file)

In [None]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("trec")

# Show coarse label categories
coarse_labels = dataset["train"].features['coarse_label'].names
print("Coarse label categories:")
for i, label in enumerate(coarse_labels):
    print(f"{i}: {label}")

# Show fine-grained label categories
fine_labels = dataset["train"].features["fine_label"].names
print("\nFine-grained label categories:")
for i, label in enumerate(fine_labels):
    print(f"{i}: {label}")

# Label	Name	Meaning / Question Type Examples
# 0	ABBR	Abbreviation – Questions asking about acronyms or abbreviations.
# "What does HTML stand for?"
# 1	DESC	Description / Definition – Asking for definitions, explanations, or descriptions.
# "What is photosynthesis?"
# 2	ENTY	Entity – Questions asking about a thing or object (e.g., color, currency, food).
# "What is the capital of France's currency?"
# 3	HUM	Human – Asking about people or groups.
# "Who discovered America?"
# 4	LOC	Location – Questions about places, countries, cities, etc.
# "Where is the Eiffel Tower?"
# 5	NUM	Numeric – Questions that expect a number as an answer (e.g., date, size, price).
# "How many people live in Japan?"


In [None]:
n_it = 100
batched_prompts = []
batched_labels = []
for i in range(0, n_it):
    batched_prompts.append(dataset["train"][i]['text'])
    batched_labels.append(dataset["train"][i]['coarse_label']) # [1,2,3,4,5]

In [3]:
from huggingface_hub import login
import transformers
import torch
import random
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

login(token="YOURTOKENHERE")

# Set seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def load_model(model_id="meta-llama/Meta-Llama-3-8B", device="cuda"):
    # Load the model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    return model, tokenizer


# Set the seed
set_seed(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# load llama model and the associated tokenizer
model, tokenizer = load_model(device=device)
tokenizer.padding_side = "left"


Loading checkpoint shards: 100%|██████████████████| 4/4 [00:00<00:00,  5.68it/s]


In [4]:
def get_model_response_batch(input_texts):
    """
    input_texts: List[str] – a list of prompts
    returns: List[str] – a list of generated outputs (one per input)
    """
    assert isinstance(input_texts, list), "input_texts should be a list of strings"
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Tokenize with padding
    input_tokens = tokenizer(
        input_texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    ).to(device)



    with torch.no_grad():
        output = model.generate(
            input_ids=input_tokens['input_ids'],
            attention_mask=input_tokens['attention_mask'],
            max_length=200,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            top_p=0.5,
            use_cache=True
        )

    # Decode batch
    decoded_outputs = tokenizer.batch_decode(output, skip_special_tokens=True)
    return decoded_outputs


def load_dictionary_and_perturb_batch(word='cake', token_idx=6, input_texts=None):
    """
    input_texts: List[str]
    Returns: List[str] of generated sentences after perturbation
    """
    assert isinstance(input_texts, list), "input_texts should be a list of strings"
    model.eval()

    # Tokenize batch
    input_tokens = tokenizer(
        input_texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    ).to(device)

    attention_mask = input_tokens["attention_mask"]
    batch_size, seq_len = input_tokens["input_ids"].shape

    # Hook factory
    def create_modify_input_hook(constant_values=None, constant_positions=None, token_idx=3):
        def modify_input_hook(module, input):
            modified_input = input[0].clone()  # shape = [B, T, D]
            if modified_input.shape[1] > token_idx:
                # modified_input[:, token_idx, constant_positions] = constant_values
                modified_input[:, token_idx-3:token_idx, constant_positions] = constant_values

            return (modified_input,)
        return modify_input_hook

    # Register hooks for each layer
    hooks = []
    if 'layer' in neural_chunk_dictionary[word]:
        layerchunk = neural_chunk_dictionary[word]['layer']
    else:
        layerchunk = neural_chunk_dictionary[word][0]['layer']# sometimes an extra timestep parameter 

    for layer_to_perturb in range(2, 10):
        constant_values = layerchunk[layer_to_perturb-1]['constant_values']
        constant_positions = layerchunk[layer_to_perturb-1]['constant_positions'][:, 1]

        constant_positions_tensor = torch.tensor(constant_positions, dtype=torch.int64, device=device)
        constant_values_tensor = torch.tensor(constant_values, dtype=torch.bfloat16, device=device)

        hook = model.model.layers[layer_to_perturb].register_forward_pre_hook(
            create_modify_input_hook(
                constant_values=constant_values_tensor,
                constant_positions=constant_positions_tensor,
                token_idx=token_idx
            )
        )
        hooks.append(hook)

    # Generate output with perturbation
    with torch.no_grad():
        output_after_perturb = model.generate(
            input_ids=input_tokens['input_ids'],
            attention_mask=attention_mask,
            max_length=200,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            top_p=0.9,
            use_cache=True
        )

    # Clean up hooks
    for hook in hooks:
        hook.remove()

    # Decode batch
    return tokenizer.batch_decode(output_after_perturb, skip_special_tokens=True)




def load_dictionary_and_freeze_batch(word='cake', token_idx=6, input_texts=None):
    """
    input_texts: List[str]
    Freezes specific dimensions of activations to zero at token_idx across layers 2–9.
    Returns: List[str] – generated outputs after freezing
    """
    assert isinstance(input_texts, list), "input_texts must be a list of strings"
    model.eval()

    # Tokenize batch
    input_tokens = tokenizer(
        input_texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    ).to(device)

    attention_mask = input_tokens["attention_mask"]

    # Load dictionary
    with open("./neural_chunk_dictionary/neural_chunk_dictionary.pkl", "rb") as file:
        neural_chunk_dictionary = pickle.load(file)

    # Hook factory to zero out chunk dimensions
    def create_freeze_hook(constant_positions=None, token_idx=3):
        def modify_input_hook(module, input):
            modified_input = input[0].clone()  # [B, T, D]
            if modified_input.shape[1] > token_idx:
                modified_input[:, token_idx-3:token_idx, constant_positions] = 0  # freeze those dims
            return (modified_input,)
        return modify_input_hook

    # Register hooks for layers 2 to 9
    hooks = []
    for layer_to_perturb in range(2, 10):
        constant_positions = neural_chunk_dictionary[word]['layer'][layer_to_perturb - 1]['constant_positions'][:, 1]
        constant_positions_tensor = torch.tensor(constant_positions, dtype=torch.int64, device=device)
        hook = model.model.layers[layer_to_perturb].register_forward_pre_hook(
            create_freeze_hook(constant_positions=constant_positions_tensor, token_idx=token_idx)
        )
        hooks.append(hook)

    # Generate with freezing
    with torch.no_grad():
        output_after = model.generate(
            input_ids=input_tokens['input_ids'],
            attention_mask=attention_mask,
            max_length=200,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            top_p=0.5
        )

    for hook in hooks:
        hook.remove()

    return tokenizer.batch_decode(output_after, skip_special_tokens=True)


In [None]:
import re
import pickle
import numpy as np
import pandas as pd
from datasets import load_dataset
from collections import defaultdict
import random

# === Load TREC dataset ===
dataset = load_dataset("trec", split="train")
coarse_label_names = dataset.features["coarse_label"].names
label_ids = dataset["coarse_label"]
questions = dataset["text"]

# === Build category → prompt mapping ===
category_to_prompts = defaultdict(list)
for text, label_id in zip(questions, label_ids):
    category = coarse_label_names[label_id]
    category_to_prompts[category].append(text)

#print(category_to_prompts)
# === Sampling parameters ===
n_it = 5  # number of prompts per category
categories = list(category_to_prompts.keys())

# === Load chunk dictionary ===
with open("./neural_chunk_dictionary.pkl", "rb") as file:
    neural_chunk_dictionary = pickle.load(file)

words = list(neural_chunk_dictionary.keys())

# === Run per-category perturbation analysis ===
results = []

for category in categories:
    print(f"\n--- Processing category: {category} ---")
    batched_prompts = random.sample(category_to_prompts[category], min(n_it, len(category_to_prompts[category])))

    for word in words:
        if isinstance(word, str):
            print(f'>>> Processing word: {word}')

            # Get control outputs
            control_outputs = get_model_response_batch(batched_prompts)

            # Perturbed outputs
            perturb_outputs = load_dictionary_and_perturb_batch(
                word=word, token_idx=-1, input_texts=batched_prompts
            )

            # Match pattern
            pattern = re.compile(rf"\b{re.escape(word)}\b", re.IGNORECASE)
            n_control = sum(bool(pattern.search(out)) for out in control_outputs)
            n_perturb = sum(bool(pattern.search(out)) for out in perturb_outputs)

            # Compute probabilities
            p_control = n_control / len(batched_prompts)
            p_perturb = n_perturb / len(batched_prompts)

            # Try to get chunk category
            chunk_category = neural_chunk_dictionary[word].get("category", "Unknown")

            # Save to results list
            results.append({
                "Target Word": word,
                "TREC_Category": category,
                "Condition": "Without Perturbation (Control)",
                "Occurrence Probability": p_control,
                "Chunk Category": chunk_category,
            })
            results.append({
                "Target Word": word,
                "TREC_Category": category,
                "Condition": "With Perturbation",
                "Occurrence Probability": p_perturb,
                "Chunk Category": chunk_category,
            })

# === Save to DataFrame ===
df = pd.DataFrame(results)
df.to_csv("perturbation_effectiveness_by_trec_category.csv", index=False)


In [None]:
# Load the CSV
df = pd.read_csv("perturbation_effectiveness_by_trec_category_d=e_niter=50.csv")

# Show all rows
pd.set_option('display.max_rows', None)

# Show all columns
pd.set_option('display.max_columns', None)

# Optionally, prevent line wrapping
pd.set_option('display.width', None)

df

In [12]:
import pandas as pd

# Group and compute mean and std
grouped = (
    df.groupby(["TREC_Category", "Condition"])["Occurrence Probability"]
    .agg(["mean", "std"])
    .reset_index()
)

# Create "mean ± std" strings
grouped["Mean ± Std"] = grouped.apply(
    lambda row: f"{row['mean']:.3f} ± {row['std']:.3f}", axis=1
)

# Pivot table to get conditions as columns
pivoted = grouped.pivot(index="TREC_Category", columns="Condition", values="Mean ± Std")

# Reset index to make it look nice
pivoted = pivoted.reset_index()
pivoted.columns.name = None  # remove pandas-generated name
pivoted
display(pivoted.style.hide(axis="index"))



TREC_Category,With Perturbation,Without Perturbation (Control)
ABBR,0.559 ± 0.324,0.149 ± 0.246
DESC,0.490 ± 0.268,0.156 ± 0.215
ENTY,0.481 ± 0.286,0.126 ± 0.218
HUM,0.467 ± 0.286,0.119 ± 0.200
LOC,0.475 ± 0.286,0.107 ± 0.194
NUM,0.453 ± 0.270,0.115 ± 0.184


In [16]:
df = pd.read_csv("perturbation_effectiveness_by_trec_category_d=m_niter=50.csv")


# Group and compute mean and std
grouped = (
    df.groupby(["TREC_Category", "Condition"])["Occurrence Probability"]
    .agg(["mean", "std"])
    .reset_index()
)

# Create "mean ± std" strings
grouped["Mean ± Std"] = grouped.apply(
    lambda row: f"{row['mean']:.3f} ± {row['std']:.3f}", axis=1
)

# Pivot table to get conditions as columns
pivoted = grouped.pivot(index="TREC_Category", columns="Condition", values="Mean ± Std")

# Reset index to make it look nice
pivoted = pivoted.reset_index()
pivoted.columns.name = None  # remove pandas-generated name
pivoted
display(pivoted.style.hide(axis="index"))


TREC_Category,With Perturbation,Without Perturbation (Control)
ABBR,0.308 ± 0.269,0.149 ± 0.246
DESC,0.281 ± 0.224,0.156 ± 0.215
ENTY,0.225 ± 0.241,0.126 ± 0.218
HUM,0.215 ± 0.238,0.119 ± 0.200
LOC,0.205 ± 0.224,0.107 ± 0.194
NUM,0.218 ± 0.214,0.115 ± 0.184


In [None]:
import pandas as pd

# Load the CSV
df = pd.read_csv("perturbation_effectiveness_by_trec_category_d=e_niter=50.csv")

# Group by Target Word and Condition, averaging across TREC categories
grouped = (
    df.groupby(["Target Word", "Condition"])["Occurrence Probability"]
    .mean()
    .reset_index()
)

# Sort alphabetically by Target Word
grouped = grouped.sort_values(by="Target Word")

# Display result
grouped

In [None]:
import pandas as pd

# Load the CSV
df = pd.read_csv("perturbation_effectiveness_by_trec_category_d=e_niter=50.csv")

# Group by Target Word and Condition, averaging across TREC categories
grouped = (
    df.groupby(["Target Word", "Condition"])["Occurrence Probability"]
    .mean()
    .reset_index()
)

# Pivot so each Target Word has one row with both conditions as columns
pivoted = grouped.pivot(index="Target Word", columns="Condition", values="Occurrence Probability")

# Optional: clean up column names
pivoted = pivoted.reset_index()
pivoted.columns.name = None  # remove index name

# Sort alphabetically by Target Word
pivoted = pivoted.sort_values(by="Target Word")

# Show result
pivoted


In [None]:
import pandas as pd

# Load the CSV
df = pd.read_csv("perturbation_effectiveness_by_trec_category_d=e_niter=50.csv")

# Group by Target Word and Condition, averaging across TREC categories
grouped = (
    df.groupby(["Target Word", "Condition"])["Occurrence Probability"]
    .mean()
    .reset_index()
)

# Pivot so each Target Word has one row with both conditions as columns
pivoted = grouped.pivot(index="Target Word", columns="Condition", values="Occurrence Probability")
pivoted.columns.name = None  # clean column names
pivoted = pivoted.reset_index()

# Calculate delta
pivoted["Delta (Perturb - Control)"] = (
    pivoted["With Perturbation"] - pivoted["Without Perturbation (Control)"]
)

# Sort alphabetically by Target Word
pivoted = pivoted.sort_values(by="Delta (Perturb - Control)", ascending=False)

# Display result
pivoted
