In [26]:
import os
import random
from utils import EXP_ROOT, load_jsonl, dump_jsonl

dataset_root = os.path.join(EXP_ROOT, "datasets/kg-datasets/ja-0.5/eval_qa/99_adaxeval_datasets")

lang = "en"
filter_root = os.path.join(EXP_ROOT, "datasets/kg-datasets/ja-0.5/eval_qa/04_filtering")
inlang_requests = load_jsonl(os.path.join(filter_root, f"{lang}_requests.jsonl"))
inlang_generations = load_jsonl(os.path.join(filter_root, f"{lang}_generation.jsonl"))
crosslang_generations = load_jsonl(os.path.join(filter_root, f"{lang}_generation.crosslang.jsonl"))

In [27]:
matched_reqids = set([item["request_id"] for item in crosslang_generations if item["generation"]['score'] == 2])
instances = [item for item in inlang_requests if item["request_id"] in matched_reqids]
print("Number of matched request IDs:", len(matched_reqids))

Number of matched request IDs: 3237


In [29]:
## Knowledge generalization dataset
items = []
filtered_requests = set()
for item in instances:
    correct_answer = item["metadata"]['answer']
    distractors = item["metadata"]["distractors"]
    options = [correct_answer] + distractors
    
    if any(option == "" for option in options):
        print(f"Skipping item {item['request_id']} due to empty option.")
        filtered_requests.add(item['request_id'])
        continue

    if item["metadata"]["fill_in_blank"].count("[BLANK]") != 1:
        print(f"Skipping item {item['request_id']} due to incorrect prompt format.")
        filtered_requests.add(item['request_id'])
        continue

    random.seed(item["request_id"])
    random.shuffle(options)
    item = {
        "id": item["request_id"],
        "question": item["metadata"]["fill_in_blank"],
        "options": options,
        "answer_idx": options.index(correct_answer),
        "metadata": {
            "docid": item["request_id"].split("_sentid:")[0],
            "sentid": int(item["request_id"].split("_sentid:")[1]),
            "correct_answer": correct_answer,
            "triple": item["metadata"]['triple'],
            "sentence": item["metadata"]['sentence']
        }
    }
    items.append(item)

dump_jsonl(items, os.path.join(dataset_root, f"{lang}_knowledge_memorization.jsonl"))
print("Number of filtered requests:", len(items))

Skipping item kakoronbunshu@@32/5/32_5_448_sentid:7 due to incorrect prompt format.
Number of filtered requests: 3236


In [30]:
## Knowledge generalization dataset
items = []
for item in instances:
    correct_answer = item["metadata"]['answer']
    distractors = item["metadata"]["distractors"]
    options = [correct_answer] + distractors
    
    if item['request_id'] in filtered_requests:
        continue
    
    random.seed(item["request_id"])
    random.shuffle(options)
    item = {
        "id": item["request_id"],
        "question": item["metadata"]["question"],
        "options": options,
        "answer_idx": options.index(correct_answer),
        "metadata": {
            "docid": item["request_id"].split("_sentid:")[0],
            "sentid": int(item["request_id"].split("_sentid:")[1]),
            "correct_answer": correct_answer,
            "triple": item["metadata"]['triple'],
            "sentence": item["metadata"]['sentence']
        }
    }
    items.append(item)

dump_jsonl(items, os.path.join(dataset_root, f"{lang}_knowledge_generalization.jsonl"))
print("Number of filtered requests:", len(items))

Number of filtered requests: 3236
