In [1]:
from rosemary import jpt_parse_args, jpt_setup; jpt_setup()

import platform
import sys
sys.path.append('/dccstor/mit_fm/wpq/github/mitibm2023/external/open-instruct/'
                if platform.uname().processor == 'x86_64' 
                else '/gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/open-instruct/')

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

  warn(f'Install `torch` for functionalities dependent on torch')


In [2]:
import argparse
import os
import json
import random
import evaluate
import numpy as np
from eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model
from eval.tydiqa.run_eval import encoding_templates_with_context, encoding_templates_without_context
from transformers import GPT2LMHeadModel

[2023-08-10 10:43:02,290] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="data/xorqa/")
parser.add_argument("--max_num_examples_per_lang", type=int, default=None, help="maximum number of examples per language to evaluate.")
parser.add_argument("--n_shot", type=int, default=1, help="number of examples to use for few-shot evaluation.")
parser.add_argument("--no_context", action="store_true", help="If given, we're evaluating a model without the gold context passage.")
parser.add_argument("--max_context_length", type=int, default=512, help="maximum number of tokens in the context passage.")
parser.add_argument("--save_dir", type=str, default="results/tydiqa/")
parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
parser.add_argument("--openai_engine", type=str, default=None, help="if specified, we will use the OpenAI API to generate the predictions.")
parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
parser.add_argument("--use_chat_format", action="store_true", help="If given, the prompt will be encoded as a chat format with the roles in prompt.")

model_name_or_path = '../results/baselines/gpt2-medium'
model_name_or_path = '../results/baselines/huggyllama/llama-7b/'
model_name_or_path = "../results/baselines/EleutherAI/pythia-1.4b"

# python -m eval.tydiqa.run_eval \
#             --data_dir data/eval/tydiqa \
#             --n_shot 1 \
#             --max_num_examples_per_lang 100 \
#             --max_context_length 512 \
#             --model_name_or_path "{model_name_or_path}" \
#             --save_dir "{save_dir}" \
#             --eval_batch_size {batch_size} \
#             {'--no_context' if no_context else ''}
#             {'--use_chat_format' if use_chat_format else ''}
#             {'--load_in_8bit' if load_in_8bit else ''}
        
#     --max_num_examples_per_lang 100 \
cmd = f"""
    --data_dir ../data/eval/tydiqa/ \
    --n_shot 1 \
    --max_context_length 400 \
    --model_name_or_path {model_name_or_path} \
    --save_dir {model_name_or_path}/eval/tydiqa/ \
    --eval_batch_size 10 \
    --use_chat_format
"""

args = jpt_parse_args(parser, cmd)

# model_name_or_path and openai_engine cannot be both None or both not None.
assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."

args

Namespace(data_dir='../data/eval/tydiqa/', max_num_examples_per_lang=None, n_shot=1, no_context=False, max_context_length=400, save_dir='../results/baselines/EleutherAI/pythia-1.4b/eval/tydiqa/', model_name_or_path='../results/baselines/EleutherAI/pythia-1.4b', tokenizer_name_or_path=None, openai_engine=None, eval_batch_size=10, load_in_8bit=False, gptq=False, use_chat_format=True)

In [4]:

random.seed(42)

print("Loading data...")


test_data = []
with open(os.path.join(args.data_dir, "tydiqa-goldp-v1.1-dev.json")) as fin:
    dev_data = json.load(fin)
    for article in dev_data["data"]:
        for paragraph in article["paragraphs"]:
            for qa in paragraph["qas"]:
                example = {
                    "id": qa["id"], 
                    "lang": qa["id"].split("-")[0],
                    "context": paragraph["context"],
                    "question": qa["question"],
                    "answers": qa["answers"]
                }
                test_data.append(example)
data_languages = set([example["lang"] for example in test_data])
if args.max_num_examples_per_lang:
    sampled_examples = []
    for lang in data_languages:
        examples_for_lang = [example for example in test_data if example["lang"] == lang]
        if len(examples_for_lang) > args.max_num_examples_per_lang:
            examples_for_lang = random.sample(examples_for_lang, args.max_num_examples_per_lang)
        sampled_examples += examples_for_lang
    test_data = sampled_examples
    
print(f"Loaded {len(test_data)} examples from {len(data_languages)} languages: {data_languages}")

    
from collections import Counter
print(Counter([x['lang'] for x in test_data]))

test_data[0]

Loading data...
Loaded 5077 examples from 9 languages: {'telugu', 'arabic', 'english', 'finnish', 'swahili', 'korean', 'bengali', 'indonesian', 'russian'}
Counter({'arabic': 921, 'russian': 812, 'finnish': 782, 'telugu': 669, 'indonesian': 565, 'swahili': 499, 'english': 440, 'korean': 276, 'bengali': 113})


{'id': 'arabic-2387335860751143628-1',
 'lang': 'arabic',
 'context': 'أقيمت البطولة 21 مرة، شارك في النهائيات 78 دولة، وعدد الفرق التي فازت بالبطولة حتى الآن 8 فرق، ويعد المنتخب البرازيلي الأكثر تتويجاً بالكأس حيث فاز بها 5 مرات أعوام: 1958، 1962، 1970، 1994 و2002. يليه المنتخب الإيطالي الذي أحرزها 4 مرات في أعوام: 1934، 1938، 1982 و2006، بالمشاركة مع المنتخب الألماني الذي حققها 4 مرات أيضاً أعوام: 1954، 1974 و1990 و2014، ثم الأوروغواي والأرجنتين وفرنسا برصيد بطولتين. بينما أحرزت منتخبات إنجلترا وإسبانيا البطولة مرة واحدة.',
 'question': 'كم عدد مرات فوز الأوروغواي ببطولة كاس العالم لكرو القدم؟',
 'answers': [{'text': 'بطولتين', 'answer_start': 394}]}

In [5]:

if args.n_shot > 0:
    train_data_for_langs = {lang: [] for lang in data_languages}
    with open(os.path.join(args.data_dir, "tydiqa-goldp-v1.1-train.json")) as fin:
        train_data = json.load(fin)
        for article in train_data["data"]:
            for paragraph in article["paragraphs"]:
                for qa in paragraph["qas"]:
                    lang = qa["id"].split("-")[0]
                    if lang in data_languages:
                        example = {
                            "id": qa["id"],
                            "lang": lang,
                            "context": paragraph["context"],
                            "question": qa["question"],
                            "answers": qa["answers"]
                        }
                        train_data_for_langs[lang].append(example)
        for lang in data_languages:
            # sample n_shot examples from each language
            train_data_for_langs[lang] = random.sample(train_data_for_langs[lang], args.n_shot)
    # assert that we have exactly n_shot examples for each language
    assert all([len(train_data_for_langs[lang]) == args.n_shot for lang in data_languages])


# assert we have templates for all data languages
assert all([lang in encoding_templates_with_context.keys() for lang in data_languages])
train_data_for_langs['telugu']

[{'id': 'telugu--4358566378423124470-15',
  'lang': 'telugu',
  'context': '2001 వ.సంవత్సరం జనాభా లెక్కల ప్రకారం గ్రామ జనాభా 3288.[1] ఇందులో పురుషుల సంఖ్య 1675, మహిళల సంఖ్య 1613, గ్రామంలో నివాసగృహాలు 802 ఉన్నాయి. చింతపర్రు పశ్చిమ గోదావరి జిల్లా, పాలకొల్లు మండలంలోని గ్రామం. ఇది మండల కేంద్రమైన పాలకొల్లు నుండి 4 కి. మీ. దూరంలో ఉంది. 2011 భారత జనగణన గణాంకాల ప్రకారం ఈ గ్రామం 977 ఇళ్లతో, 3440 జనాభాతో 238 హెక్టార్లలో విస్తరించి ఉంది. గ్రామంలో మగవారి సంఖ్య 1775, ఆడవారి సంఖ్య 1665. షెడ్యూల్డ్ కులాల సంఖ్య 899 కాగా షెడ్యూల్డ్ తెగల సంఖ్య 44. గ్రామం యొక్క జనగణన లొకేషన్ కోడ్ 588776[2].పిన్ కోడ్: 534250.',
  'question': 'చింతపర్రు గ్రామ వైశాల్యం ఎంత?',
  'answers': [{'answer_start': 322, 'text': '238 హెక్టార్ల'}]}]

In [14]:

if args.model_name_or_path:
    print("Loading model and tokenizer...")
    model, tokenizer = load_hf_lm_and_tokenizer(
        model_name_or_path=args.model_name_or_path, 
        tokenizer_name_or_path=args.tokenizer_name_or_path, 
        load_in_8bit=args.load_in_8bit, 
        gptq_model=args.gptq,
        use_fast_tokenizer=True,
    )
else:
    import tiktoken
    tokenizer = tiktoken.get_encoding("cl100k_base")

Using pad_token, but it is not set yet.


Loading model and tokenizer...


In [134]:

# reduce context length to max_context_length
if args.max_context_length:
    for example in test_data:
        tokenized_context = tokenizer.encode(example["context"])
        if len(tokenized_context) > args.max_context_length:
            example["context"] = tokenizer.decode(tokenized_context[:args.max_context_length])
    if args.n_shot > 0:
        for lang in data_languages:
            for example in train_data_for_langs[lang]:
                tokenized_context = tokenizer.encode(example["context"])
                if len(tokenized_context) > args.max_context_length:
                    example["context"] = tokenizer.decode(tokenized_context[:args.max_context_length])

In [135]:

if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir, exist_ok=True)
    

# wpq: for gpt-2 model, need to enforce `max_length` constraints to avoid `position_id` index errors.
if isinstance(model, GPT2LMHeadModel):
    max_input_seq_len = model.config.max_position_embeddings - 50
else:
    max_input_seq_len = 2048 - 50

In [143]:

prompts = []
num_use_less_shots = 0
for example in test_data:
    lang = example["lang"]
    
    # wpq: use less demonstrations if exceeds context lengths.
    for n_shot in list(range(args.n_shot+1)[::-1]):
        if args.no_context:
            prompt, q_template, a_template = encoding_templates_without_context[lang]
            p_template = ""
        else:
            prompt, p_template, q_template, a_template = encoding_templates_with_context[lang]

        prompt += "\n\n"

        if n_shot > 0:
            formatted_demo_examples = []
            for train_example in train_data_for_langs[lang]:
                if args.no_context:
                    formatted_demo_examples.append(
                        q_template + " " + train_example["question"] + "\n" + a_template + " " + train_example["answers"][0]["text"]
                    )
                else:
                    formatted_demo_examples.append(
                        p_template + " " + train_example["context"] + "\n" + q_template + " " + train_example["question"] + "\n" + a_template + " " + train_example["answers"][0]["text"]
                    )
            prompt += "\n\n".join(formatted_demo_examples) + "\n\n"

        if args.no_context:
            prompt += q_template + " " + format(example["question"]) + "\n"
        else:
            prompt += p_template + " " + format(example["context"]) + "\n" + q_template + " " + format(example["question"]) + "\n"

        if args.use_chat_format:
            prompt = "<|user|>\n" + prompt.strip() + "\n<|assistant|>\n" + a_template
        else:
            prompt += a_template
            
        tokenized_prompt_len = len(tokenizer(prompt, add_special_tokens=False)['input_ids'])
        if tokenized_prompt_len <= max_input_seq_len:
            break
    if n_shot != args.n_shot:
        num_use_less_shots += 1
        print(f'n_shot: {args.n_shot} -> {n_shot}')

    prompts.append(prompt)
    
print(f'frac test_data use less shots: {num_use_less_shots / len(test_data):.2f}')

n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 

n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 

n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 -> 0
n_shot: 1 

0.26708686232026785

In [137]:
args.n_shot=1
list(range(args.n_shot+1)[::-1])

[1, 0]

In [138]:
# max_input_seq_len
974

974

In [145]:

l = tokenizer(prompts).input_ids
lens = [len(x) for x in l]


In [146]:
max(lens)

974

In [141]:
np.quantile(lens, .8), np.quantile(lens, .5)


(735.0, 560.0)