In [None]:
import string
import re
import json
import sys
sys.path.append("../")
from evaluate_utils import calculate_short_answer_EM, rouge, bleu, select_candidate

version="v0915"
multi_docs="top10"
src_granularity=256
granularity=128

def eval_short_answer_EM(dataset, chat_model, reference_format, split, search_engine, rewrite_method):
    output_dir = f"../html_data/{dataset}/{chat_model}/{search_engine}"

    if dataset in ["asqa", "nq", "eli5"]:
        #. fine trim ratio 2/3
        coarse_context_window = {"2k": "3k", "4k": "6k", "8k": "12k", "16k": "24k", "128k": "192k"}[context_window]
    else:
        #. fine trim ratio 1/2
        coarse_context_window = {"2k": "4k", "4k": "8k", "8k": "16k", "16k": "32k", "128k": "192k"}[context_window]
    if reference_format in ["html-trim", "fill-chunk"]:
        output_file = f"{output_dir}/{chat_model}-{reference_format}-{rewrite_method}-{rerank_model}-{dataset}-{split}.jsonl"
    elif reference_format == "tree-gen":
        output_file = f"{output_dir}/{chat_model}-{reference_format}-{rewrite_method}-{version}-{granularity}-{dataset}-{split}.jsonl"
    elif reference_format == "tree-rerank":
        output_file = f"{output_dir}/{chat_model}-{reference_format}-{rewrite_method}-{rerank_model}-{granularity}-{dataset}-{split}.jsonl"
    elif reference_format in ["chunk-rerank-tree-gen", "tree-rerank-tree-gen"]:
        output_file = f"{output_dir}/{chat_model}-{reference_format}-{rewrite_method}-{rerank_model}-{src_granularity}to{granularity}-{coarse_context_window}-{version}-{dataset}-{split}.jsonl"
    elif reference_format in ["llmlingua", "bgelargeen","e5-mistral"]:
        output_file = f"{output_dir}/{chat_model}-{reference_format}-{rewrite_method}-{dataset}-{split}.jsonl"
    else:
        output_file = f"{output_dir}/{chat_model}-{reference_format}-{rewrite_method}-{dataset}-{split}.jsonl"
    print(f"evaluating file {output_file}")
    
    try:
        data_lines = [json.loads(l) for l in open(output_file)]
        generated_answers = [data_line[f"{chat_model}_{reference_format}" ] for data_line in data_lines]
        if dataset == "eli5":
            #. eval long answer
            if "answer" in data_lines[0]:
                gold_answers=[data_line["answer"] for data_line in data_lines]
            else:
                gold_answers=[data_line["long_answers"] for data_line in data_lines]
                
            selected_gold_answers=[]
            for gen, gold in tqdm.tqdm(zip(generated_answers, gold_answers), total=len(generated_answers)):
                selected_gold_answers.append(select_candidate(gen, gold))
            rouge_result=rouge.compute(predictions=generated_answers, references=selected_gold_answers)
            rouge_result={k: round(v * 100, 2) for k, v in rouge_result.items()}
            
            bleu_result=bleu.compute(predictions=generated_answers, references=gold_answers)
            return {**rouge_result, **bleu_result}
        
        if "answers" in data_lines[0]:
            answers = [data_line['answers'] for data_line in data_lines]
        elif "short_answers" in data_lines[0]:
            answers = [data_line['short_answers'] for data_line in data_lines]
        elif "answer" in data_lines[0]:
            answers = [data_line['answer'] for data_line in data_lines]
        else:
            raise NotImplementedError("answers not found in data_lines")
    
        
        exact_match = [calculate_short_answer_EM(generated_answer, gold_answers) for generated_answer, gold_answers in zip(generated_answers, answers)]
        hit1= round(sum([hit1["hit1"] for hit1 in exact_match])/len(exact_match)*100, 2),
        exact_match= round(sum([hit1["exact_match"] for hit1 in exact_match])/len(exact_match)*100, 2)
        print(f"chat_model: {chat_model}, reference_format: {reference_format}, dataset: {dataset}, split: {split}, hit1: {hit1}, exact_match: {exact_match}")
        return {
            "hit1": hit1,
            "exact_match": exact_match
        }
    except Exception as e:
        print(f"error evaluating file {output_file}, error: {e}")
        #  print stack trace
        import traceback
        traceback.print_exc()
        if dataset == "eli5":
            return {
                "rouge1": .0,
                "rouge2": .0,
                "rougeL": .0,
                "bleu": .0,
            }
        return {
            "exact_match": .0,
            "hit1": .0,
        }


In [None]:
import tqdm
#  eval all datasets
#  generate latex table report
context_window="4k"
reference_formats=["bm25", "bgelargeen", "e5-mistral", "llmlingua", "jinaai-reader", "tree-rerank-tree-gen"]
syn_names=["BM25", "BGE", "E5-Mistral", "LongLLMLingua", "JinaAI Reader", "HTML4RAG"]

# long context settings
# context_window="128k"
# reference_formats=["html", "raw-text", "markdown", "html-simple"]
# syn_names=["Vanilla HTML", "Raw Text", "Markdown", "HTML4RAG-Clean"]

datasets=["asqa", "hotpot-qa", "nq", "trivia-qa", "musique", "eli5"]

split="test"
search_engine="bing"
rewrite_method="slimplmqr"
rerank_model="bgelargeen"

import multiprocessing
res_list=multiprocessing.Manager().list([""]*len(datasets)*len(reference_formats))
processes = []

def append_res2markdown_table(lidx, *args, **kwargs):
    lidx=lidx
    res=eval_short_answer_EM(*args)
    if "hit1" in res:
        hit1, exact_match=res["hit1"], res["exact_match"]
        if isinstance(hit1, tuple):
            hit1=hit1[0]
        if isinstance(exact_match, tuple):
            exact_match=exact_match[0]
        if args[0] in ["hotpot-qa", "musique"]:
            res=f" {hit1} "
        else:
            res=f" {hit1} & {exact_match} "
    
        res_list[lidx]=res
    elif "rouge1" in res:
        res=" & ".join([f"{v:.2f}" for k, v in res.items()])
        res_list[lidx]=res
    else:
        bleu=res["bleu"] *100
        res=f" {bleu:.2f} "
        res_list[lidx]=res

pbar=tqdm.tqdm(total=len(datasets)*len(reference_formats))

chat_model=f"llama70b{context_window}"
# chat_model=f"llama8b{context_window}"

for i, dataset in enumerate(datasets):
    for j, reference_format in enumerate(reference_formats):
        lidx= i*len(reference_formats) + j
        p=multiprocessing.Process(target=append_res2markdown_table, args=(lidx, dataset, chat_model, reference_format, split, search_engine, rewrite_method))
        processes.append(p)
        p.start()
        pbar.update(1)
        if len(processes) >= 4:
            for p in processes:
                p.join()
            processes=[]
                
if processes:
    for p in processes:
        p.join()

pbar.close()

In [None]:
# create a latex table
import re
latex_table = ["Dataset & EM & Hit@1 & EM & EM & Hit@1 & EM & Hit@1 & EM & ROUGE-L & BLEU"]

longest_syn_name = max([len(syn_name) for syn_name in syn_names]) +2
for i in range(len(reference_formats)):
    latex_table.append(f"{syn_names[i]}"+" "*(longest_syn_name-len(syn_names[i])) + "&")
    
for i, dataset in enumerate(datasets):
    for j, reference_format in enumerate(reference_formats):
        lidx= i*len(reference_formats) + j
        latex_table[j+1] += f"{res_list[lidx]} &"
        #. replace .x with .x0, e.g. 5.5 with 5.50
        latex_table[j+1]=re.sub(r"(\d+\.\d)(?!\d)", r"\g<1>0", latex_table[j+1])
        

for line in latex_table:
    if line.endswith("&"):
        line=line[:-1]
    line += "\\\\"
    print(line)

In [None]:
#. create a markdown table
markdown_table = ["| Dataset | EM | Hit@1 | EM | EM | Hit@1 | EM | Hit@1 | EM | ROUGE-L | BLEU |",
                  "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |"]

longest_syn_name = max([len(syn_name) for syn_name in syn_names]) +2
for i in range(len(reference_formats)):
    markdown_table.append(f"| {syn_names[i]}"+" "*(longest_syn_name-len(syn_names[i])) + "|")
    
for i, dataset in enumerate(datasets):
    for j, reference_format in enumerate(reference_formats):
        lidx= i*len(reference_formats) + j
        markdown_table[j+1] += f"{res_list[lidx]} |"
        #. replace .x with .x0, e.g. 5.5 with 5.50
        markdown_table[j+1]=re.sub(r"(\d+\.\d)(?!\d)", r"\g<1>0", markdown_table[j+1])
        
for line in markdown_table:
    print(line)
    