In [None]:
from rosemary import jpt_parse_args, jpt_setup; jpt_setup()

import sys
sys.path.append('/dccstor/mit_fm/wpq/github/mitibm2023/external/open-instruct/')

In [None]:
import argparse
import os
import tqdm
import re
import json
import random
import time
import evaluate
from transformers import GPT2LMHeadModel
from eval.utils import generate_completions, load_hf_lm_and_tokenizer, query_openai_chat_model, KeyWordsCriteria
from eval.gsm.examplars import EXAMPLARS as GSM_EXAMPLARS
import torch

In [None]:

parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="data/mgsm")
parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.")
parser.add_argument("--save_dir", type=str, default="results/mgsm")
parser.add_argument("--model_name_or_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
parser.add_argument("--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here.")
parser.add_argument("--openai_engine", type=str, default=None, help="if specified, we will use the OpenAI API to generate the predictions.")
parser.add_argument("--n_shot", type=int, default=8, help="max number of examples to use for demonstration.")
parser.add_argument("--no_cot", action="store_true", help="If given, we're evaluating a model without chain-of-thought.")
parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
parser.add_argument("--load_in_8bit", action="store_true", help="load model in 8bit mode, which will reduce memory and speed up inference.")
parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
parser.add_argument("--use_chat_format", action="store_true", help="If given, the prompt will be encoded as a chat format with the roles in prompt.")

# model_name_or_path = 'gpt2-Large'
# model_name_or_path = '../results/gpt2-Large_human_mix'
# model_name_or_path = 't5-Large'
# model_name_or_path = 'google/flan-t5-large'
# model_name_or_path = '../results/google/flan-t5-small'
# model_name_or_path = 'huggyllama/llama-7b'
# model_name_or_path = '../results/baselines/mosaicml/mpt-7b'
model_name_or_path = '../results/baselines/mosaicml/mpt-7b'

cmd = f"""
    --data_dir ../data/eval/gsm/ \
    --save_dir {model_name_or_path}/eval/gsm/ \
    --max_num_examples 200 \
    --model_name_or_path {model_name_or_path} \
    --eval_batch_size 20 \
    --n_shot 8
"""

args = jpt_parse_args(parser, cmd)

# model_name_or_path and openai_engine cannot be both None or both not None.
assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."
args


In [None]:
exact_match = evaluate.load("exact_match")

random.seed(42)

print("Loading data...")
test_data = []
with open(os.path.join(args.data_dir, f"test.jsonl")) as fin:
    for line in fin:
        example = json.loads(line)
        test_data.append({
            "question": example["question"],
            "answer": example["answer"].split("####")[1].strip()
        })

# some numbers are in the `x,xxx` format, and we want to remove the comma
for example in test_data:
    example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
    assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"
    

if args.max_num_examples and len(test_data) > args.max_num_examples:
    test_data = random.sample(test_data, args.max_num_examples)


In [None]:
if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir, exist_ok=True)
    
global GSM_EXAMPLARS
if args.n_shot:
    if len(GSM_EXAMPLARS) > args.n_shot:
        GSM_EXAMPLARS = random.sample(GSM_EXAMPLARS, args.n_shot)
    demonstrations = []
    for example in GSM_EXAMPLARS:
        if args.no_cot:
            demonstrations.append(
                "Quesion: " + example["question"] + "\n" + "Answer: " + example["short_answer"]
            )
        else:
            demonstrations.append(
                "Question: " + example["question"] + "\n" + "Answer: " + example["cot_answer"]
            )
    prompt_prefix = "Answer the following questions.\n\n" + "\n\n".join(demonstrations) + "\n\n"
else:
    prompt_prefix = "Answer the following question.\n\n"

prompts = []
for example in test_data:
    if args.use_chat_format:
        prompt = "<|user|>\n" + prompt_prefix + "Question: " + example["question"].strip() + "\n<|assistant|>\n" + "Answer:"
    else:
        prompt = prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:"
    prompts.append(prompt)
    
print(prompts[0])

In [None]:
from eval.utils import load_hf_lm_and_tokenizer

model, tokenizer = load_hf_lm_and_tokenizer(
    model_name_or_path=args.model_name_or_path, 
    tokenizer_name_or_path=args.tokenizer_name_or_path, 
    load_in_8bit=args.load_in_8bit, 
    dtype=torch.bfloat16,
    gptq_model=args.gptq,
    use_fast_tokenizer=True,
    device_map='auto',
)

model.device, model.dtype

In [None]:
from eval.utils import generate_completions


# get the last token because the tokenizer may add space tokens at the start.
# wpq: t5 tokenizer strips `\n`. don't use `\n` as stop sequence. just generate to max length or encounters <\s>. 
new_line_token = tokenizer.encode("\n", add_special_tokens=False)
stop_id_sequences = [[new_line_token[-1]]] if new_line_token else None

if isinstance(model, GPT2LMHeadModel):
    # wpq: for gpt-2 model, need to enforce `max_length` constraints to avoid `position_id` index errors.
    generation_kwargs = {'max_length': model.config.max_position_embeddings} # 1024
else:
    # wpq: modify `max_new_tokens=512` to `256` by default
    generation_kwargs = {'max_new_tokens': 256}


t0 = time.time()
outputs = generate_completions(
    model=model,
    tokenizer=tokenizer,
    prompts=prompts[:20],
    batch_size=args.eval_batch_size,
    stop_id_sequences=stop_id_sequences,
    **generation_kwargs,
)

t = time.time()-t0
print(f'Time = {t:.2f}')
outputs[0]


# batch_siz = 20
# 4*60+44 / 20 = 14.2 / data
# 276.43 / 20 = 13.8 / data

In [None]:
outputs

In [None]:
# ## utils.generate_completions*


# if 'gpt2' in model_name_or_path:
#     generation_kwargs = {'max_length': model.config.max_position_embeddings} # 1024
# else:
#     generation_kwargs = {'max_new_tokens': 512}
# print(generation_kwargs)
    
# batch_size = 20
# i = 0
# batch_prompts = prompts[i:i+batch_size]
# batch_prompts = [
#     'Is the following sentence positive or negative: I hate the food',
#     'translate from english to german: the weather is great!']
# tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
# batch_input_ids = tokenized_prompts.input_ids
# attention_mask = tokenized_prompts.attention_mask

# print(batch_input_ids.shape, batch_input_ids.device)
# print(attention_mask.shape, attention_mask.device)
# model = model.to('cuda')
# batch_input_ids = batch_input_ids.cuda()
# attention_mask = attention_mask.cuda()
# print(model.device)


# from transformers import StoppingCriteriaList
# import time
# start = time.time()

# stopping_criteria = StoppingCriteriaList([KeyWordsCriteria(stop_id_sequences)]) if stop_id_sequences else None

# batch_outputs = model.generate(
#     input_ids=batch_input_ids,
#     attention_mask=attention_mask,
#     stopping_criteria=stopping_criteria,
#     **generation_kwargs,
# )
# end = time.time()
# end-start
# print(end-start)
# print(batch_outputs.shape)

# if stop_id_sequences:
#     for output_idx in range(batch_outputs.shape[0]):
#         for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
#             if any(batch_outputs[output_idx, token_idx: token_idx+len(stop_sequence)].tolist() == stop_sequence for stop_sequence in stop_id_sequences):
#                 batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
#                 break


# # 
# # 15s
# # 18.369792938232422

# num_return_sequences = 1
# # remove the prompt from the output
# # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
# # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
# # space is important for some tasks (e.g., code completion).
# batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
# batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
# # duplicate the prompts to match the number of return sequences
# batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
# batch_generations = [
#     output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
# ]

# print([len(x) for x in batch_outputs])
# print([len(x) for x in batch_prompts])

In [None]:
# generation_kwargs = {'max_new_tokens': 512}
# disable_tqdm = False
# stop_id_sequences = [[new_line_token]]
# batch_size=args.eval_batch_size



# generations = []
# if not disable_tqdm:
#     progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")

# num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
# for i in range(0, len(prompts), batch_size):
#     batch_prompts = prompts[i:i+batch_size]
#     tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
#     batch_input_ids = tokenized_prompts.input_ids
#     attention_mask = tokenized_prompts.attention_mask

#     if model.device.type == "cuda":
#         print(torch.cuda.is_available())
#         print(batch_input_ids.device, batch_input_ids.shape)
#         batch_input_ids = batch_input_ids.cuda()
#         attention_mask = attention_mask.cuda()

#     try:
#         batch_outputs = model.generate(
#             input_ids=batch_input_ids,
#             attention_mask=attention_mask,
#             stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None,
#             **generation_kwargs
#         )

#         # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
#         # so some outputs still have the stop sequence, which we need to remove.
#         if stop_id_sequences:
#             for output_idx in range(batch_outputs.shape[0]):
#                 for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
#                     if any(batch_outputs[output_idx, token_idx: token_idx+len(stop_sequence)].tolist() == stop_sequence for stop_sequence in stop_id_sequences):
#                         batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
#                         break

#         # remove the prompt from the output
#         # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
#         # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
#         # space is important for some tasks (e.g., code completion).
#         batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
#         batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
#         # duplicate the prompts to match the number of return sequences
#         batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
#         batch_generations = [
#             output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
#         ]
#     except Exception as e:
#         print("Error when generating completions for batch:")
#         print(batch_prompts)
#         print("Error message:")
#         print(e)
#         print("Use empty string as the completion.")
#         batch_generations = [""] * len(batch_prompts) * num_return_sequences

#     generations += batch_generations

#     # for prompt, generation in zip(batch_prompts, batch_generations):
#     #     print("========")
#     #     print(prompt)
#     #     print("--------")
#     #     print(generation)

#     if not disable_tqdm:
#         progress.update(len(batch_prompts)//num_return_sequences)

In [None]:

predictions = []
for output in outputs:
    # replace numbers like `x,xxx` with `xxxx`
    output = re.sub(r"(\d),(\d)", r"\1\2", output)
    numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
    if numbers:
        predictions.append(numbers[-1])
    else:
        predictions.append(output)

print("Calculating accuracy...")
targets = [example["answer"] for example in test_data]

em_score = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
print(f"Exact match : {em_score}")

predictions = [{
    "question": example["question"],
    "answer": example["answer"],
    "model_output": output,
    "prediction": pred
} for example, output, pred in zip(test_data, outputs, predictions)]

with open(os.path.join(args.save_dir, f"predictions.jsonl"), "w") as fout:
    for prediction in predictions:
        fout.write(json.dumps(prediction) + "\n") 

with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
    json.dump({
        "exact_match": em_score
    }, fout, indent=4)


In [None]:
# t5-Large, 0.02
# flan-t5-large, 0.03