In [None]:
# general imports
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import copy
import random
import seaborn as sns
import json
from typing import Optional, Tuple, Union

# another code we wrote
sys.path.append('../')
import llm_utils
import opt_utils
from plot_utils import plot_aux_wrapper
import exp_ra_utils

In [3]:
try:
    from function_vectors.src.utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt
except Exception as error:
    print('could not import from function_vectors package with the following error:')
    print(error)
    print('Make sure you first pull relevant submodules. See README.md for more info.')

In [4]:
START_TIME = time.strftime("%Y/%m/%d-%H:%M:%S")
# DEBUG_FLAG = True  # for local testing
DEBUG_FLAG = False  # automatically set below

try:
    DEBUG_FLAG = torch.backends.mps.is_available()  # since only Apple M1-3 supports this, we can use this as a flag for local testing
except:
    pass
if DEBUG_FLAG:
    print('*'*40)
    print(f'    DEBUG MODE [{START_TIME}]')
    print('*'*40)

****************************************
    DEBUG MODE [2024/09/28-16:53:17]
****************************************


In [None]:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='gpt2' if DEBUG_FLAG else 'gpt2-xl')
parser.add_argument('--model_args', type=str, default='')
parser.add_argument('--layer_opt_filter', type=str, default='attn_only')
parser.add_argument('--out_folder', type=str, default='tmp_plots1')
parser.add_argument('--postfix_name', type=str, default='')
parser.add_argument('--disable_pad_token', action='store_true')
parser.add_argument('--root_data_dir', type=str, default='../function_vectors/dataset_files')
parser.add_argument('--dataset_name', type=str, default='antonym_len_8')
parser.add_argument('--n_shots_icl', type=str, default='0,5' if DEBUG_FLAG else '0,1,5,10')
parser.add_argument('--n_samples', type=int, default=5 if DEBUG_FLAG else 25)
parser.add_argument('--metric_to_use', type=str, default='f1_score')  # f1_score, exact_match_score, first_word_score. not really used as we examine only accuracy of top-{1,2,3} predictions
parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})    
parser.add_argument('--device', type=str, default=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--check_if_result_already_exists', action='store_true')

In [None]:
args, unknown = parser.parse_known_args()
print('unknown args:', unknown)
print('args:', args)

In [17]:
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if not args.disable_pad_token:
    print(f'adding pad token: {tokenizer.eos_token}')
    tokenizer.pad_token = tokenizer.eos_token
    
try:
    os.environ["TOKENIZERS_PARALLELISM"] = "true"  # not blocking, just to prevent warnings and faster tokenization
except:
    pass
padding_flag = tokenizer.pad_token_id is not None  # this is up to if the tokenizer was configured with padding or not

adding pad token: <|endoftext|>


In [18]:
device = torch.device(args.device)
print(f'Using device: {device} [cuda available? => {torch.cuda.is_available()}, cuda version: {torch.version.cuda}, args.device = "{args.device}"]')

Using device: cpu [cuda available? => False, cuda version: None, args.device = "cpu"]


In [19]:
model_extra_args = {}
for arg in args.model_args.split(','):
    if arg == '':
        continue
    k, v = arg.split('=')
    model_extra_args[k] = v
print(f'model_extra_args: {model_extra_args}')

model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_extra_args).eval().requires_grad_(False).to(device)
model_aux = llm_utils.model_extra(model=model, device=device)
model_config = copy.deepcopy(model.config)
config = model_aux.config
# del model

n_embd = model_aux.n_embd
n_head = model_aux.n_head
head_size = model_aux.head_size
n_layer = model_aux.n_layer

pad_k = model_aux.pad_k
pad_v = model_aux.pad_v


# params_names_filter = opt_utils.get_filter_by_name(args.filter_layers)
if args.layer_opt_filter == 'attn_only':
    params_names_filter = lambda x: ('attn' in x or 'attention' in x) and 'weight' in x and 'layer_norm' not in x
     # tested for attn with gpt2 and opt
else:
    params_names_filter = opt_utils.get_filter_by_name(args.layer_opt_filter)  # tested for attn with gpt2 and opt
# params_names_filter = lambda x: ('attn' in x or 'attention' in x) and 'weight' in x

n_shots_icl = [int(x) for x in args.n_shots_icl.split(',')]
print(f'n_shots_icl: {n_shots_icl}')
metric_to_use = args.metric_to_use
n_samples = args.n_samples

model_extra_args: {}
Loading config from /Users/ks/Documents/research/projD2/opt07/bl2/clean_bu_code_v2/flow_graph_configs/config_gpt2.json
{'config_name': 'gpt2', 'layer_format': 'transformer.h.{}', 'layer_mlp_format': 'transformer.h.{}.mlp', 'layer_attn_format': 'transformer.h.{}.attn', 'ln1': 'transformer.h.{}.ln_1', 'attn_q': 'transformer.h.{}.attn.c_attn', 'attn_k': 'transformer.h.{}.attn.c_attn', 'attn_v': 'transformer.h.{}.attn.c_attn', 'attn_o': 'transformer.h.{}.attn.c_proj', 'ln2': 'transformer.h.{}.ln_2', 'mlp_ff1': 'transformer.h.{}.mlp.c_fc', 'mlp_ff2': 'transformer.h.{}.mlp.c_proj', 'mlp_act': 'transformer.h.{}.mlp.act', 'include_mlp_bias': True, 'include_attn_bias': True}
n_shots_icl: [0, 5]


In [20]:
print(model.config)

GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.40.1",
  "use_cache": true,
  "vocab_size": 50257
}



In [21]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [None]:
# the prefix is a code for which version of the code we use in this notebook
output_prefix = f'ap1_{args.model_name.replace("/", "-")}_{args.dataset_name}_[{args.n_shots_icl}]_{n_samples}_{args.postfix_name}_'
plot_wrapper = plot_aux_wrapper(output_folder=args.out_folder, 
                            output_prefix=output_prefix,
                            local_show=True)

show_every_n_layer = 1 if n_layer <= 12 else 2 if n_layer <= 24 else 4

In [23]:
is_llama = 'llama' in args.model_name or 'facebook/opt' in args.model_name
prepend_bos = not is_llama

prefixes=args.prefixes
separators=args.separators

compute_ppl=False
shuffle_labels=False
generate_str=False

In [24]:
dataset = exp_ra_utils.data_loading_wrapper(args.dataset_name, 
                                            root_data_dir=args.root_data_dir,
                                            seed=args.seed)

Load dataset: antonym_len_8_1 with 1501 train, 194 valid and 450 test samples
Example: 1) {'input': 'figured', 'output': 'plain'}
Example: 2) {'input': 'bent', 'output': 'straight'}


In [25]:
sentence_example = exp_ra_utils.data_print_and_test_example(model, tokenizer=tokenizer,
                                                            dataset=dataset,
                                                            prepend_bos=prepend_bos,
                                                            prefixes=prefixes,
                                                            separators=separators,
                                                            shuffle_labels=shuffle_labels)

ICL prompt:
 '<|endoftext|>Q: historic\nA: modern\n\nQ: figured\nA: plain\n\nQ: bent\nA: straight\n\nQ: convergence\nA: divergence\n\nQ: more\nA: less\n\nQ: lifelong\nA:' 


Zero-Shot Prompt:
 '<|endoftext|>Q: lifelong\nA:'
Input Sentence: '<|endoftext|>Q: historic\nA: modern\n\nQ: figured\nA: plain\n\nQ: bent\nA: straight\n\nQ: convergence\nA: divergence\n\nQ: more\nA: less\n\nQ: lifelong\nA:' 

Input Query: 'lifelong', Target: 'temporary'

ICL Prompt Top K Vocab Probs:
 [(' lifelong', 0.17095), (' lifetime', 0.12667), (' long', 0.04274), (' life', 0.02816), (' more', 0.02066)] 

Input Sentence: '<|endoftext|>Q: lifelong\nA:' 

Input Query: 'lifelong', Target: 'temporary'

Zero-Shot Prompt Top K Vocab Probs:
 [(' I', 0.1022), (' Yes', 0.03696), (' The', 0.02445), (' No', 0.02331), (' It', 0.01985)] 



In [26]:
params = []
params_names = []
for param_name, param in model.named_parameters():
    if params_names_filter(param_name):
        param.requires_grad_(True)  # only relevant layer are trained (all the rest are freezed)
        params.append(param)
        params_names.append(param_name)
print(f'Number of params to optimize: {len(params)}')
print('Going to edit only the following layers (showing only the first 10):', params_names[:10])
print('Going to edit only the following layers (showing only the last 10):', params_names[-10:])

opt = torch.optim.SGD(params, lr=1e-3, weight_decay=0)  # learning rate does not really matter as we only intrest in the VJPs which we collect
min_loss_for_update = 0

Number of params to optimize: 24
Going to edit only the following layers (showing only the first 10): ['transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.c_proj.weight', 'transformer.h.3.attn.c_attn.weight', 'transformer.h.3.attn.c_proj.weight', 'transformer.h.4.attn.c_attn.weight', 'transformer.h.4.attn.c_proj.weight']
Going to edit only the following layers (showing only the last 10): ['transformer.h.7.attn.c_attn.weight', 'transformer.h.7.attn.c_proj.weight', 'transformer.h.8.attn.c_attn.weight', 'transformer.h.8.attn.c_proj.weight', 'transformer.h.9.attn.c_attn.weight', 'transformer.h.9.attn.c_proj.weight', 'transformer.h.10.attn.c_attn.weight', 'transformer.h.10.attn.c_proj.weight', 'transformer.h.11.attn.c_attn.weight', 'transformer.h.11.attn.c_proj.weight']


In [27]:
short_lm_config = [config.attn_o]  # the only type of layers we need to calculate the reverse attn (witht he attn values)

In [28]:
def get_experiment_collector(example_index, split='train', annotate=False, dataset=dataset):
    collector = {}
    j = example_index
    for curr_n_shots in n_shots_icl:
        run_type = curr_n_shots
        # hs_collector = llm_utils.wrap_model(model, layers_to_check=args.llm_config_path, 
        #                                     return_hooks_handler=True, forward=True, max_len=1000)
        # grad_collector = llm_utils.wrap_model(model, layers_to_check=args.llm_config_path, 
        #                                     return_hooks_handler=True, forward=False, max_len=1000)
        hs_collector = {}
        grad_collector = llm_utils.wrap_model(model, layers_to_check=short_lm_config, 
                                            return_hooks_handler=True, forward=False, max_len=1000)

        opt.zero_grad()
        if curr_n_shots == 0:
            word_pairs = {'input':[], 'output':[]}
        else:
            random_samples_without_current_j = np.random.choice(len(dataset[split]), curr_n_shots, replace=False)
            while j in random_samples_without_current_j:
                random_samples_without_current_j = np.random.choice(len(dataset[split]), curr_n_shots, replace=False)

            word_pairs = dataset[split][random_samples_without_current_j]
        word_pairs_test = dataset[split][j]
        if prefixes is not None and separators is not None:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
        else:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels)
            
        # Get relevant parts of the Prompt
        query, target = prompt_data['query_target']['input'], prompt_data['query_target']['output']
        query = query[0] if isinstance(query, list) else query
        if generate_str:
            target = [target] if not isinstance(target, list) else target
        else:
            target = target[0] if isinstance(target, list) else target
        
        sentence = [create_prompt(prompt_data)]
        
        # # Figure out tokens of interest
        # target_token_id = get_answer_id(sentence[0], target, tokenizer)
    
        device = model.device
        inputs = tokenizer(sentence, return_tensors='pt').to(device)
        original_pred_idx = len(inputs.input_ids.squeeze()) - 1

        # if compute_nll:
        target_completion = "".join(sentence + [target])
        nll_inputs = tokenizer(target_completion, return_tensors='pt').to(device)
        nll_targets = nll_inputs.input_ids.clone()
        target_len = len(nll_targets.squeeze()) - len(inputs.input_ids.squeeze()) 
        nll_targets[:,:-target_len] = -100  # This is the accepted value to skip indices when computing loss in nn.CrossEntropyLoss

        if annotate:
            tmp_print = target_completion.replace("\n", "\\n")
            print(f'[{run_type}-shots] len={nll_inputs.input_ids.shape[1]}, sentence_with_target="{tmp_print}", only target="{target}"')

        output = model(**nll_inputs, labels=nll_targets, output_attentions=True, output_hidden_states=True, use_cache=True)
        
         # compute gradients but do not apply the step of the optimizer
        # if clean_nll >= min_loss_for_update:
        # clean_nll = output.loss.item()
        # clean_output = output.logits[:,original_pred_idx,:]
        output.loss.backward()
        opt.zero_grad()

        collector[run_type] = {
            # 'hs': copy.deepcopy(hs_collector),
            'grad': copy.deepcopy(grad_collector),
            'kv_cache': output.past_key_values,
            'attentions': output.attentions,
            # 'inputs': copy.deepcopy(nll_inputs)
        }


        llm_utils.remove_collector_hooks(hs_collector)
        llm_utils.remove_collector_hooks(grad_collector)
    return collector

In [29]:
all_rev_attns = {}
all_forward_attns = {}

for n_shots in n_shots_icl:
    all_rev_attns[n_shots] = {}
    all_forward_attns[n_shots] = {}
    for layer_index in range(n_layer):
        all_rev_attns[n_shots][layer_index] = {}
        all_forward_attns[n_shots][layer_index] = {}
        for head_index in range(n_head):
            all_rev_attns[n_shots][layer_index][head_index] = []
            all_forward_attns[n_shots][layer_index][head_index] = []
    
for example_index in range(n_samples):
    if example_index % 5 == 0:
        print(f'example_index: {example_index}/{n_samples}')
    collector = get_experiment_collector(example_index=example_index, annotate= example_index%5 == 0)
    for n_shots in n_shots_icl:
        for layer_index in range(n_layer):
            for head_index in range(n_head):
                # forward_attn_map, rev_attn_map = get_forward_and_reversed_attn(collector[n_shots], layer_index, head_index)
                forward_attn_map, rev_attn_map = exp_ra_utils.get_forward_and_reversed_attn(collector[n_shots], layer_index, head_index, config, head_size)
                all_rev_attns[n_shots][layer_index][head_index].append(rev_attn_map)
                all_forward_attns[n_shots][layer_index][head_index].append(forward_attn_map)

example_index: 0/5
[0-shots] len=8, sentence_with_target="<|endoftext|>Q: historic\nA: modern", only target=" modern"
[5-shots] len=53, sentence_with_target="<|endoftext|>Q: destructive\nA: constructive\n\nQ: elevation\nA: depression\n\nQ: bodily\nA: spiritual\n\nQ: disparate\nA: similar\n\nQ: figured\nA: plain\n\nQ: historic\nA: modern", only target=" modern"


# Alter the Model

In [30]:
reversed_attention_intervention = {}
forward_attention_intervention = {}
for n_shots in n_shots_icl:
    reversed_attention_intervention[n_shots] = {}
    forward_attention_intervention[n_shots] = {}
    for layer_index in range(n_layer):
        reversed_attention_intervention[n_shots][layer_index] = {}
        forward_attention_intervention[n_shots][layer_index] = {}
        for head_index in range(n_head):
            # assuming all prompts are in the same length and format, hence we can average the Reverse Attentions maps
            mean_rev_attn = sum(all_rev_attns[n_shots][layer_index][head_index]) / len(all_rev_attns[n_shots][layer_index][head_index])
            reversed_attention_intervention[n_shots][layer_index][head_index] = mean_rev_attn.detach().clone().to(device)

            mean_forward_attn = sum(all_forward_attns[n_shots][layer_index][head_index]) / len(all_forward_attns[n_shots][layer_index][head_index])
            forward_attention_intervention[n_shots][layer_index][head_index] = mean_forward_attn.detach().clone().to(device)

print(f'Finish averaging the reverse attentions for {n_samples} samples')
print(f'The amount of reverse attentions for each head and layer is {len(all_rev_attns[n_shots_icl[0]][2][3])}')
# sns.heatmap(reversed_attention_intervention['zsl'][layer_index][head_index].detach().cpu().numpy(), cmap='bwr')
map_lens_to_n_shots = {}
for n_shots in n_shots_icl:
    curr_len = reversed_attention_intervention[n_shots][2][3].shape[0]
    map_lens_to_n_shots[curr_len] = n_shots  # correct for reversed and forward attentions

Finish averaging the reverse attentions for 5 samples
The amount of reverse attentions for each head and layer is 5


# How we edits the forward pass:
We use a method known as function patching (sometime refer as "monkey patching") to replace the forward pass of the model with our own forward pass.

We copy from transformers (the library by HuggingFace) the forward pass of the model, modify it, and then we replace the forward pass call of the model with our own version.

In the copied forward pass we add a our new code, which involve dynamicly adding an attention maps to the forward pass' attention maps.

our code is marked with #### (all the rest of the code is from transformers)

In [31]:
def create_aux_func_opt(self, layer_index, lr, intervention_with_reversed_attention=True):
    def forward(
        # self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None

        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_states, value_states)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = torch.max(
                attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
        if attn_weights.dtype == torch.float16:
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
        else:
            attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if layer_head_mask is not None:
            if layer_head_mask.size() != (self.num_heads,):
                raise ValueError(
                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
                    f" {layer_head_mask.size()}"
                )
            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if output_attentions:
            # this operation is a bit awkward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to be reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None

        ################# our trick
        if len(attn_weights.shape) != 3: # assuming [n_head, seq_len, seq_len]
            print(f'Warning: len(attn_weights.shape) != 3: attn_weights.shape: {attn_weights.shape}')
        
        for head_index in range(n_head):
            # forward_attn_map, rev_attn_map = get_forward_and_reversed_attn(collector, layer_index, head_index)
            # attn_weights[0, head_index] += lr * rev_attn_map
            if attn_weights.shape[2] not in map_lens_to_n_shots:
                print(f'Error: could not find intervention for layer_index={layer_index}, head_index={head_index}')
                print(f'attn_weights.shape: {attn_weights.shape}')
                raise ValueError(f'could not find intervention for layer_index={layer_index}, head_index={head_index}')
            curr_n_shots = map_lens_to_n_shots[attn_weights.shape[2]]
            if intervention_with_reversed_attention:
                intervention_map = reversed_attention_intervention[curr_n_shots][layer_index][head_index]
            else:
                intervention_map = forward_attention_intervention[curr_n_shots][layer_index][head_index]
            attn_weights[head_index] += lr * intervention_map
        #################

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
    
        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)

        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
        # partitioned aross GPUs when using tensor-parallelism.
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped, past_key_value
    
    return forward

In [32]:
def create_aux_func_gpt2(self, layer_index, lr, intervention_with_reversed_attention=True):
    # print(f'Connecting to layer_index={layer_index}, lr={lr}')
    def _attn(
            # self, 
              query, key, value, attention_mask=None, head_mask=None):
            attn_weights = torch.matmul(query, key.transpose(-1, -2))

            if self.scale_attn_weights:
                attn_weights = attn_weights / torch.full(
                    [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
                )

            # Layer-wise attention scaling
            if self.scale_attn_by_inverse_layer_idx:
                attn_weights = attn_weights / float(self.layer_idx + 1)

            if not self.is_cross_attention:
                # if only "normal" attention layer implements causal mask
                query_length, key_length = query.size(-2), key.size(-2)
                causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
                mask_value = torch.finfo(attn_weights.dtype).min
                # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
                # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
                mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
                attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

            if attention_mask is not None:
                # Apply the attention mask
                attn_weights = attn_weights + attention_mask

            attn_weights = nn.functional.softmax(attn_weights, dim=-1)

            # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
            attn_weights = attn_weights.type(value.dtype)
            attn_weights = self.attn_dropout(attn_weights)

            # Mask heads if we want to
            if head_mask is not None:
                attn_weights = attn_weights * head_mask

            ################# our trick
            if attn_weights.shape[0] != 1:
                print(f'Warning: attn_weights.shape[0] != 1: attn_weights.shape: {attn_weights.shape}')
                
            for head_index in range(n_head):
                if attn_weights.shape[3] not in map_lens_to_n_shots:
                    print(f'Error: could not find intervention for layer_index={layer_index}, head_index={head_index}')
                    print(f'attn_weights.shape: {attn_weights.shape}')
                    raise ValueError(f'could not find intervention for layer_index={layer_index}, head_index={head_index}')
                curr_n_shots = map_lens_to_n_shots[attn_weights.shape[3]]
                if intervention_with_reversed_attention:
                    intervention_map = reversed_attention_intervention[curr_n_shots][layer_index][head_index]
                else:
                    intervention_map = forward_attention_intervention[curr_n_shots][layer_index][head_index]
                attn_weights[0, head_index] += lr * intervention_map
            #################

            attn_output = torch.matmul(attn_weights, value)

            return attn_output, attn_weights


    def forward(
            # self,
            hidden_states: Optional[Tuple[torch.FloatTensor]],
            layer_past: Optional[Tuple[torch.Tensor]] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            head_mask: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.Tensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = False,
            output_attentions: Optional[bool] = False,
        ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
            if encoder_hidden_states is not None:
                if not hasattr(self, "q_attn"):
                    raise ValueError(
                        "If class is used as cross attention, the weights `q_attn` have to be defined. "
                        "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                    )

                query = self.q_attn(hidden_states)
                key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
                attention_mask = encoder_attention_mask
            else:
                query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

            query = self._split_heads(query, self.num_heads, self.head_dim)
            key = self._split_heads(key, self.num_heads, self.head_dim)
            value = self._split_heads(value, self.num_heads, self.head_dim)

            if layer_past is not None:
                past_key, past_value = layer_past
                key = torch.cat((past_key, key), dim=-2)
                value = torch.cat((past_value, value), dim=-2)

            if use_cache is True:
                present = (key, value)
            else:
                present = None

            if self.reorder_and_upcast_attn:
                attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
            else:
                attn_output, attn_weights = _attn(query, key, value, attention_mask, head_mask)

            attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
            attn_output = self.c_proj(attn_output)
            attn_output = self.resid_dropout(attn_output)

            outputs = (attn_output, present)
            if output_attentions:
                outputs += (attn_weights,)

            return outputs  # a, present, (attentions)

    return forward


In [33]:
lr_to_examine = [3, 2, 1, 0.5, 0, -0.1, 0.25, -0.5, -1] + np.arange(-2, -10, -2).tolist() + np.arange(-10, -40, -5).tolist() + np.arange(-40, -121, -20).tolist()
if DEBUG_FLAG:
    lr_to_examine = [3, 0, -3, -15]

columns = ['n_shots', 'lr', 'top_1', 'top_2', 'top_3']

df_rev_path_out = f'{args.out_folder}/{output_prefix}_intervention_via_reverse_attn.csv'
df_forward_path_out = f'{args.out_folder}/{output_prefix}_intervention_via_forward_attn.csv'

In [34]:
meta_data = {
             'df_rev_path_out': df_rev_path_out,
             'df_forward_path_out': df_forward_path_out,
             'example_sentence': sentence_example,  # to verify we ran what we wanted
             'ds_train_len': len(dataset['train']),
             'ds_valid_len': len(dataset['valid']),
             'ds_test_len': len(dataset['test']),
             'range_of_percentages': lr_to_examine,
             'run_args': str(args),
             'debug_flag_on': DEBUG_FLAG}
meta_data_out = os.path.join(args.out_folder, f'{output_prefix}_exp_intervention_meta_data.json')
if args.check_if_result_already_exists and os.path.exists(meta_data_out):
    print(f'File already exists: {meta_data_out}. Terminating this run. If you wish to re-run, please remove the file or disable the flag --check_if_result_already_exists')
    print(f'Run args: {args}')
    print('Exit...')
    exit(0)

with open(meta_data_out, 'w') as f:
    json.dump(meta_data, f)

In [None]:
try:
    model = model.to('cpu')  # not really needed, but just to be sure
    del model
except:
    pass

# re-initialize the model just to make sure we used one that was not fine-tuned or altered in any way
model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_extra_args).eval().requires_grad_(False).to(device)

for df_path, rev_or_forward in [(df_rev_path_out, True), (df_forward_path_out, False)]:
    df = pd.DataFrame(data=[], columns=columns)

    print(f'Start intervention via {"reverse" if rev_or_forward else "forward"} attention...')
    # change all attn layers forward pass function to take the intervention on the attention maps
    # notice we do not need to reinitialize the model, as we overwrite the forward function at each iteration of the lr loop (no leakings between iterations)
    for lr_index, lr in enumerate(lr_to_examine):
        print(f'\n******** lr: {lr} ********')
        print('Connecting to layer', end='')
        for layer_index in range(n_layer):
            print(f' {layer_index}', end='')
            attn_layer = llm_utils.rgetattr(model, config.layer_attn_format.format(layer_index))
            if 'gpt2' in args.model_name.lower():
                attn_layer.forward = create_aux_func_gpt2(attn_layer, layer_index=layer_index, 
                                                          lr=lr, intervention_with_reversed_attention=rev_or_forward)
            elif 'opt' in args.model_name.lower():
                attn_layer.forward = create_aux_func_opt(attn_layer, layer_index=layer_index,
                                                         lr=lr, intervention_with_reversed_attention=rev_or_forward)
            else:
                print(f'Error: could not find model type for {args.model_name}')
                raise ValueError(f'could not find model type for {args.model_name}')
        print(f' --> all layer are ready!')
        # first, evaluate the model on the ICL task
        print(f"ICL Results:")
        # curr_res_icl = aux_eval_model(model, n_shots_list=n_shots_icl, annotate=lr_index % 5 == 0)
        curr_res_icl = exp_ra_utils.aux_eval_model(model, tokenizer=tokenizer,
                                                   model_name=args.model_name,
                                                   dataset=dataset,
                                                   metric_to_eval=metric_to_use,
                                                   n_shots_list=n_shots_icl,
                                                   prefixes=prefixes, separators=separators,
                                                   compute_ppl=True,
                                                   annotate=lr_index % 5 == 0)
        # print(f"ICL Results:")
        # print(res_icl)
        for n_shots, res in curr_res_icl.items():
            df.loc[len(df)] = [n_shots, lr, res['clean_topk'][0][1], res['clean_topk'][1][1], res['clean_topk'][2][1]]

        if lr_index % 5 == 0:
            print('Backup saving...')
            df.to_csv(df_path, index=False)

    df.to_csv(df_path, index=False)
    print(f'Finish intervention via {"reverse" if rev_or_forward else "forward"} attention. Info saved at {df_path}\n')

In [None]:
print('done')