In [1]:
import transformers
import os, re, json
import torch, numpy
from collections import defaultdict
from tqdm import tqdm
import jsonlines
from dsets import CounterFactDataset
from util.globals import DATA_DIR



In [2]:
!huggingface-cli login --token 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

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/rseetharaman_umass_edu/.cache/huggingface/token
Login successful


In [3]:
from transformers import AutoTokenizer

mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
llama_tokenizer = AutoTokenizer.from_pretrained("/work/pi_dhruveshpate_umass_edu/rseetharaman_umass_edu/llama-2-7b-hf")

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [4]:
counterfacts = CounterFactDataset(data_dir=DATA_DIR)  # Dataset of known facts

Loaded dataset with 21919 elements


In [6]:
def is_present_in_vocab(vocab, space_token, attr):
    for v,idx in vocab.items():

        # b'\\u0120' - phi-2's and falcon's space token
        # b'\\u2581' - mistral's and gemma's space token

        # Condition to check if the attribute is an atomic token, first unicode character denotes beginning of word.
        if v[0].encode("unicode_escape")== space_token and v[1:]==attr:
            return (True, idx)
    return (False, -1)

In [7]:
def get_entries_with_atomic_tokens(tokenizer, space_token, counterfacts):
    data = []
    for i,cf in enumerate(tqdm(counterfacts)):
        is_present, idx = is_present_in_vocab(tokenizer.vocab, space_token, cf['requested_rewrite']['target_new']['str'])
        if is_present:
            cf.update({"token_id": idx})
            data.append(cf)
    return data

In [8]:
def save_records(records, model_name):
    with jsonlines.open(f"data_counterfactual_{model_name}.jsonl", "w") as writer:
        for d in records:
            writer.write(d)

In [9]:
s="\u2581Rome"
byte=b'\\u2581'
s[0].encode("unicode_escape") == byte

True

In [10]:
mistral_entries = get_entries_with_atomic_tokens(mistral_tokenizer, b'\\u2581', counterfacts)

  0%|          | 0/21919 [00:00<?, ?it/s]

100%|██████████| 21919/21919 [05:42<00:00, 64.02it/s]


In [11]:
mistral_counterfacts = []

In [55]:
def prepare_counterfacts(entries):
    counterfactuals = []
    for entry in entries:
        m = entry['requested_rewrite']
        subject = m['subject']
        attribute = m['target_new']['str']
        full_answer = m['prompt'].format(subject)+ ' {}.'.format(attribute)
        prompt = m['prompt'].format(subject)+' '
        token_id = entry['token_id']
        counterfactuals.append({
            "subject": attribute,
            "attribute": attribute,
            "prompt": prompt,
            "token_id": token_id,
            "full_answer": full_answer
        })
    return counterfactuals

In [56]:
mistral_counterfacts = prepare_counterfacts(mistral_entries)

In [57]:
import random
for i, m in enumerate(mistral_counterfacts):
    
    all_prompts = random.sample(mistral_counterfacts[:i]+mistral_counterfacts[i+1:], 2)
    all_prompts = [a['full_answer'] for a in all_prompts]
    all_prompts.append(m['full_answer'])
    random.shuffle(all_prompts)
    
    mistral_counterfacts[i]['original_prompt'] = mistral_counterfacts[i]['prompt']
    context = '\n'.join(all_prompts)

    mistral_counterfacts[i]['prompt'] = f"""Context information is given below.
Answer solely based on the context.
Context: 
{context}
Question: {m['original_prompt']}"""

In [58]:
print(mistral_counterfacts[190]['prompt'])

Context information is given below.
Answer solely based on the context.
Context: 
Palladam is located in the country of Iran.
The Waking Eyes, that was started in Seattle.
Jari Kurri is a professional soccer.
Question: Jari Kurri is a professional 


In [59]:
len(mistral_counterfacts)

14679

In [60]:
save_records(mistral_counterfacts, 'mistral')

In [61]:
llama_entries = get_entries_with_atomic_tokens(llama_tokenizer, b'\\u2581', counterfacts)

  0%|          | 0/21919 [00:00<?, ?it/s]

100%|██████████| 21919/21919 [05:52<00:00, 62.13it/s]


In [64]:
len(llama_entries)

13862

In [65]:
llama_counterfacts = prepare_counterfacts(llama_entries)

In [66]:
import random
for i, m in enumerate(llama_counterfacts):
    
    all_prompts = random.sample(llama_counterfacts[:i]+llama_counterfacts[i+1:], 2)
    all_prompts = [a['full_answer'] for a in all_prompts]
    all_prompts.append(m['full_answer'])
    random.shuffle(all_prompts)
    
    llama_counterfacts[i]['original_prompt'] = llama_counterfacts[i]['prompt']
    context = '\n'.join(all_prompts)

    llama_counterfacts[i]['prompt'] = f"""Context information is given below.
Answer solely based on the context.
Context: 
{context}
Question: {m['original_prompt']}"""

In [67]:
llama_counterfacts[189]

{'subject': 'mayor',
 'attribute': 'mayor',
 'prompt': 'Context information is given below.\nAnswer solely based on the context.\nContext: \nBrugmann Mountains is located in Europe.\nJozef Tomko, who has the position of mayor.\nBright Promise is to debut on BBC.\nQuestion: Jozef Tomko, who has the position of ',
 'token_id': 9105,
 'full_answer': 'Jozef Tomko, who has the position of mayor.',
 'original_prompt': 'Jozef Tomko, who has the position of '}

In [68]:
save_records(llama_counterfacts, 'llama')

In [69]:
len(llama_tokenizer.vocab)

32000

In [70]:
len(mistral_tokenizer.vocab)

32000