In [None]:
import re
import string
def normalize(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_f1(prediction, ground_truth):
    if prediction is None:
        return 0.0
    prediction_tokens = normalize(prediction).split()
    ground_truth_tokens = normalize(ground_truth).split()

    common = set(prediction_tokens) & set(ground_truth_tokens)
    num_same = len(common)

    if num_same == 0:
        return 0.0

    precision = num_same / len(prediction_tokens)
    recall = num_same / len(ground_truth_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    return f1

def exact_match_score(prediction, ground_truth):
    if prediction is None:
        return 0.0
    return int(normalize(prediction) == normalize(ground_truth))

def evaluate(predictions):
    total = len(predictions)
    f1_total = 0
    em_total = 0

    for item in predictions:
        # if item['pred_answer_ori'] == None:
        pred = item['pred_answer']
        # else:
        #     pred = item['pred_answer_ori']
        gts = item['gt']

        # 若gt是str，统一转换为列表处理
        if isinstance(gts, str):
            gts = [gts]

        f1 = max([compute_f1(pred, gt) for gt in gts])
        em = max([exact_match_score(pred, gt) for gt in gts])
        if em == 1:
            f1 = 1

        f1_total += f1
        em_total += em

    return {
        "avg_f1": f1_total / total if total > 0 else 0,
        "avg_em": em_total / total if total > 0 else 0
    }

In [None]:
import json

# results list：
json_file_paths = [
    "7b_step200_musique_web_n4_0.json",
    "7b_step200_musique_web_n4_1.json",
    "7b_step200_musique_web_n4_2.json",
    "7b_step200_musique_web_n4_3.json",
    "7b_step200_musique_web_n4_4.json",
    "7b_step200_musique_web_n4_5.json",
    "7b_step200_musique_web_n4_6.json",
    "7b_step200_musique_web_n4_7.json",
]

# combined to one file
merged_data = []
for path in json_file_paths:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
        merged_data.extend(data)

print(len(merged_data))
# save
with open("7b_step200_musique_web_n4.json", "w", encoding="utf-8") as f:
    json.dump(merged_data, f, ensure_ascii=False, indent=2)


In [None]:
import json
with open('./evaluation_results/musique/7b_step200_musique_web_n4.json', 'r') as f:
    combine_results = json.load(f)
print(len(combine_results))

count_none = 0
for item in combine_results:
    if item['pred_answer'] == None:
        count_none += 1
print(count_none)
results = evaluate(combine_results)
print("Average F1:", results['avg_f1'])
print("Average EM:", results['avg_em'])