In [1]:
from rosemary import jpt_parse_args, jpt_setup; jpt_setup()

import sys
sys.path.append('/dccstor/mit_fm/wpq/github/mitibm2023/external/open-instruct/')

  warn(f'Install `torch` for functionalities dependent on torch')


In [2]:
import argparse
import os
import re
import json
import tqdm
import glob
import torch
import random
import evaluate
from eval.utils import load_hf_lm_and_tokenizer, generate_completions, query_openai_chat_model
import torch

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="data/bbh")
parser.add_argument("--save_dir", type=str, default="results/bbh")
parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
parser.add_argument("--openai_engine", type=str, default=None, help="if specified, we will use the OpenAI API to generate the predictions.")
parser.add_argument("--no_cot", action="store_true", help="if specified, chain of thoughts will be removed from the prompts.")
parser.add_argument("--max_num_examples_per_task", type=int, default=None, help="maximum number of examples to evaluate per task.")
parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
parser.add_argument("--use_chat_format", action="store_true", help="If given, the prompt will be encoded as a chat format with the roles in prompt.")


model_name_or_path = '../results/baselines/gpt2-large'
model_name_or_path = '../results/baselines/gpt2'
# model_name_or_path = '../results/gpt2-Large_human_mix'
# model_name_or_path = '../results/huggyllama:llama-7b_human_mix-trainer_savebystep/checkpoint-200'
model_name_or_path = '../results/baselines/huggyllama/llama-7b/'

cmd = f"""
    --data_dir ../data/eval/bbh/ \
    --save_dir {model_name_or_path}/eval/bbh/ \
    --max_num_examples_per_task 40 \
    --model_name_or_path {model_name_or_path} \
    --eval_batch_size 10
"""

args = jpt_parse_args(parser, cmd)

# model_name_or_path and openai_engine cannot be both None or both not None.
assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."



In [5]:
random.seed(42)

data_size = 0
all_tasks = {}
task_files = glob.glob(os.path.join(args.data_dir, "bbh", "*.json"))
for task_file in tqdm.tqdm(task_files, desc="Loading tasks"):
    with open(task_file, "r") as f:
        task_name = os.path.basename(task_file).split(".")[0]
        all_tasks[task_name] = json.load(f)["examples"]
        print(f'{task_name}\t{len(all_tasks[task_name])}')
        data_size += len(all_tasks[task_name])
        if args.max_num_examples_per_task:
            all_tasks[task_name] = random.sample(all_tasks[task_name], args.max_num_examples_per_task)
data_size

Loading tasks:  33%|███▎      | 9/27 [00:00<00:00, 85.01it/s]

logical_deduction_seven_objects	250
ruin_names	250
dyck_languages	250
causal_judgement	187
sports_understanding	250
word_sorting	250
temporal_sequences	250
multistep_arithmetic_two	250
tracking_shuffled_objects_seven_objects	250
web_of_lies	250
formal_fallacies	250
salient_translation_error_detection	250
reasoning_about_colored_objects	250
disambiguation_qa	250
movie_recommendation	250
logical_deduction_five_objects	250
hyperbaton	250


Loading tasks: 100%|██████████| 27/27 [00:00<00:00, 84.00it/s]

geometric_shapes	250
navigate	250
object_counting	250
tracking_shuffled_objects_three_objects	250
snarks	178
penguins_in_a_table	146
date_understanding	250
boolean_expressions	250
logical_deduction_three_objects	250
tracking_shuffled_objects_five_objects	250





6511

In [6]:

all_prompts = {}
cot_prompt_files = glob.glob(os.path.join(args.data_dir, "cot-prompts", "*.txt"))
for cot_prompt_file in tqdm.tqdm(cot_prompt_files, desc="Loading prompts"):
    with open(cot_prompt_file, "r") as f:
        task_name = os.path.basename(cot_prompt_file).split(".")[0]
        task_prompt = "".join(f.readlines()[2:])
        if args.no_cot:
            prompt_fields = task_prompt.split("\n\n")
            new_prompt_fields = []
            for prompt_field in prompt_fields:
                if prompt_field.startswith("Q:"):
                    assert "So the answer is" in prompt_field, f"`So the answer is` not found in prompt field of {task_name}.txt."
                    assert "\nA:" in prompt_field, "`\nA:` not found in prompt field."
                    answer = prompt_field.split("So the answer is")[-1].strip()
                    question = prompt_field.split("\nA:")[0].strip()
                    new_prompt_fields.append(question + "\nA: " + answer)
                else:
                    new_prompt_fields.append(prompt_field)
            task_prompt = "\n\n".join(new_prompt_fields)
        all_prompts[task_name] = task_prompt

assert set(all_tasks.keys()) == set(all_prompts.keys()), "task names in task data and task prompts are not the same."


Loading prompts: 100%|██████████| 27/27 [00:00<00:00, 402.41it/s]


In [7]:
task_name = 'logical_deduction_seven_objects'
print(all_prompts[task_name])

A logical deduction task which requires deducing the order of a sequence of objects.

Q: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. In a golf tournament, there were three golfers: Amy, Eli, and Eve. Eve finished above Amy. Eli finished below Amy.
Options:
(A) Amy finished last
(B) Eli finished last
(C) Eve finished last
A: Let's think step by step.
(1) Eve finished above Amy: "(above) ? Eve ? Amy ? (below)".
(2) Eli finished below Amy: "(above) ? Amy ? Eli ? (below)".
(3) Combining (1) and (2) we get the following ordering: "(above) Eve Amy Eli (below)".
According to this ordering, the person who finished last (the one at the bottom of this list) is Eli.
Eli finished last. So the answer is (B).

Q: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a shelf, there are three boo

In [9]:

os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(os.path.join(args.save_dir, "predictions"), exist_ok=True)


if args.model_name_or_path:
    print("Loading model and tokenizer...")
    model, tokenizer = load_hf_lm_and_tokenizer(
        model_name_or_path=args.model_name_or_path, 
        tokenizer_name_or_path=args.tokenizer_name_or_path, 
        load_in_8bit=args.load_in_8bit, 
        gptq_model=args.gptq,
        device_map='auto',
    )
    
model.device, model.dtype

Loading model and tokenizer...


Using pad_token, but it is not set yet.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

(device(type='cuda', index=0), torch.bfloat16)

In [15]:
# get an idea of prompt length for the tasks.
# - no   cot: prompts are pretty short!
# - with cot: some prompt >2k sequence length, potentially a problem.

print('prompt lengths:')
[(k, tokenizer(x, return_tensors='pt').input_ids.shape[1]) 
       for k, x in all_prompts.items()]




prompt lengths:


[('date_understanding', 630),
 ('logical_deduction_five_objects', 703),
 ('movie_recommendation', 716),
 ('multistep_arithmetic_two', 986),
 ('object_counting', 535),
 ('word_sorting', 813),
 ('hyperbaton', 990),
 ('sports_understanding', 227),
 ('boolean_expressions', 519),
 ('tracking_shuffled_objects_seven_objects', 803),
 ('ruin_names', 1076),
 ('tracking_shuffled_objects_three_objects', 803),
 ('causal_judgement', 867),
 ('reasoning_about_colored_objects', 796),
 ('logical_deduction_seven_objects', 703),
 ('temporal_sequences', 1082),
 ('salient_translation_error_detection', 1488),
 ('tracking_shuffled_objects_five_objects', 803),
 ('geometric_shapes', 2329),
 ('disambiguation_qa', 969),
 ('dyck_languages', 946),
 ('navigate', 750),
 ('formal_fallacies', 1342),
 ('web_of_lies', 851),
 ('snarks', 808),
 ('penguins_in_a_table', 846),
 ('logical_deduction_three_objects', 703)]

In [73]:
task_name = next(iter(all_tasks.keys()))
task_examples = all_tasks[task_name]
prompt = all_prompts[task_name]
print(prompt)
print()
print(task_examples[0])

A logical deduction task which requires deducing the order of a sequence of objects.

Q: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. In a golf tournament, there were three golfers: Amy, Eli, and Eve. Eve finished above Amy. Eli finished below Amy.
Options:
(A) Amy finished last
(B) Eli finished last
(C) Eve finished last
A: Let's think step by step.
(1) Eve finished above Amy: "(above) ? Eve ? Amy ? (below)".
(2) Eli finished below Amy: "(above) ? Amy ? Eli ? (below)".
(3) Combining (1) and (2) we get the following ordering: "(above) Eve Amy Eli (below)".
According to this ordering, the person who finished last (the one at the bottom of this list) is Eli.
Eli finished last. So the answer is (B).

Q: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a shelf, there are three boo

In [74]:
# task_perf = eval_hf_model(
#     args, 
#     model, 
#     tokenizer, 
#     task_examples, 
#     prompt, 
#     save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl")
# )

# performance[task_name] = task_perf
# print(f"Task {task_name} - EM: {task_perf}")

examples = task_examples
task_prompt = prompt
save_path=os.path.join(args.save_dir, "predictions", f"{task_name}.jsonl")


In [75]:
targets = [example["target"] for example in examples]
if save_path:
    fout = open(save_path, "w")
save_path

'../results/huggyllama:llama-7b_human_mix-trainer_savebystep/checkpoint-200/eval/bbh/predictions/logical_deduction_seven_objects.jsonl'

In [76]:

prompts = []
for example in examples:
    if args.use_chat_format:
        prompt = "<|user|>\n" + task_prompt.strip() + "\n\nQ: " + example["input"] + "\n<|assistant|>\nA:"
    else:
        prompt = task_prompt.strip() + "\n\nQ: " + example["input"] + "\nA:"
    prompts.append(prompt)

print(prompts[0])

A logical deduction task which requires deducing the order of a sequence of objects.

Q: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. In a golf tournament, there were three golfers: Amy, Eli, and Eve. Eve finished above Amy. Eli finished below Amy.
Options:
(A) Amy finished last
(B) Eli finished last
(C) Eve finished last
A: Let's think step by step.
(1) Eve finished above Amy: "(above) ? Eve ? Amy ? (below)".
(2) Eli finished below Amy: "(above) ? Amy ? Eli ? (below)".
(3) Combining (1) and (2) we get the following ordering: "(above) Eve Amy Eli (below)".
According to this ordering, the person who finished last (the one at the bottom of this list) is Eli.
Eli finished last. So the answer is (B).

Q: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a shelf, there are three boo

In [38]:
for s in [' ', '\n', '\n\n', 'A: (B)\n\n']:
    print(s, tokenizer.encode(s, add_special_tokens=False))

  [259]

 [29871, 13]


 [29871, 13, 13]
A: (B)

 [319, 29901, 313, 29933, 29897, 13, 13]


In [45]:
l = []
for s in [[13], [13,13], [29871,13]]:
    l.append(tokenizer.decode(s, add_special_tokens=False))
l

['\n', '\n\n', '\n']

In [77]:
if args.no_cot:
    # get the last token because the tokenizer may add space tokens at the start.
    stop_id_sequences = tokenizer.encode("\n\n", add_special_tokens=False)
    stop_id_sequences = [stop_id_sequences[-2:]] if stop_id_sequences else None
else:
    # let's not use the stop sequence for cot now since it's too inefficient when the generation is long. 
    # instead, we'll do some post-processing to extract the answer.
    stop_sequnce = None

from transformers import GPT2LMHeadModel
if isinstance(model, GPT2LMHeadModel):
    # wpq: for gpt-2 model, need to enforce `max_length` constraints to avoid `position_id` index errors.
    generation_kwargs = {'max_length': model.config.max_position_embeddings} # 1024
else:
    # wpq: modify `max_new_tokens=512` to `128` for faster generation.
    # for non-cot multiple choice answers, e.g., ' (G).' requires just 5 tokens
    generation_kwargs = {'max_new_tokens': 10 if args.no_cot else 256}

outputs = generate_completions(
    model=model,
    tokenizer=tokenizer,
    prompts=prompts,
    batch_size=args.eval_batch_size if args.eval_batch_size else 1,
    stop_id_sequences=stop_id_sequences,
    **generation_kwargs,
)

Generating Completions: 100%|██████████| 40/40 [03:45<00:00,  5.64s/it]


[1, 29871, 313, 29954, 467]

In [63]:
exact_match = evaluate.load("exact_match")

predictions = []
for example, output in zip(examples, outputs):
    example["raw_output"] = output

    # only keep the first part of the output - this is mainly for vanilla language models.
    output = output.strip().split("\n\n")[0].strip()

    # extract the first answer after `So the answer is` and before the next period.
    # if there is no such answer, we will just use the raw output.
    results = re.search(r"So the answer is (.*?)\.", output)
    if results:
        prediction = results.group(1).strip()
    else:
        prediction = output.strip()

    example["prediction"] = prediction
    predictions.append(prediction)
    if save_path:
        fout.write(json.dumps(example) + "\n")        

assert len(predictions) == len(targets), "number of predictions and targets are not the same."
task_perf = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
task_perf

ValueError: I/O operation on closed file.

In [65]:
performance = {}
performance[task_name] = .1

In [66]:

with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
    performance["average_exact_match"] = sum(performance.values()) / len(performance)
    print(f"Average EM: {performance['average_exact_match']}")
    json.dump(performance, fout, indent=4)


Average EM: 0.1
