In [1]:
import pyvene as pv
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

In [2]:
import os
import numpy as np
import pandas as pd
from copy import deepcopy
from tqdm.notebook import tqdm
from datasets import load_dataset

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    theme,
    element_text,
    ggtitle,
    xlab, ylab, ggsave,
    facet_wrap
)
from plotnine.scales import scale_y_reverse, scale_fill_cmap

In [4]:
def embed_to_distrib(model, embed, log=False, logits=False):

    vocab = pv.embed_to_distrib(model, embed, logits=logits)
    return vocab

# factual recall
def factual_recall(text, model, tokenizer):
    print(text)
    base = deepcopy(text)
    inputs = [
        tokenizer(base, return_tensors="pt").to(device),
    ]
    base = deepcopy(text)
    for i in range(len(inputs)):
        res = model(**inputs[i], output_hidden_states=True)
        last = res.last_hidden_state if hasattr(res, 'last_hidden_state') else res.hidden_states[-1]
        distrib = embed_to_distrib(model, last, logits=False)
        pv.top_vals(tokenizer, distrib[0][-1], n=10)
        
def create_noise_intervention(seq_len): 
    class NoiseIntervention(pv.ConstantSourceIntervention, pv.LocalistRepresentationIntervention):
        def __init__(self, embed_dim, **kwargs):
            super().__init__()
            self.interchange_dim = embed_dim
            self.seq_len = kwargs.get("seq_len", 1)
            rs = np.random.RandomState(1)
            prng = lambda *shape: rs.randn(*shape)
            self.noise = torch.from_numpy(
                prng(1, self.seq_len, embed_dim)).to(device)
            self.noise_level = 0.13462981581687927

        def forward(self, base, source=None, subspaces=None): 
            # source argument is ignored unlike in causal interventions, since we are adding noise without reference to any other input
            base[..., : self.interchange_dim] += self.noise * self.noise_level
            return base

        def __str__(self):
            return f"NoiseIntervention(embed_dim={self.embed_dim}, seq_len={self.seq_len})"

    return NoiseIntervention


def corrupted_config(model_type, seq_len):
    config = pv.IntervenableConfig(
        model_type=model_type,
        representations=[
            pv.RepresentationConfig(
                0,              # layer
                "block_input",  # intervention type
            ),
        ],
        intervention_types=[create_noise_intervention(seq_len)],
    )
    return config

def corrupted_recall(text, model, tokenizer, subject_pos):
    print(text)
    base = tokenizer(deepcopy(text), return_tensors="pt").to(device)
    seq_len = len(subject_pos)
    config = corrupted_config(type(model), seq_len)
    intervenable = pv.IntervenableModel(config, model)

    _, counterfactual_outputs = intervenable(
        base, unit_locations={"base": ([[subject_pos]])}
    )
    last = counterfactual_outputs.last_hidden_state if hasattr(counterfactual_outputs, 'last_hidden_state') else counterfactual_outputs.hidden_states[-1]
    distrib = embed_to_distrib(model, last, logits=False)
    pv.top_vals(tokenizer, distrib[0][-1], n=10)
    
def format_tokens(tokenizer, tokens):
    return [tokenizer.decode(tok).replace("\n", "\\n") for tok in tokens]

# restored run - corrupt input in some position, then restore the hidden state at a particular layer for some positions
def restore_corrupted_with_interval_config(
        layer, 
        stream="mlp_activation", 
        window=10, 
        num_layers=48,
        seq_len=1
    ):
    start = max(0, layer - window // 2)
    end = min(num_layers, layer - (-window // 2))
    config = pv.IntervenableConfig(
        representations=[
            pv.RepresentationConfig(
                0,       # layer
                "block_input",  # intervention type
            ),
        ] + [
            pv.RepresentationConfig(
                i,       # layer
                stream,  # intervention type
        ) for i in range(start, end)],
        intervention_types=\
            [create_noise_intervention(seq_len)]+[pv.VanillaIntervention]*(end-start),
        seq_len=seq_len,
    )
    return config

# corrupt all layers and positions 0, 1, 2, 3 ("The Space Needle", i.e. the subject of the fact) and restore at a single position at every layer
def restored_run(text, model, tokenizer, subject_pos, save_path):
    base = tokenizer(deepcopy(text), return_tensors="pt").to(device)
    yes_token = tokenizer.convert_tokens_to_ids("Yes")
    no_token = tokenizer.convert_tokens_to_ids("No")
    seq_len = len(subject_pos)
    
    for stream in ["block_output", "mlp_activation", "attention_output"]:
        data = []
        n_layers = model.config.n_layer if hasattr(model.config, "n_layer") else model.config.num_hidden_layers
        for layer_i in tqdm(range(n_layers)):
            for pos_i in range(len(base.input_ids[0])):
                config = restore_corrupted_with_interval_config(
                    layer_i, stream, 
                    window=1 if stream == "block_output" else 10,
                    num_layers=n_layers,
                    seq_len=seq_len
                )
                n_restores = len(config.representations) - 1
                intervenable = pv.IntervenableModel(config, model)
                _, counterfactual_outputs = intervenable(
                    base,
                    [None] + [base]*n_restores,
                    {
                        "sources->base": (
                            [None] + [[[pos_i]]]*n_restores,
                            [[subject_pos]] + [[[pos_i]]]*n_restores,
                        )
                    },
                )
                last = counterfactual_outputs.last_hidden_state if hasattr(counterfactual_outputs, 'last_hidden_state') else counterfactual_outputs.hidden_states[-1]
                distrib = embed_to_distrib(
                    model, last, logits=False
                )
                yes_prob = distrib[0][-1][yes_token].detach().cpu().item()
                no_prob = distrib[0][-1][no_token].detach().cpu().item()
                data.append({"layer": layer_i, "pos": pos_i, "yes_prob": yes_prob, "no_prob": no_prob})
        df = pd.DataFrame(data)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        df.to_csv(f"{save_path}/pyvene_rome_{stream}.csv")
        
def plot_activations(labels, breaks, colors, titles, path):
    for stream in ["block_output", "mlp_activation", "attention_output"]:
        df = pd.read_csv(f"{path}/pyvene_rome_{stream}.csv")
        df["layer"] = df["layer"].astype(int)
        df["pos"] = df["pos"].astype(int)
        df["p(Yes)"] = df["yes_prob"].astype(float)
        df["p(No)"] = df["no_prob"].astype(float)

        df_melted = df.melt(
            id_vars=["layer", "pos"],
            value_vars=["p(Yes)", "p(No)"],
            var_name="Token",
            value_name="Probability"
        )

        # Create the plot with facets
        plot = (
            ggplot(df_melted, aes(x="layer", y="pos"))
            + facet_wrap("~Token", scales="free_y")  # Side-by-side subplots for "p(Yes)" and "p(No)"
            + geom_tile(aes(fill="Probability"))
            + scale_fill_cmap(colors[stream])
            + xlab(titles[stream])
            + scale_y_reverse(
                limits=(-0.5, len(labels) - 0.5), 
                breaks=breaks, labels=labels
            )
            + ggtitle(f"{stream} Activations: p(Yes) and p(No)")
            + theme(figure_size=(10, 5))  # Adjust figure size for side-by-side plots
            + ylab("") 
            + theme(axis_text_y=element_text(angle=90, hjust=1))
        )
        display(plot)

# Load dataset


In [5]:
def load_boolq_dataset():
    dataset = load_dataset("boolq")
    df = pd.DataFrame({
        'question': dataset['train']['question'],
        'answer': dataset['train']['answer'],
        'passage': dataset['train']['passage']
    })
    return df

In [6]:
df = load_boolq_dataset()
df.head()

Unnamed: 0,question,answer,passage
0,do iran and afghanistan speak the same language,True,"Persian (/ˈpɜːrʒən, -ʃən/), also known by its ..."
1,do good samaritan laws protect those who help ...,True,Good Samaritan laws offer legal protection to ...
2,is windows movie maker part of windows essentials,True,Windows Movie Maker (formerly known as Windows...
3,is confectionary sugar the same as powdered sugar,True,"Powdered sugar, also called confectioners' sug..."
4,is elder scrolls online the same as skyrim,False,As with other games in The Elder Scrolls serie...


In [7]:
os.makedirs("output", exist_ok=True)

## GPT-2


In [8]:
# init GPT2-XL model
config, tokenizer, gpt = pv.create_gpt2(name="gpt2-xl")
gpt.to(device)

loaded model


GPT2Model(
  (wte): Embedding(50257, 1600)
  (wpe): Embedding(1024, 1600)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-47): 48 x GPT2Block(
      (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=4800, nx=1600)
        (c_proj): Conv1D(nf=1600, nx=1600)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=6400, nx=1600)
        (c_proj): Conv1D(nf=1600, nx=6400)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
)

In [9]:
def prepare_input_for_model(text, tokenizer, device=device):
    """Prepare input text for model processing"""
    inputs = tokenizer(text, return_tensors="pt")
    return {k: v.to(device) for k, v in inputs.items()}


def print_token_info(model, tokenizer, logits=None, input_ids=None, k=5, prefix="", only_if_changed=False, last_probs=None):
    """Print token information - either tokenized input or top k probabilities"""
    if logits is not None:
        # Print top k tokens and probabilities
        probs = torch.nn.functional.softmax(logits, dim=-1)
        top_k = torch.topk(probs, k)
        
        if only_if_changed and last_probs is not None:
            # Check if probabilities changed significantly
            if torch.allclose(probs, last_probs, rtol=1e-3):
                return probs
        
        print(f"\n{prefix}")
        for i, (prob, idx) in enumerate(zip(top_k.values, top_k.indices)):
            token = tokenizer.decode(idx)
            print(f"{i+1}. '{token}': {prob:.4f}")
        return probs
    
    elif input_ids is not None:
        # Print tokenized input information
        print("\nTokens:")
        for token in input_ids['input_ids'][0]:
            print(f"'{tokenizer.decode(token)}'")
    else:
        raise ValueError("Either input_ids or logits must be provided")

In [10]:
def process_boolq_example(question, model, tokenizer, device=device):
    """Process a single BoolQ example with debug information"""
    # Format question with proper prompt
    input_text = f"Question: {question} Answer:"
    print(f"\nDebug: Input text: {input_text}")
    
    input_ids = prepare_input_for_model(input_text, tokenizer, device)
    print("\nDebug: Tokenized input:")
    print_token_info(model, tokenizer, input_ids=input_ids)
    
    return input_text

In [11]:
def prepare_sample(sample):
    text = process_boolq_example(sample["question"], gpt, tokenizer)
    text_token = " Yes" if sample["answer"] is True else " No"
    return text, text_token

In [12]:
text, text_token = prepare_sample(df.iloc[2])


Debug: Input text: Question: is windows movie maker part of windows essentials Answer:

Debug: Tokenized input:

Tokens:
'Question'
':'
' is'
' windows'
' movie'
' maker'
' part'
' of'
' windows'
' essentials'
' Answer'
':'


In [13]:
print('GT: ', text_token)
factual_recall(text, gpt, tokenizer)

GT:   No
Question: is windows movie maker part of windows essentials Answer:
_yes                 0.13121595978736877
_No                  0.11776792258024216
\n                   0.10949345678091049
_Yes                 0.09982707351446152
_no                  0.08984081447124481
_Windows             0.04195360094308853
_It                  0.03211362287402153
_it                  0.022324759513139725
_windows             0.021807190030813217
_The                 0.01462388876825571


In [14]:
formated_tokens = format_tokens(tokenizer, tokenizer(text)["input_ids"])
print(formated_tokens)

['Question', ':', ' is', ' windows', ' movie', ' maker', ' part', ' of', ' windows', ' essentials', ' Answer', ':']


In [15]:
def get_subject_tokens_from_question(question, tokenizer):
    """Extract subject tokens from the question"""
    # Tokenize the question with the prompt
    full_text = f"Question: {question}? Answer: "
    tokens = tokenizer.encode(full_text)
    token_words = [tokenizer.decode(t).strip().lower() for t in tokens]
    
    # Question words to look for
    question_words = ["what", "is", "are", "does", "do", "has", "can", "did", "will", "was", "were"]
    
    # Find the position after question word and take exactly 3 tokens
    subject_positions = []
    for i, word in enumerate(token_words):
        if word in question_words and i + 3 < len(token_words):
            # Always take exactly 3 tokens as subject
            subject_positions = list(range(i + 1, i + 4))
            break
    
    # If no question word found or not enough tokens, take first 3 non-punctuation tokens
    if not subject_positions:
        count = 0
        for i, word in enumerate(token_words):
            if word not in ['question', ':', 'answer'] and count < 3:
                subject_positions.append(i)
                count += 1
            if count == 3:
                break
    
    return subject_positions

In [19]:
print(text)
subject_positions = get_subject_tokens_from_question(text, tokenizer)
print(subject_positions)
# print the subject words using formated_tokens
subject_words = [formated_tokens[pos] for pos in subject_positions]
print(subject_words)


corrupted_recall(text, gpt, tokenizer, subject_positions)

Question: is windows movie maker part of windows essentials Answer:
[5, 6, 7]
[' maker', ' part', ' of']
Question: is windows movie maker part of windows essentials Answer:
\n                   0.1348111629486084
_yes                 0.0988793894648552
_Yes                 0.07032006233930588
_windows             0.06847503036260605
_no                  0.06502274423837662
_Windows             0.05554073303937912
_No                  0.05104365199804306
_it                  0.01944822631776333
_It                  0.019329600036144257
\n\n                 0.018422432243824005


In [None]:
# path = 'output/gpt2xl'
# restored_run(text, gpt, tokenizer, subject_positions, path)

In [None]:
# titles={
#     "block_output": "single restored layer in GPT2-XL",
#     "mlp_activation": "center of interval of 10 patched mlp layer",
#     "attention_output": "center of interval of 10 patched attn layer"
# }

# colors={
#     "block_output": "Purples",
#     "mlp_activation": "Greens",
#     "attention_output": "Reds"
# } 
        
# labels = ['Is','windows','movie','maker','part','of','windows','essentials']
# breaks = list(range(0, len(labels)))
# plot_activations(labels, breaks, colors, titles, path)

In [None]:
def process_text(text, model, tokenizer, output_prefix):
    text, text_token = prepare_sample(df.iloc[2])
    print('GT: ', text_token)
    factual_recall(text, model, tokenizer)
    subject_positions = get_subject_tokens_from_question(text, tokenizer)
    
    path = f'{output_prefix}'
    restored_run(text, model, tokenizer, subject_positions, path) 

    titles={
        "block_output": "single restored layer in GPT2-XL",
        "mlp_activation": "center of interval of 10 patched mlp layer",
        "attention_output": "center of interval of 10 patched attn layer"
    }

    colors={
        "block_output": "Purples",
        "mlp_activation": "Greens",
        "attention_output": "Reds"
    }

    labels = format_tokens(tokenizer, tokenizer(text)["input_ids"])
    breaks = list(range(0, len(labels)))
    plot_activations(labels, breaks, colors, titles, path)
    

In [None]:
def process_multiple_examples(df, model, tokenizer, num_examples=10, random_seed=42, output_prefix = 'output/boolq_example'):
    """Process multiple examples from the BoolQ dataset"""
    np.random.seed(random_seed)
    selected_indices = np.random.choice(len(df), num_examples, replace=False)
    
    results = []
    for idx in selected_indices:
        example = df.iloc[idx]
        print(f"\nProcessing example {idx}:")
        print(f"Question: {example['question']}")
        print(f"Answer: {example['answer']}")
        
        output_prefix_idx = f"{output_prefix}_{idx}"
        
        process_text(example['question'], model, tokenizer, output_prefix_idx)
        
        results.append({
            'idx': idx,
            'question': example['question'],
            'answer': example['answer'],
            'output_prefix': output_prefix
        })
    
    # Save results summary
    pd.DataFrame(results).to_csv("output/processed_examples.csv", index=False)

In [None]:
process_multiple_examples(df, gpt, tokenizer, num_examples=2, random_seed=42, output_prefix = 'output/gpt2xl')

In [None]:
torch.cuda.empty_cache()

## GPT2-XL BoolQ


In [17]:
from huggingface_hub import login

hf_token = nothing to see here

login(token=hf_token)

In [None]:
# init GPT-2 BoolQ model
config, tokenizer, gpt_boolq = pv.create_gpt2(name='utahnlp/boolq_gpt2-xl_seed-1')
gpt_boolq.to(device);

In [None]:
print('GT: ', text_token)
factual_recall(text, gpt_boolq, tokenizer)

In [None]:
format_tokens(tokenizer, tokenizer(text)["input_ids"])

In [None]:
subject_positions = [6,7]

corrupted_recall(text, gpt_boolq, tokenizer, subject_positions)

In [None]:
# path = 'output/gpt2xl_boolq'
# restored_run(text, gpt_boolq, tokenizer, subject_positions, path)

In [None]:
# titles={
#     "block_output": "single restored layer in GPT2-XL BoolQ",
#     "mlp_activation": "center of interval of 10 patched mlp layer",
#     "attention_output": "center of interval of 10 patched attn layer"
# }

# colors={
#     "block_output": "Purples",
#     "mlp_activation": "Greens",
#     "attention_output": "Reds"
# } 
        
# labels = ['Is','windows','movie','maker','part','of','windows','essentials']
# breaks = list(range(0, len(labels)))
# plot_activations(labels, breaks, colors, titles, path)

In [None]:
process_multiple_examples(df, gpt_boolq, tokenizer, num_examples=2, random_seed=42, output_prefix = 'output/gpt2xl')