In [None]:
'''
create forward and reversed attention maps
'''

In [None]:
# general imports
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import copy
import seaborn as sns
import json

# another code we wrote
sys.path.append('../')
import llm_utils
import exp_ra_utils
import opt_utils
from plot_utils import plot_aux_wrapper

In [3]:
try:
    from function_vectors.src.utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt
    from function_vectors.src.utils.eval_utils import decode_to_vocab, sentence_eval, n_shot_eval_no_intervention, get_answer_id
except Exception as error:
    print('could not import from function_vectors package with the following error:')
    print(error)
    print('Make sure you first pull relevant submodules. See README.md for more info.')

In [None]:
START_TIME = time.strftime("%Y/%m/%d-%H:%M:%S")
# DEBUG_FLAG = True  # for local testing
DEBUG_FLAG = False  # automatically set below

try:
    DEBUG_FLAG = torch.backends.mps.is_available()  # since only Apple M1-3 supports this, we can use this as a flag for local testing
except:
    pass
if DEBUG_FLAG:
    print('*'*40)
    print(f'    DEBUG MODE [{START_TIME}]')
    print('*'*40)

In [5]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='gpt2' if DEBUG_FLAG else 'gpt2-xl')
parser.add_argument('--model_args', type=str, default='')
parser.add_argument('--out_folder', type=str, default='tmp_plots1')
parser.add_argument('--postfix_name', type=str, default='')
parser.add_argument('--disable_pad_token', action='store_true')
parser.add_argument('--root_data_dir', type=str, default='../function_vectors/dataset_files')
parser.add_argument('--dataset_name', type=str, default='antonym')
parser.add_argument('--n_shots_icl', type=int, default=0)
parser.add_argument('--n_samples', type=int, default=5 if DEBUG_FLAG else 25)
parser.add_argument('--metric_to_use', type=str, default='f1_score')  # f1_score, exact_match_score, first_word_score
parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})    
parser.add_argument('--device', type=str, default=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--check_if_result_already_exists', action='store_true')

_StoreTrueAction(option_strings=['--check_if_result_already_exists'], dest='check_if_result_already_exists', nargs=0, const=True, default=False, type=None, choices=None, required=False, help=None, metavar=None)

In [6]:
args, unknown = parser.parse_known_args()
print('unknown args:', unknown)
print('args:', args)

unknown args: ['--f=/Users/ks/Library/Jupyter/runtime/kernel-v2-37936PB31iTO0ySOS.json']
args: Namespace(model_name='gpt2', model_args='', out_folder='tmp_plots1', postfix_name='', disable_pad_token=False, root_data_dir='../function_vectors/dataset_files', dataset_name='antonym', n_shots_icl=0, n_samples=5, metric_to_use='f1_score', prefixes={'input': 'Q:', 'output': 'A:', 'instructions': ''}, separators={'input': '\n', 'output': '\n\n', 'instructions': ''}, device=device(type='cpu'), seed=42, check_if_result_already_exists=False)


In [7]:
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if not args.disable_pad_token:
    print(f'adding pad token: {tokenizer.eos_token}')
    tokenizer.pad_token = tokenizer.eos_token
    
try:
    os.environ["TOKENIZERS_PARALLELISM"] = "true"  # not blocking, just to prevent warnings and faster tokenization
except:
    pass
padding_flag = tokenizer.pad_token_id is not None  # this is up to if the tokenizer was configured with padding or not

adding pad token: <|endoftext|>


In [8]:
if args.device == 'AUTO' or args.device == '':
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
else:
    device = torch.device(args.device)
print(f'Using device: {device} [cuda available? => {torch.cuda.is_available()}, cuda version: {torch.version.cuda}, args.device = "{args.device}"]')

Using device: cpu [cuda available? => False, cuda version: None, args.device = "cpu"]


In [None]:
model_extra_args = {}
for arg in args.model_args.split(','):
    if arg == '':
        continue
    k, v = arg.split('=')
    model_extra_args[k] = v
print(f'model_extra_args: {model_extra_args}')

model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_extra_args).eval().requires_grad_(False).to(device)
model_aux = llm_utils.model_extra(model=model, device=device)
model_config = copy.deepcopy(model.config)
config = model_aux.config
# del model

n_embd = model_aux.n_embd
n_head = model_aux.n_head
head_size = model_aux.head_size
n_layer = model_aux.n_layer

pad_k = model_aux.pad_k
pad_v = model_aux.pad_v

# params_names_filter = opt_utils.get_filter_by_name(args.filter_layers)
params_names_filter = opt_utils.get_filter_by_name('attn_only')
# params_names_filter = lambda x: ('attn' in x or 'attention' in x) and 'weight' in x

n_shots_icl = args.n_shots_icl
print(f'n_shots_icl: {n_shots_icl}')
metric_to_use = args.metric_to_use
n_samples = args.n_samples

In [10]:
print(model.config)

GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.40.1",
  "use_cache": true,
  "vocab_size": 50257
}



In [11]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [12]:
# the prefix is a code for which version of the code we use in this notebook
output_prefix = f'av1_{args.model_name.replace("/", "-")}_{args.dataset_name}_[{args.n_shots_icl}]_{n_samples}_{metric_to_use}_{args.postfix_name}_'

plot_wrapper = plot_aux_wrapper(output_folder=args.out_folder, 
                            output_prefix=output_prefix,
                            local_show=True)

show_every_n_layer = 1 if n_layer <= 12 else 2 if n_layer <= 24 else 4

In [13]:
is_llama = 'llama' in args.model_name or 'facebook/opt' in args.model_name
prepend_bos = not is_llama

prefixes=args.prefixes
separators=args.separators

compute_ppl=False
shuffle_labels=False
generate_str=False

In [14]:
dataset = exp_ra_utils.data_loading_wrapper(args.dataset_name, 
                                            root_data_dir=args.root_data_dir,
                                            seed=args.seed)

Load dataset: antonym with 1678 train, 216 valid and 504 test samples
Example: 1) {'input': 'lesbian', 'output': 'straight'}
Example: 2) {'input': 'homegrown', 'output': 'imported'}


In [15]:
sentence_example = exp_ra_utils.data_print_and_test_example(model, tokenizer=tokenizer,
                                                            dataset=dataset,
                                                            prepend_bos=prepend_bos,
                                                            prefixes=prefixes,
                                                            separators=separators,
                                                            shuffle_labels=shuffle_labels)

ICL prompt:
 '<|endoftext|>Q: noise\nA: silence\n\nQ: lesbian\nA: straight\n\nQ: homegrown\nA: imported\n\nQ: default\nA: customized\n\nQ: disrespect\nA: respect\n\nQ: damn\nA:' 


Zero-Shot Prompt:
 '<|endoftext|>Q: damn\nA:'
Input Sentence: '<|endoftext|>Q: noise\nA: silence\n\nQ: lesbian\nA: straight\n\nQ: homegrown\nA: imported\n\nQ: default\nA: customized\n\nQ: disrespect\nA: respect\n\nQ: damn\nA:' 

Input Query: 'damn', Target: 'bless'

ICL Prompt Top K Vocab Probs:
 [(' damn', 0.06769), (' disrespect', 0.02297), (' I', 0.01337), (' stupid', 0.01174), (' insult', 0.01142)] 

Input Sentence: '<|endoftext|>Q: damn\nA:' 

Input Query: 'damn', Target: 'bless'

Zero-Shot Prompt Top K Vocab Probs:
 [(' I', 0.08618), ('\n', 0.02441), (' The', 0.01667), (' It', 0.01592), (' Yeah', 0.01445)] 



In [16]:
df_forward_path = f'{args.out_folder}/{output_prefix}_forward_attn_maps_norms'
df_backward_path = f'{args.out_folder}/{output_prefix}_backward_attn_maps_norms'

meta_data = {
             'df_forward_path': df_forward_path,
             'df_backward_path': df_backward_path,
             'example_sentence': sentence_example,  # to verify we ran what we wanted
             'ds_train_len': len(dataset['train']),
             'ds_valid_len': len(dataset['valid']),
             'ds_test_len': len(dataset['test']),
             'start_time': START_TIME,
             'run_args': str(args),
             'debug_flag_on': DEBUG_FLAG}
meta_data_out = os.path.join(args.out_folder, f'{output_prefix}_get_mean_attention_tables.json')
if args.check_if_result_already_exists and os.path.exists(meta_data_out):
    print(f'File already exists: {meta_data_out}. Terminating this run. If you wish to re-run, please remove the file or disable the flag --check_if_result_already_exists')
    print(f'Run args: {args}')
    print('Exit...')
    exit(0)

with open(meta_data_out, 'w') as f:
    json.dump(meta_data, f)
    

In [17]:
params = []
params_names = []
for param_name, param in model.named_parameters():
    if params_names_filter(param_name):
        param.requires_grad_(True)  # only relevant layer are trained (all the rest are freezed)
        params.append(param)
        params_names.append(param_name)
print(f'Number of params to optimize: {len(params)}')
print('Going to edit only the following layers (showing only the first 10):', params_names[:10])
print('Going to edit only the following layers (showing only the last 10):', params_names[-10:])

opt = torch.optim.SGD(params, lr=1e-3, weight_decay=0)  # learning rate does not really matter as we only intrest in the VJPs which we collect
min_loss_for_update = 0

Number of params to optimize: 24
Going to edit only the following layers (showing only the first 10): ['transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.c_proj.weight', 'transformer.h.3.attn.c_attn.weight', 'transformer.h.3.attn.c_proj.weight', 'transformer.h.4.attn.c_attn.weight', 'transformer.h.4.attn.c_proj.weight']
Going to edit only the following layers (showing only the last 10): ['transformer.h.7.attn.c_attn.weight', 'transformer.h.7.attn.c_proj.weight', 'transformer.h.8.attn.c_attn.weight', 'transformer.h.8.attn.c_proj.weight', 'transformer.h.9.attn.c_attn.weight', 'transformer.h.9.attn.c_proj.weight', 'transformer.h.10.attn.c_attn.weight', 'transformer.h.10.attn.c_proj.weight', 'transformer.h.11.attn.c_attn.weight', 'transformer.h.11.attn.c_proj.weight']


In [18]:
short_lm_config = [config.attn_o]  # the only type of layers we need to calculate the reverse attn (witht he attn values)

In [19]:
def get_experiment_collector(example_index, split='train', annotate=False, dataset=dataset, curr_n_shots=n_shots_icl):
    collector = {}
    j = example_index
    # hs_collector = llm_utils.wrap_model(model, layers_to_check=args.llm_config_path, 
    #                                     return_hooks_handler=True, forward=True, max_len=1000)
    # grad_collector = llm_utils.wrap_model(model, layers_to_check=args.llm_config_path, 
    #                                     return_hooks_handler=True, forward=False, max_len=1000)
    hs_collector = {}
    grad_collector = llm_utils.wrap_model(model, layers_to_check=short_lm_config, 
                                        return_hooks_handler=True, forward=False, max_len=1000)

    opt.zero_grad()
    if curr_n_shots == 0:
        word_pairs = {'input':[], 'output':[]}
    else:
        random_samples_without_current_j = np.random.choice(len(dataset[split]), curr_n_shots, replace=False)
        while j in random_samples_without_current_j:
            random_samples_without_current_j = np.random.choice(len(dataset[split]), curr_n_shots, replace=False)

        word_pairs = dataset[split][random_samples_without_current_j]
    word_pairs_test = dataset[split][j]
    if prefixes is not None and separators is not None:
        prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, 
                                                shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
    else:
        prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair = word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
        
    # Get relevant parts of the Prompt
    query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
    query = query[0] if isinstance(query, list) else query
    if generate_str:
        target = [target] if not isinstance(target, list) else target
    else:
        target = target[0] if isinstance(target, list) else target
    
    sentence = [create_prompt(prompt_data)]
    
    # Figure out tokens of interest
    target_token_id = get_answer_id(sentence[0], target, tokenizer)

    device = model.device
    inputs = tokenizer(sentence, return_tensors='pt').to(device)
    original_pred_idx = len(inputs.input_ids.squeeze()) - 1
     
    # REF1: https://github.com/ericwtodd/function_vectors/blob/1e8a9a0f3583c547efcee2b4add4e880c25a96d3/src/utils/intervention_utils.py#L151
    # REF2: https://huggingface.co/docs/transformers/en/perplexity
    # loss is calculated using CrossEntropyLoss which averages over valid labels
    # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
    # to the left by 1.
    target_completion = "".join(sentence + [target])
    nll_inputs = tokenizer(target_completion, return_tensors='pt').to(device)
    nll_targets = nll_inputs.input_ids.clone()
    target_len = len(nll_targets.squeeze()) - len(inputs.input_ids.squeeze()) 
    nll_targets[:,:-target_len] = -100  # This is the accepted value to skip indices when computing loss in nn.CrossEntropyLoss

    if annotate:
        tmp_print = target_completion.replace("\n", "\\n")
        print(f'[{curr_n_shots}-shots] len={nll_inputs.input_ids.shape[1]}, sentence_with_target={tmp_print}')

    output = model(**nll_inputs, labels=nll_targets, output_attentions=True, output_hidden_states=True, use_cache=True)
    
    # compute gradients but do not apply the step of the optimizer
    # if clean_nll >= min_loss_for_update:
    # clean_nll = output.loss.item()
    # clean_output = output.logits[:,original_pred_idx,:]
    output.loss.backward()
    opt.zero_grad()

    collector[curr_n_shots] = {
        # 'hs': copy.deepcopy(hs_collector),
        'grad': copy.deepcopy(grad_collector),
        'kv_cache': output.past_key_values,
        'attentions': output.attentions,
        # 'inputs': copy.deepcopy(nll_inputs)
    }

    llm_utils.remove_collector_hooks(hs_collector)
    llm_utils.remove_collector_hooks(grad_collector)
    return collector

In [20]:
# also could cahced all collectors but it is not feasible with big models and many examples (one collector of gpt2-XL can get up to 500MB)

all_forward_attns = {}
all_rev_attns = {}

for layer_index in range(n_layer):
    all_rev_attns[layer_index] = {}
    all_forward_attns[layer_index] = {}
    for head_index in range(n_head):
        all_rev_attns[layer_index][head_index] = []
        all_forward_attns[layer_index][head_index] = []

for j in range(n_samples):
    if j % 10 == 0:
        print(f'j={j}')
    collector = get_experiment_collector(example_index=j, annotate=j%10==0)

    for layer_index in range(n_layer):
        for head_index in range(n_head):
            forward_attn_map, rev_attn_map = exp_ra_utils.get_forward_and_reversed_attn(collector[n_shots_icl], layer_index, head_index, config, head_size)
            all_rev_attns[layer_index][head_index].append(rev_attn_map)
            all_forward_attns[layer_index][head_index].append(forward_attn_map)
            

j=0
[0-shots] len=8, sentence_with_target=<|endoftext|>Q: noise\nA: silence


In [21]:
forward_attn_maps_norms_mean = []
backward_attn_maps_norms_mean = []
forward_attn_maps_norms_std = []
backward_attn_maps_norms_std = []
for layer_index in range(n_layer):
    forward_attn_maps_norms_mean.append([])
    backward_attn_maps_norms_mean.append([])
    forward_attn_maps_norms_std.append([])
    backward_attn_maps_norms_std.append([])
    forward_attn_maps_norms_mean[layer_index] = []
    backward_attn_maps_norms_mean[layer_index] = []
    forward_attn_maps_norms_std[layer_index] = []
    backward_attn_maps_norms_std[layer_index] = []
    for head_index in range(n_head):
        # mean_rev_attn = sum([x.norm() for x in all_rev_attns[layer_index][head_index]]) / len(all_rev_attns[layer_index][head_index])
        # mean_forward_attn = sum([x.norm() for x in all_forward_attns[layer_index][head_index]]) / len(all_forward_attns[layer_index][head_index])
        mean_rev_attn = np.mean([x.norm().item() for x in all_rev_attns[layer_index][head_index]], axis=-1)
        mean_forward_attn = np.mean([x.norm().item() for x in all_forward_attns[layer_index][head_index]], axis=-1)
        forward_attn_maps_norms_mean[layer_index].append(mean_forward_attn)
        backward_attn_maps_norms_mean[layer_index].append(mean_rev_attn)
        std_rev_attn = np.std([x.norm().item() for x in all_rev_attns[layer_index][head_index]], axis=-1)
        std_forward_attn = np.std([x.norm().item() for x in all_forward_attns[layer_index][head_index]], axis=-1)
        forward_attn_maps_norms_std[layer_index].append(std_forward_attn)
        backward_attn_maps_norms_std[layer_index].append(std_rev_attn)
        

forward_attn_maps_norms_mean = np.array(forward_attn_maps_norms_mean).T
backward_attn_maps_norms_mean = np.array(backward_attn_maps_norms_mean).T
forward_attn_maps_norms_std = np.array(forward_attn_maps_norms_std).T
backward_attn_maps_norms_std = np.array(backward_attn_maps_norms_std).T

# save all tables as df
df_forward_attn_maps_norms_mean = pd.DataFrame(forward_attn_maps_norms_mean, columns=[f'layer_{i}' for i in range(n_layer)], index=[f'head_{i}' for i in range(n_head)])
df_backward_attn_maps_norms_mean = pd.DataFrame(backward_attn_maps_norms_mean, columns=[f'layer_{i}' for i in range(n_layer)], index=[f'head_{i}' for i in range(n_head)])
df_forward_attn_maps_norms_std = pd.DataFrame(forward_attn_maps_norms_std, columns=[f'layer_{i}' for i in range(n_layer)], index=[f'head_{i}' for i in range(n_head)])
df_backward_attn_maps_norms_std = pd.DataFrame(backward_attn_maps_norms_std, columns=[f'layer_{i}' for i in range(n_layer)], index=[f'head_{i}' for i in range(n_head)])
print(f'Saving all tables as csv in {args.out_folder}/{output_prefix}_<table-name>.csv')
df_forward_attn_maps_norms_mean.to_csv(f'{df_forward_path}_mean.csv')
df_backward_attn_maps_norms_mean.to_csv(f'{df_backward_path}_mean.csv')
df_forward_attn_maps_norms_std.to_csv(f'{df_forward_path}_std.csv')
df_backward_attn_maps_norms_std.to_csv(f'{df_backward_path}_std.csv')


Saving all tables as csv in tmp_plots1/av1_gpt2_antonym_[0]_5_f1_score___<table-name>.csv


In [None]:
print('done')

raise ValueError('Done')

# Post analysis


In [None]:
for norm_table, table_name in [(forward_attn_maps_norms_mean, 'forward mean'), 
                                (forward_attn_maps_norms_std, 'forward std'),
                                (backward_attn_maps_norms_mean, 'reversed mean'),
                                (backward_attn_maps_norms_std, 'reversed std')]:
    
    if n_layer > n_head:
        ratio = 1.5 * n_layer / n_head
    else:
        ratio = 2.5
    fig, ax = plt.subplots(figsize=(ratio * 3, 3))
    cmap = 'Purples' if 'mean' in table_name else 'Greys'
    sns.heatmap(norm_table, ax=ax, annot=False, fmt=".2f", cmap=cmap, vmin=0, cbar_kws={'label': 'norm'})
    ax.set_xticks(np.arange(0, n_layer, show_every_n_layer) + 0.5)
    ax.set_yticks(np.arange(0, n_head, max(1, show_every_n_layer//2)) + 0.5)
    ax.set_xticklabels(np.arange(0, n_layer, show_every_n_layer), rotation=0)
    ax.set_yticklabels(np.arange(0, n_head, max(1, show_every_n_layer//2)), rotation=0)
    ax.set_xlabel('layer index')
    ax.set_ylabel('head index')
    for _, spine in ax.spines.items(): 
        spine.set_visible(True) 
        spine.set_linewidth(1)
    plt.tight_layout()
    plot_wrapper(plt, title=f'{table_name} attention map norms', save_also_without_title=True)