n_shot vs n-trails?

In [1]:
import os, re, json
from tqdm import tqdm
import torch, numpy as np
import argparse
from baukit import TraceDict

In [2]:
os.sys.path.append('../src/')

In [3]:
# Include prompt creation helper functions
from utils.prompt_utils import *
# from utils.intervention_utils import *
from utils.model_utils import *
# from utils.extract_utils import *
from utils.intervention_utils import replace_activation_w_avg

In [4]:
def activation_replacement_per_class_intervention(prompt_data, avg_activations, dummy_labels, model, model_config, tokenizer, last_token_only=True):
    """
    Experiment to determine top intervention locations through avg activation replacement. 
    Performs a systematic sweep over attention heads (layer, head) to track their causal influence on probs of key tokens.

    Parameters: 
    prompt_data: dict containing ICL prompt examples, and template information
    avg_activations: avg activation of each attention head in the model taken across n_trials ICL prompts
    dummy_labels: labels and indices for a baseline prompt with the same number of example pairs
    model: huggingface model
    model_config: contains model config information (n layers, n heads, etc.)
    tokenizer: huggingface tokenizer
    last_token_only: If True, only computes indirect effect for heads at the final token position. If False, computes indirect_effect for heads for all token classes

    Returns:   
    indirect_effect_storage: torch tensor containing the indirect_effect of each head for each token class.
    """
    device = model.device
    print("---------------------------Activation Replacement--------------")
    print(f"prompt_data: {prompt_data}")
    # Get sentence and token labels
    query_target_pair = prompt_data['query_target']
    print("query_target_pair:", query_target_pair)
    query = query_target_pair['input']
    token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query=query)
    
    print("token_labels is a list of triplet that contains (idx, token, meta_label)")
    print(f"token_labels: {token_labels}")
    print(f"prompt_string: {prompt_string}")
    
    print("dummy_labels", dummy_labels)
    idx_map, idx_avg = compute_duplicated_labels(token_labels, dummy_labels)
    print("idx_map", idx_map)
    print("idx_avg", idx_avg)
    
    idx_map = update_idx_map(idx_map, idx_avg)
    print("(updated)idx_avg", idx_avg)
      
    sentences = [prompt_string]# * model.config.n_head # batch things by head

    print("sentences", sentences)

    # Figure out tokens of interest
    tokens_of_interest = [query_target_pair['output']]
    print("tokens_of_interest", tokens_of_interest)
    
    if 'llama' in model_config['name_or_path']:
        ts = tokenizer(tokens_of_interest, return_tensors='pt').input_ids.squeeze()
        if tokenizer.decode(ts[1])=='' or ts[1]==29871: # avoid spacing issues
            token_id_of_interest = ts[2]
        else:
            token_id_of_interest = ts[1]
    else:
        token_id_of_interest = tokenizer(tokens_of_interest).input_ids[0][:1]
        
    inputs = tokenizer(sentences, return_tensors='pt').to(device)

    # Speed up computation by only computing causal effect at last token
    if last_token_only:
        token_classes = ['query_predictive']
        token_classes_regex = ['query_predictive_token']
    # Compute causal effect for all token classes (instead of just last token)
    else:
        token_classes = ['demonstration', 'label', 'separator', 'predictive', 'structural','end_of_example', 
                        'query_demonstration', 'query_structural', 'query_separator', 'query_predictive']
        token_classes_regex = ['demonstration_[\d]{1,}_token', 'demonstration_[\d]{1,}_label_token', 'separator_token', 'predictive_token', 'structural_token','end_of_example_token', 
                            'query_demonstration_token', 'query_structural_token', 'query_separator_token', 'query_predictive_token']
    

    indirect_effect_storage = torch.zeros(model_config['n_layers'], model_config['n_heads'],len(token_classes))
    print(f"indirect_effect_storage.shape: {indirect_effect_storage.shape}")

    # Clean Run of Baseline:
    clean_output = model(**inputs).logits[:,-1,:]
    print(f"clean_output.shape: {clean_output.shape}")
    clean_probs = torch.softmax(clean_output[0], dim=-1)
    print(f"clean_probs.shape: {clean_probs.shape}")

    # For every layer, head, token combination perform the replacement & track the change in meaningful tokens
    for layer in range(model_config['n_layers']):
        head_hook_layer = [model_config['attn_hook_names'][layer]]
        
        for head_n in range(model_config['n_heads']):
            for i,(token_class, class_regex) in enumerate(zip(token_classes, token_classes_regex)):
                print("token_class, class_regex: ", token_class, class_regex)
                reg_class_match = re.compile(f"^{class_regex}$")
                print("reg_class_match: ", reg_class_match)
                
                class_token_inds = [x[0] for x in token_labels if reg_class_match.match(x[2])]
                print("class_token_inds: ", class_token_inds)
                

                intervention_locations = [(layer, head_n, token_n) for token_n in class_token_inds]
                print("intervention_locations:", intervention_locations)
                intervention_fn = replace_activation_w_avg(layer_head_token_pairs=intervention_locations, avg_activations=avg_activations, 
                                                           model=model, model_config=model_config,
                                                           batched_input=False, idx_map=idx_map, last_token_only=last_token_only)
                with TraceDict(model, layers=head_hook_layer, edit_output=intervention_fn) as td:                
                    output = model(**inputs).logits[:,-1,:] # batch_size x n_tokens x vocab_size, only want last token prediction
                
                # TRACK probs of tokens of interest
                intervention_probs = torch.softmax(output, dim=-1) # convert to probability distribution
                indirect_effect_storage[layer,head_n,i] = (intervention_probs-clean_probs).index_select(1, torch.LongTensor(token_id_of_interest).to(device).squeeze()).squeeze()

    return indirect_effect_storage

In [5]:
import numpy as np
import torch
from tqdm import tqdm

def compute_indirect_effect(dataset, mean_activations, model, model_config, tokenizer, n_shots=10, n_trials=25, last_token_only=True, prefixes=None, separators=None, filter_set=None):
    """
    Computes Indirect Effect of each head in the model
    """
    print("---------------------------------------------------------------------")
    print(f"Starting computation of indirect effects with {n_shots} shots and {n_trials} trials...")
    n_test_examples = 1

    if prefixes is not None and separators is not None:
        print("Using custom prefixes and separators for dummy token labels...")
        dummy_gt_labels = get_dummy_token_labels(n_shots, tokenizer=tokenizer, prefixes=prefixes, separators=separators)
    else:
        print("Using default settings for dummy token labels...")
        dummy_gt_labels = get_dummy_token_labels(n_shots, tokenizer=tokenizer)
    
    print(f"Dummy ground truth labels generated with length: {len(dummy_gt_labels)}")
    print(f" 10 Dummy ground truth labels are:{dummy_gt_labels[:10]}")
    
    is_llama = 'llama' in model_config['name_or_path']
    prepend_bos = not is_llama
    print(f"Model prepend_bos setting: {prepend_bos}")

    if last_token_only:
        indirect_effect = torch.zeros(n_trials, model_config['n_layers'], model_config['n_heads'])
    else:
        indirect_effect = torch.zeros(n_trials, model_config['n_layers'], model_config['n_heads'], 10) # have 10 classes of tokens
    
    print(f"Initialized indirect effect tensor with zero tensor of shape: {indirect_effect.shape}")

    if filter_set is None:
        filter_set = np.arange(len(dataset['valid']))
    
    for i in tqdm(range(n_trials), total=n_trials):
        word_pairs = dataset['train'][np.random.choice(len(dataset['train']), n_shots, replace=False)]
        word_pairs_test = dataset['valid'][np.random.choice(filter_set, n_test_examples, replace=False)]
        
        print(f"Trial {i+1}/{n_trials}: Selected word pairs for prompts and testing.")

        print("Parameters to word_pairs_to_prompt_data:")
        print(f"word_pairs: {word_pairs}, query_target_pair: {word_pairs_test}, shuffle_labels: True, prepend_bos_token: {prepend_bos}")

        if prefixes is not None and separators is not None:
            prompt_data_random = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, shuffle_labels=True, 
                                                           prepend_bos_token=prepend_bos, prefixes=prefixes, separators=separators)
        else:
            prompt_data_random = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, 
                                                           shuffle_labels=True, prepend_bos_token=prepend_bos)
        print(f"Generated prompt data for trial {i+1}.")
        print("Prompt Data", prompt_data_random)
        
        ind_effects = activation_replacement_per_class_intervention(prompt_data=prompt_data_random, 
                                                                    avg_activations=mean_activations, 
                                                                    dummy_labels=dummy_gt_labels, 
                                                                    model=model, model_config=model_config, tokenizer=tokenizer, 
                                                                    last_token_only=last_token_only)
        indirect_effect[i] = ind_effects.squeeze()
        print(f"Indirect effects calculated for trial {i+1} with shape: {ind_effects.shape}.")

    print("Completed computation of indirect effects.")
    return indirect_effect

# Note: You'll need to define or adapt the functions get_dummy_token_labels and word_pairs_to_prompt_data,
# as well as activation_replacement_per_class_intervention, according to their requirements and outputs.

In [6]:

args = {
    'dataset_name': 'antonym',
    'model_name': 'gpt2',
    'root_data_dir': '../dataset_files',
    'save_path_root': '../results',
    'seed': 42,
    'n_shots': 10,
    'n_trials': 1,
    'test_split': 0.3,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'mean_activations_path': None,
    'last_token_only': True,
    'prefixes': {"input": "Q:", "output": "A:", "instructions": ""},
    'separators': {"input": "\\n", "output": "\\n\\n", "instructions": ""}
}

dataset_name = args['dataset_name']
model_name = args['model_name']
root_data_dir = args['root_data_dir']
save_path_root = f"{args['save_path_root']}/{dataset_name}"
seed = args['seed']
n_shots = args['n_shots']
n_trials = args['n_trials']
test_split = args['test_split']
device = args['device']
mean_activations_path = args['mean_activations_path']
last_token_only = args['last_token_only']
prefixes = args['prefixes']
separators = args['separators']

In [7]:
# Load Model & Tokenizer
torch.set_grad_enabled(False)
print("Loading Model")
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device)

Loading Model
Loading:  gpt2


In [8]:
set_seed(seed)

# Load the dataset
print("Loading Dataset")
dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)

Loading Dataset


In [9]:
if not os.path.exists(save_path_root):
    os.makedirs(save_path_root)

In [10]:
# Load or Re-Compute Mean Activations
if mean_activations_path is not None and os.path.exists(mean_activations_path):
    mean_activations = torch.load(mean_activations_path)
elif mean_activations_path is None and os.path.exists(f'{save_path_root}/{dataset_name}_mean_head_activations.pt'):
    mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
    mean_activations = torch.load(mean_activations_path)        
else:
    print("Computing Mean Activations")
    mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, 
                                                    n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators)
    torch.save(mean_activations, f'{save_path_root}/{dataset_name}_mean_head_activations.pt')

In [11]:
print("Computing Indirect Effect")
indirect_effect = compute_indirect_effect(dataset, mean_activations, model=model, model_config=model_config, tokenizer=tokenizer, 
                                            n_shots=n_shots, n_trials=n_trials, last_token_only=last_token_only, prefixes=prefixes, separators=separators)

Computing Indirect Effect
---------------------------------------------------------------------
Starting computation of indirect effects with 10 shots and 1 trials...
Using custom prefixes and separators for dummy token labels...
<function get_token_meta_labels at 0x2ba95b53dd80>
----------------------------
get_prompt_parts_and_labels
----------------------------
prompt_parts: ['<|endoftext|>', '', '', ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:', ' a', '\\n\\n'], ['Q:', ' a', '\\n', 'A:']]
prompt_part_labels: ['bos_token', 'instructions_token', 'separator_token', ['structural_token', 'demonstration_1_token', 'separator_toke

  0%|          | 0/1 [00:00<?, ?it/s]

Trial 1/1: Selected word pairs for prompts and testing.
Parameters to word_pairs_to_prompt_data:
word_pairs: {'input': ['valley', 'economic', 'fugitive', 'inward', 'unmarked', 'worthless', 'regress', 'borrowing', 'saline', 'risky'], 'output': ['mountain', 'uneconomic', 'law-abiding citizen', 'outward', 'marked', 'valuable', 'progress', 'lending', 'freshwater', 'safe']}, query_target_pair: {'input': ['secure'], 'output': ['insecure']}, shuffle_labels: True, prepend_bos_token: True
Generated prompt data for trial 1.
Prompt Data {'instructions': '', 'separators': {'input': '\\n', 'output': '\\n\\n', 'instructions': ''}, 'prefixes': {'input': 'Q:', 'output': 'A:', 'instructions': '<|endoftext|>'}, 'query_target': {'input': ' secure', 'output': ' insecure'}, 'examples': [{'input': ' valley', 'output': ' progress'}, {'input': ' economic', 'output': ' valuable'}, {'input': ' fugitive', 'output': ' outward'}, {'input': ' inward', 'output': ' freshwater'}, {'input': ' unmarked', 'output': ' saf

100%|██████████| 1/1 [00:02<00:00,  2.59s/it]

token_class, class_regex:  query_predictive query_predictive_token
reg_class_match:  re.compile('^query_predictive_token$')
class_token_inds:  [131]
intervention_locations: [(11, 8, 131)]
----------------replace_activation_w_avg-------------
Current layer name: transformer.h.11.attn.c_proj
Current layer index: 11
Intervening in layer: 11
Inputs are a tuple, taking the first element
Original input shape: torch.Size([1, 132, 768])
New input shape after reshaping for heads: torch.Size([1, 132, 12, 64])
Last token only mode
Layer 11: Patching activation at last token for head 8, token 131
GPT2 model detected
token_class, class_regex:  query_predictive query_predictive_token
reg_class_match:  re.compile('^query_predictive_token$')
class_token_inds:  [131]
intervention_locations: [(11, 9, 131)]
----------------replace_activation_w_avg-------------
Current layer name: transformer.h.11.attn.c_proj
Current layer index: 11
Intervening in layer: 11
Inputs are a tuple, taking the first element
Ori


