In [12]:
import os
import hydra
import json
import random
import torch
import wandb
from omegaconf import DictConfig, OmegaConf
from transformers import GPT2Tokenizer, BertTokenizer, AutoTokenizer, BloomTokenizerFast, GPTNeoXTokenizerFast, LlamaTokenizer
from intervention_models.intervention_model import load_model
from tqdm import tqdm
import pickle
import sys
import yaml
from utils.number_utils import convert_to_words
import copy
import numpy as np

class DotDict(dict):
    """ Dot notation access to dictionary attributes """
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
yaml_file_path = "./conf/config.yaml"
with open(yaml_file_path, "r") as f:
    args = DotDict(yaml.safe_load(f))

file_name = args.data_dir
file_name += '/' + str(args.model)

n_shots = str(args.n_shots)
max_n = str(args.max_n)
representation = str(args.representation)
file_name += '/intervention_' + n_shots + '_shots_' + 'max_' + max_n + '_' + representation
file_name += '_further_templates' if args.extended_templates else ''
file_name += '_mpt2' if args.mpt_data_version_2 else ''
file_name += '.pkl'

path_to_data = os.path.join(args.data_dir, file_name)
with open(path_to_data, 'rb') as f:
    intervention_list = pickle.load(f)
print("Loaded data from", path_to_data)
if args.debug_run:
    intervention_list = intervention_list[:2]

Loaded data from /shared-network/shared/2024_ml_master/data/mosaicml/mpt-7b/intervention_1_shots_max_20_words_further_templates_mpt2.pkl


In [7]:
model=load_model(args)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

hf_device_map: {'transformer.wte': 0, 'transformer.emb_drop': 0, 'transformer.blocks.0': 0, 'transformer.blocks.1': 0, 'transformer.blocks.2': 0, 'transformer.blocks.3': 1, 'transformer.blocks.4': 1, 'transformer.blocks.5': 1, 'transformer.blocks.6': 1, 'transformer.blocks.7': 1, 'transformer.blocks.8': 2, 'transformer.blocks.9': 2, 'transformer.blocks.10': 2, 'transformer.blocks.11': 2, 'transformer.blocks.12': 2, 'transformer.blocks.13': 3, 'transformer.blocks.14': 3, 'transformer.blocks.15': 3, 'transformer.blocks.16': 3, 'transformer.blocks.17': 3, 'transformer.blocks.18': 4, 'transformer.blocks.19': 4, 'transformer.blocks.20': 4, 'transformer.blocks.21': 4, 'transformer.blocks.22': 4, 'transformer.blocks.23': 5, 'transformer.blocks.24': 5, 'transformer.blocks.25': 5, 'transformer.blocks.26': 5, 'transformer.blocks.27': 5, 'transformer.blocks.28': 6, 'transformer.blocks.29': 6, 'transformer.blocks.30': 6, 'transformer.blocks.31': 6, 'transformer.norm_f': 6}
MPTConfig {
  "_name_or_

In [8]:
tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b")
print(tokenizer)
print(tokenizer.tokenize("a 14", add_special_tokens=False))
print(tokenizer.decode([204592]))
print(tokenizer.pad_token_id)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


GPTNeoXTokenizerFast(name_or_path='mosaicml/mpt-7b', vocab_size=50254, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<|padding|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	50254: AddedToken("                        ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50255: AddedToken("                       ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50256: AddedToken("                      ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50257: AddedToken("                     ", rstrip=False, lstrip=False, single

In [13]:
import pandas as pd

intervention_data = []
for intervention in intervention_list:
    intervention_data.append(intervention.__dict__)

df = pd.DataFrame(intervention_data)


In [14]:
df.columns

Index(['op3_pos', 'operator_word', 'operands_alt', 'operands_base',
       'operator_pos', 'op2_pos', 'op1_pos', 'res_alt_tok', 'res_base_tok',
       'res_string', 'res_base_string', 'res_alt_string', 'device',
       'multitoken', 'is_llama', 'is_opt', 'is_bloom', 'is_mistral',
       'is_persimmon', 'representation', 'extended_templates', 'template_id',
       'n_vars', 'base_string', 'alt_string', 'few_shots', 'few_shots_t2',
       'equation', 'enc', 'len_few_shots', 'len_few_shots_t2',
       'base_string_tok_list', 'alt_string_tok_list', 'base_string_tok',
       'alt_string_tok', 'pred_alt_string', 'pred_res_alt_tok'],
      dtype='object')

In [15]:
equation_counts = df['equation'].value_counts()
print(equation_counts)


equation
({x} * {y} * {z})    25
({x}*{y} * {z})      24
({x}*{y}*{z})        21
({x}+{y}+{z})        11
({x}+{y} + {z})      11
(({x}-{y})*{z})       5
({x} -{y}-{z})        3
Name: count, dtype: int64


In [None]:
pred_alt_string = [intervention.pred_alt_string[1:] for intervention in intervention_list]
print(pred_alt_string)

['thirteen', 'seven', '', 'ten', 'eight', 'ten', 'fifteen', 'ten', 'fourteen', 'thirteen', 'nine', 'twelve', '', 'fourteen', '', '13', '13', 'seven', 'fifteen', 'nine', '', 'six', '', '', 'six', 'eight', '', '13', 'ten', '14', '13', '14', '12', 'eleven', 'six', '14', 'five', 'six', 'thirteen', 'seven', 'fifteen', 'thirteen', 'five', '3', '3', '6', '4', '6', '6', '6', '6', '6', '8', '6', '6', '6', '8', '3', '4', '8', '6', '3', '8', '6', '6', '6', '4', '6', '12', '12', '8', '12', '6', '6', '6', '9', '8', '8', '8', '6', '8', '6', '8', '8', '6', '6', '8', '9', 'four', 'two', 'two', 'two', 'two', 'four', 'two', 'six', 'two', 'six', 'four', 'six', 'two', 'twenty', 'twenty', '-', '(', 'twenty', '-', 'thirty', 'thirty', 'twenty', '18']


In [None]:
res_alt_string = [intervention.res_alt_string for intervention in intervention_list]
print(res_alt_string)

['nineteen', 'eleven', 'eighteen', 'fourteen', 'twelve', 'fourteen', 'nineteen', 'fourteen', 'sixteen', 'nineteen', 'eleven', 'eighteen', 'fifteen', 'sixteen', 'eighteen', 'fifteen', 'fifteen', 'nine', 'nineteen', 'eleven', 'eighteen', 'ten', 'seventeen', 'sixteen', 'nine', 'ten', 'eighteen', 'seventeen', 'twelve', 'sixteen', 'fifteen', 'sixteen', 'fourteen', 'thirteen', 'eight', 'eighteen', 'one', 'four', 'three', 'eight', 'five', 'eleven', 'three', 'twelve', 'twelve', 'eighteen', 'twelve', 'twelve', 'eighteen', 'eighteen', 'twelve', 'eighteen', 'sixteen', 'eighteen', 'twelve', 'eighteen', 'sixteen', 'twelve', 'sixteen', 'sixteen', 'eighteen', 'twelve', 'sixteen', 'twelve', 'twelve', 'eighteen', 'sixteen', 'twelve', 'eighteen', 'eighteen', 'eight', 'eighteen', 'twelve', 'twelve', 'twelve', 'eighteen', 'sixteen', 'sixteen', 'eight', 'twelve', 'eight', 'twelve', 'sixteen', 'eight', 'twelve', 'twelve', 'eight', 'eighteen', 'twelve', 'twelve', 'eight', 'eight', 'twelve', 'twelve', 'eight'

In [None]:
res_base_string = [intervention.res_base_string for intervention in intervention_list]
print(res_base_string)

['nineteen', 'eleven', 'eighteen', 'fourteen', 'twelve', 'fourteen', 'nineteen', 'fourteen', 'sixteen', 'nineteen', 'eleven', 'eighteen', 'fifteen', 'sixteen', 'eighteen', 'fifteen', 'fifteen', 'nine', 'nineteen', 'eleven', 'eighteen', 'ten', 'seventeen', 'sixteen', 'nine', 'ten', 'eighteen', 'seventeen', 'twelve', 'sixteen', 'fifteen', 'sixteen', 'fourteen', 'thirteen', 'eight', 'eighteen', 'one', 'four', 'three', 'eight', 'five', 'eleven', 'three', 'twelve', 'twelve', 'eighteen', 'twelve', 'twelve', 'eighteen', 'eighteen', 'twelve', 'eighteen', 'sixteen', 'eighteen', 'twelve', 'eighteen', 'sixteen', 'twelve', 'sixteen', 'sixteen', 'eighteen', 'twelve', 'sixteen', 'twelve', 'twelve', 'eighteen', 'sixteen', 'twelve', 'eighteen', 'eighteen', 'eight', 'eighteen', 'twelve', 'twelve', 'twelve', 'eighteen', 'sixteen', 'sixteen', 'eight', 'twelve', 'eight', 'twelve', 'sixteen', 'eight', 'twelve', 'twelve', 'eight', 'eighteen', 'twelve', 'twelve', 'eight', 'eight', 'twelve', 'twelve', 'eight'

In [None]:
equal = [pred_alt == res_base for pred_alt, res_base in zip(pred_alt_string, res_base_string)]
print(np.mean(equal))
print(equal)

0.0
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]


In [None]:
model = load_model(args)

You are using a model of type gpt_neox to instantiate a model of type gpt_neo. This is not supported for all configurations of models and can yield errors.


IndexError: list index out of range

In [None]:
model = load_model(args)
tokenizer_class = (GPT2Tokenizer if model.is_gpt2 or model.is_gptneo or model.is_opt else
                       BertTokenizer if model.is_bert else
                       AutoTokenizer if model.is_gptj or model.is_flan or model.is_pythia else
                       BloomTokenizerFast if model.is_bloom else
                       GPTNeoXTokenizerFast if model.is_neox else
                       LlamaTokenizer if model.is_llama else
                       None)
if not tokenizer_class:
    raise Exception(f'Tokenizer for model {args.model} not found')
if 'goat' in args.model:
    tokenizer_id = 'decapoda-research/llama-7b-hf'
else:
    tokenizer_id = args.model
tokenizer = tokenizer_class.from_pretrained(tokenizer_id, cache_dir=args.transformers_cache_dir)
model.create_vocab_subset(tokenizer, args)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

hf_device_map: {'gpt_neox.embed_in': 0, 'gpt_neox.layers.0': 0, 'gpt_neox.layers.1': 0, 'gpt_neox.layers.2': 0, 'gpt_neox.layers.3': 0, 'gpt_neox.layers.4': 1, 'gpt_neox.layers.5': 1, 'gpt_neox.layers.6': 1, 'gpt_neox.layers.7': 1, 'gpt_neox.layers.8': 1, 'gpt_neox.layers.9': 2, 'gpt_neox.layers.10': 2, 'gpt_neox.layers.11': 2, 'gpt_neox.layers.12': 2, 'gpt_neox.layers.13': 2, 'gpt_neox.layers.14': 3, 'gpt_neox.layers.15': 3, 'gpt_neox.layers.16': 3, 'gpt_neox.layers.17': 3, 'gpt_neox.layers.18': 3, 'gpt_neox.layers.19': 4, 'gpt_neox.layers.20': 4, 'gpt_neox.layers.21': 4, 'gpt_neox.layers.22': 4, 'gpt_neox.layers.23': 4, 'gpt_neox.layers.24': 5, 'gpt_neox.layers.25': 5, 'gpt_neox.layers.26': 5, 'gpt_neox.layers.27': 5, 'gpt_neox.layers.28': 5, 'gpt_neox.layers.29': 6, 'gpt_neox.layers.30': 6, 'gpt_neox.layers.31': 6, 'gpt_neox.layers.32': 6, 'gpt_neox.layers.33': 6, 'gpt_neox.layers.34': 7, 'gpt_neox.layers.35': 7, 'gpt_neox.final_layer_norm': 7, 'embed_out': 7}


In [None]:
layer_0 = model.get_hidden_states(1).attention
layer_0

GPTNeoXAttention(
  (rotary_emb): RotaryEmbedding()
  (query_key_value): Linear(in_features=5120, out_features=15360, bias=True)
  (dense): Linear(in_features=5120, out_features=5120, bias=True)
)

In [None]:
intervention = intervention_list[100]
print(intervention)
print(intervention.len_few_shots)
print(intervention.few_shots)
print(intervention.equation)
print(intervention.base_string_tok[0])
print(tokenizer.decode(intervention.base_string_tok[0][:]))

<interventions.Intervention object at 0x7faba81201f0>
9
four * two * two = sixteen. 
({x}*{y}*{z})
tensor([12496,   475,   767,   475,   767,   426, 25279,    15,   767,   475,
          767,   475,  1264,   426])
four * two * two = sixteen. two * two * three =


In [None]:

equation_position_operands={"({x}+{y}+{z})": [0, 2, 4],
"({x}+{y} + {z})": [3, 5, 7], 
"({x} -{y}-{z})": [3, 5, 7], 
"({x}*{y} * {z})": [3, 5, 7], 
"({x} * {y} * {z})": [3, 5, 7], 
"({x}*{y}*{z})": [0, 2, 4], 
"(({x}-{y})*{z})": [4, 6, 9]}

words_to_n = {convert_to_words(str(i)): i for i in range(args.max_n + 1)}

new_intervention = copy.deepcopy(intervention)
few_shot_result = tokenizer.decode(new_intervention.base_string_tok[0][new_intervention.len_few_shots - 3])[1:]
few_shot_result_int = words_to_n[few_shot_result]
new_result = 10
new_result_string = ' ' + str(new_result)
print(new_result_string)
new_result_enc = tokenizer.encode(new_result_string)[0]
print(new_result_enc)
new_intervention.base_string_tok[0][new_intervention.len_few_shots - 3] = new_result_enc
new_few_shot_string = tokenizer.decode(new_intervention.base_string_tok[0][:intervention.len_few_shots-1])
print(new_few_shot_string)
position_operands = equation_position_operands[intervention.equation]
print(position_operands)
symbols = ["alpha", "beta", "gamma", "delta", "epsilon", "zeta", 
           "eta", "theta", "iota", "number", "result", 
           "x", "y", "z", "a", "b", "c"]
for pos in position_operands:
     new_operand_str = random.choice(symbols)
     new_operand_str = ' ' + new_operand_str
     new_operand_enc = tokenizer.encode(new_operand_str)[0]
     new_intervention.base_string_tok[0][pos] = new_operand_enc
print(tokenizer.decode(new_intervention.base_string_tok[0]))

 10
884
four * two * two = 10.
[0, 2, 4]
 z * b * number = 10. two * two * three =


: 