In [14]:
import os, re, json
from tqdm import tqdm
import torch, numpy as np
import argparse
from baukit import TraceDict

In [15]:
os.sys.path.append('../src/')

In [16]:
# Include prompt creation helper functions
from utils.prompt_utils import *
from utils.intervention_utils import *
from utils.model_utils import *
from utils.extract_utils import *

In [17]:
def activation_replacement_per_class_intervention(prompt_data, avg_activations, dummy_labels, model, model_config, tokenizer, last_token_only=True):
    """
    Experiment to determine top intervention locations through avg activation replacement. 
    Performs a systematic sweep over attention heads (layer, head) to track their causal influence on probs of key tokens.

    Parameters: 
    prompt_data: dict containing ICL prompt examples, and template information
    avg_activations: avg activation of each attention head in the model taken across n_trials ICL prompts
    dummy_labels: labels and indices for a baseline prompt with the same number of example pairs
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    last_token_only: If True, only computes indirect effect for heads at the final token position. If False, computes indirect_effect for heads for all token classes

    Returns:   
    indirect_effect_storage: torch tensor containing the indirect_effect of each head for each token class.
    """
    device = model.device

    # Get sentence and token labels
    query_target_pair = prompt_data['query_target']

    query = query_target_pair['input']
    token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query=query)

    idx_map, idx_avg = compute_duplicated_labels(token_labels, dummy_labels)
    idx_map = update_idx_map(idx_map, idx_avg)
      
    sentences = [prompt_string]# * model.config.n_head # batch things by head

    # Figure out tokens of interest
    tokens_of_interest = [query_target_pair['output']]
    if 'llama' in model_config['name_or_path']:
        ts = tokenizer(tokens_of_interest, return_tensors='pt').input_ids.squeeze()
        if tokenizer.decode(ts[1])=='' or ts[1]==29871: # avoid spacing issues
            token_id_of_interest = ts[2]
        else:
            token_id_of_interest = ts[1]
    else:
        token_id_of_interest = tokenizer(tokens_of_interest).input_ids[0][:1]
        
    inputs = tokenizer(sentences, return_tensors='pt').to(device)

    # Speed up computation by only computing causal effect at last token
    if last_token_only:
        token_classes = ['query_predictive']
        token_classes_regex = ['query_predictive_token']
    # Compute causal effect for all token classes (instead of just last token)
    else:
        token_classes = ['demonstration', 'label', 'separator', 'predictive', 'structural','end_of_example', 
                        'query_demonstration', 'query_structural', 'query_separator', 'query_predictive']
        token_classes_regex = ['demonstration_[\d]{1,}_token', 'demonstration_[\d]{1,}_label_token', 'separator_token', 'predictive_token', 'structural_token','end_of_example_token', 
                            'query_demonstration_token', 'query_structural_token', 'query_separator_token', 'query_predictive_token']
    

    indirect_effect_storage = torch.zeros(model_config['n_layers'], model_config['n_heads'],len(token_classes))

    # Clean Run of Baseline:
    clean_output = model(**inputs).logits[:,-1,:]
    clean_probs = torch.softmax(clean_output[0], dim=-1)

    # For every layer, head, token combination perform the replacement & track the change in meaningful tokens
    for layer in range(model_config['n_layers']):
        head_hook_layer = [model_config['attn_hook_names'][layer]]
        
        for head_n in range(model_config['n_heads']):
            for i,(token_class, class_regex) in enumerate(zip(token_classes, token_classes_regex)):
                reg_class_match = re.compile(f"^{class_regex}$")
                class_token_inds = [x[0] for x in token_labels if reg_class_match.match(x[2])]

                intervention_locations = [(layer, head_n, token_n) for token_n in class_token_inds]
                intervention_fn = replace_activation_w_avg(layer_head_token_pairs=intervention_locations, avg_activations=avg_activations, 
                                                           model=model, model_config=model_config,
                                                           batched_input=False, idx_map=idx_map, last_token_only=last_token_only)
                with TraceDict(model, layers=head_hook_layer, edit_output=intervention_fn) as td:                
                    output = model(**inputs).logits[:,-1,:] # batch_size x n_tokens x vocab_size, only want last token prediction
                
                # TRACK probs of tokens of interest
                intervention_probs = torch.softmax(output, dim=-1) # convert to probability distribution
                indirect_effect_storage[layer,head_n,i] = (intervention_probs-clean_probs).index_select(1, torch.LongTensor(token_id_of_interest).to(device).squeeze()).squeeze()

    return indirect_effect_storage

In [37]:
import numpy as np

def word_pairs_to_prompt_data(word_pairs : dict,
                              instructions: str = "",
                              prefixes: dict = {"input":"Q:", "output":"A:","instructions":""},
                              separators: dict = {"input":"\n", "output":"\n\n", "instructions":""},
                              query_target_pair: dict = None, prepend_bos_token=False,
                              shuffle_labels=False, prepend_space=True) -> dict:
    """Takes a dataset of word pairs, and constructs a prompt_data dict with additional information to construct an ICL prompt.
    Parameters:
    word_pairs: dict of the form {'word1':['a', 'b', ...], 'word2':['c', 'd', ...]}
    instructions: prefix instructions for an ICL prompt
    prefixes: dict of ICL prefixes that are prepended to inputs, outputs and instructions
    separators: dict of ICL separators that are appended to inputs, outputs and instructions
    query_target_pair: dict with a single input-output pair acting as the query for the prompt
    prepend_bos_token: whether or not to prepend a BOS token to the prompt
    shuffle_labels: whether to shuffle the ICL labels
    prepend_space: whether to prepend a space to every input and output token

    Returns: 
    prompt_data: dict containing ICL prompt examples, and template information
    """
    print("---------------------------------------------------------------------")
    prompt_data = {}
    prompt_data['instructions'] = instructions
    print("Instructions:", instructions)
    prompt_data['separators'] = separators
    print("Separators:", separators)
    if prepend_bos_token:
        prefixes = {k:(v if k !='instructions' else '' + v) for (k,v) in prefixes.items()}
    prompt_data['prefixes'] = prefixes
    print("Prefixes:", prefixes)

    if query_target_pair is not None:
        query_target_pair = {k:(v[0] if isinstance(v, list) else v) for k,v in query_target_pair.items()}
    prompt_data['query_target'] = query_target_pair
    print("Query Target Pair:", query_target_pair)
        
    if shuffle_labels:
        randomized_pairs = [np.random.permutation(x).tolist() if i==1 else x for (i,x) in enumerate(list(word_pairs.values()))] # shuffle labels only
        if prepend_space:
            prompt_data['examples'] = [{'input':' ' + w1, 'output':' ' + w2} for (w1,w2) in list(zip(*randomized_pairs))]
            prompt_data['query_target'] = {k:' ' + v for k,v in query_target_pair.items()} if query_target_pair is not None else None
        else:
            prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*randomized_pairs))]
    else:    
        if prepend_space:
            prompt_data['examples'] = [{'input':' ' + w1, 'output':' ' + str(w2)} for (w1,w2) in list(zip(*word_pairs.values()))]
            prompt_data['query_target'] = {k:' ' + str(v) for k,v in query_target_pair.items()} if query_target_pair is not None else None
        else:
            prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*word_pairs.values()))]
    
    print("Prompt Data Examples:", prompt_data['examples'])
    print("---------------------------------------------------------------------")
    return prompt_data

In [38]:
import numpy as np
import torch
from tqdm import tqdm

def compute_indirect_effect(dataset, mean_activations, model, model_config, tokenizer, n_shots=10, n_trials=25, last_token_only=True, prefixes=None, separators=None, filter_set=None):
    """
    Computes Indirect Effect of each head in the model
    """
    print("---------------------------------------------------------------------")
    print(f"Starting computation of indirect effects with {n_shots} shots and {n_trials} trials...")
    n_test_examples = 1

    if prefixes is not None and separators is not None:
        print("Using custom prefixes and separators for dummy token labels...")
        dummy_gt_labels = get_dummy_token_labels(n_shots, tokenizer=tokenizer, prefixes=prefixes, separators=separators)
    else:
        print("Using default settings for dummy token labels...")
        dummy_gt_labels = get_dummy_token_labels(n_shots, tokenizer=tokenizer)
    
    print(f"Dummy ground truth labels generated with length: {len(dummy_gt_labels)}")
    print(f" 10 Dummy ground truth labels are:{dummy_gt_labels[:10]}")
    
    is_llama = 'llama' in model_config['name_or_path']
    prepend_bos = not is_llama
    print(f"Model prepend_bos setting: {prepend_bos}")

    if last_token_only:
        indirect_effect = torch.zeros(n_trials, model_config['n_layers'], model_config['n_heads'])
    else:
        indirect_effect = torch.zeros(n_trials, model_config['n_layers'], model_config['n_heads'], 10) # have 10 classes of tokens
    
    print(f"Initialized indirect effect tensor with zero tensor of shape: {indirect_effect.shape}")

    if filter_set is None:
        filter_set = np.arange(len(dataset['valid']))
    
    for i in tqdm(range(n_trials), total=n_trials):
        word_pairs = dataset['train'][np.random.choice(len(dataset['train']), n_shots, replace=False)]
        word_pairs_test = dataset['valid'][np.random.choice(filter_set, n_test_examples, replace=False)]
        
        print(f"Trial {i+1}/{n_trials}: Selected word pairs for prompts and testing.")

        print("Parameters to word_pairs_to_prompt_data:")
        print(f"word_pairs: {word_pairs}, query_target_pair: {word_pairs_test}, shuffle_labels: True, prepend_bos_token: {prepend_bos}")

        if prefixes is not None and separators is not None:
            prompt_data_random = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, shuffle_labels=True, 
                                                           prepend_bos_token=prepend_bos, prefixes=prefixes, separators=separators)
        else:
            prompt_data_random = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, 
                                                           shuffle_labels=True, prepend_bos_token=prepend_bos)
        print(f"Generated prompt data for trial {i+1}.")
        print("Prompt Data", prompt_data_random)
        
        ind_effects = activation_replacement_per_class_intervention(prompt_data=prompt_data_random, 
                                                                    avg_activations=mean_activations, 
                                                                    dummy_labels=dummy_gt_labels, 
                                                                    model=model, model_config=model_config, tokenizer=tokenizer, 
                                                                    last_token_only=last_token_only)
        indirect_effect[i] = ind_effects.squeeze()
        print(f"Indirect effects calculated for trial {i+1} with shape: {ind_effects.shape}.")

    print("Completed computation of indirect effects.")
    return indirect_effect

# Note: You'll need to define or adapt the functions get_dummy_token_labels and word_pairs_to_prompt_data,
# as well as activation_replacement_per_class_intervention, according to their requirements and outputs.

In [39]:

args = {
    'dataset_name': 'antonym',
    'model_name': 'gpt2',
    'root_data_dir': '../dataset_files',
    'save_path_root': '../results',
    'seed': 42,
    'n_shots': 10,
    'n_trials': 1,
    'test_split': 0.3,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'mean_activations_path': None,
    'last_token_only': True,
    'prefixes': {"input": "Q:", "output": "A:", "instructions": ""},
    'separators': {"input": "\\n", "output": "\\n\\n", "instructions": ""}
}

dataset_name = args['dataset_name']
model_name = args['model_name']
root_data_dir = args['root_data_dir']
save_path_root = f"{args['save_path_root']}/{dataset_name}"
seed = args['seed']
n_shots = args['n_shots']
n_trials = args['n_trials']
test_split = args['test_split']
device = args['device']
mean_activations_path = args['mean_activations_path']
last_token_only = args['last_token_only']
prefixes = args['prefixes']
separators = args['separators']

In [40]:
# Load Model & Tokenizer
torch.set_grad_enabled(False)
print("Loading Model")
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device)

Loading Model
Loading:  gpt2


In [41]:
set_seed(seed)

# Load the dataset
print("Loading Dataset")
dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)

Loading Dataset


In [42]:
if not os.path.exists(save_path_root):
    os.makedirs(save_path_root)

In [43]:
# Load or Re-Compute Mean Activations
if mean_activations_path is not None and os.path.exists(mean_activations_path):
    mean_activations = torch.load(mean_activations_path)
elif mean_activations_path is None and os.path.exists(f'{save_path_root}/{dataset_name}_mean_head_activations.pt'):
    mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
    mean_activations = torch.load(mean_activations_path)        
else:
    print("Computing Mean Activations")
    mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, 
                                                    n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators)
    torch.save(mean_activations, f'{save_path_root}/{dataset_name}_mean_head_activations.pt')

In [44]:
print("Computing Indirect Effect")
indirect_effect = compute_indirect_effect(dataset, mean_activations, model=model, model_config=model_config, tokenizer=tokenizer, 
                                            n_shots=n_shots, n_trials=n_trials, last_token_only=last_token_only, prefixes=prefixes, separators=separators)

Computing Indirect Effect
---------------------------------------------------------------------
Starting computation of indirect effects with 10 shots and 1 trials...
Using custom prefixes and separators for dummy token labels...
Dummy ground truth labels generated with length: 128
 10 Dummy ground truth labels are:[(0, 'bos_token'), (1, 'structural_token'), (2, 'structural_token'), (3, 'demonstration_1_token'), (4, 'separator_token'), (5, 'separator_token'), (6, 'structural_token'), (7, 'predictive_token'), (8, 'demonstration_1_label_token'), (9, 'end_of_example_token')]
Model prepend_bos setting: True
Initialized indirect effect tensor with zero tensor of shape: torch.Size([1, 12, 12])


  0%|          | 0/1 [00:00<?, ?it/s]

Trial 1/1: Selected word pairs for prompts and testing.
Parameters to word_pairs_to_prompt_data:
word_pairs: {'input': ['valley', 'economic', 'fugitive', 'inward', 'unmarked', 'worthless', 'regress', 'borrowing', 'saline', 'risky'], 'output': ['mountain', 'uneconomic', 'law-abiding citizen', 'outward', 'marked', 'valuable', 'progress', 'lending', 'freshwater', 'safe']}, query_target_pair: {'input': ['secure'], 'output': ['insecure']}, shuffle_labels: True, prepend_bos_token: True
---------------------------------------------------------------------
Instructions: 
Separators: {'input': '\\n', 'output': '\\n\\n', 'instructions': ''}
Prefixes: {'input': 'Q:', 'output': 'A:', 'instructions': ''}
Query Target Pair: {'input': 'secure', 'output': 'insecure'}
Prompt Data Examples: [{'input': ' valley', 'output': ' progress'}, {'input': ' economic', 'output': ' valuable'}, {'input': ' fugitive', 'output': ' outward'}, {'input': ' inward', 'output': ' freshwater'}, {'input': ' unmarked', 'output

100%|██████████| 1/1 [00:02<00:00,  2.24s/it]

Indirect effects calculated for trial 1 with shape: torch.Size([12, 12, 1]).
Completed computation of indirect effects.





In [25]:
# Write args to file
args['save_path_root'] = save_path_root
args['mean_activations_path'] = mean_activations_path
with open(f'{save_path_root}/indirect_effect_args.txt', 'w') as arg_file:
    json.dump(args, arg_file, indent=2)

torch.save(indirect_effect, f'{save_path_root}/{dataset_name}_indirect_effect.pt')

-----