In [1]:
import pandas as pd
from pywikidata import Entity
from tqdm.auto import tqdm
import ujson
import itertools
import random
import datasets
from wd_api import get_wd_search_results
from pathlib import Path

tqdm.pandas()

In [2]:
valid_predictions = pd.read_json('./mintaka_mixtral_50_valid_predictions.json')
test_predictions = pd.read_json('./mintaka_mixtral_50_test_predictions.json')

ds = datasets.load_dataset("AmazonScience/mintaka")

valid_df = pd.merge(
    valid_predictions,
    ds['validation'].to_pandas(),
    on=['id', 'question'],
)

test_df = pd.merge(
    test_predictions.rename(columns={'mixtral answers': 'model_answer'}),
    ds['test'].to_pandas(),
    on=['id', 'question'],
)

valid_df.head()

No config specified, defaulting to: mintaka/en
Found cached dataset mintaka (/root/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d)


  0%|          | 0/3 [00:00<?, ?it/s]

Unnamed: 0,id,question,model_answer,lang,answerText,category,complexityType,questionEntity,answerEntity
0,9ace9041,What is the fourth book in the Twilight series?,"[Breaking Dawn, Breaking Dawn, Breaking Dawn, ...",en,Breaking Dawn,books,ordinal,"[{'name': 'Q44523', 'entityType': 'entity', 'l...","[{'name': 'Q53945', 'label': 'Breaking Dawn'}]"
1,88bdb808,How many games are in the Uncharted series?,"[4, 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 6, 7, 10,...",en,6,videogames,count,"[{'name': 'Q1064135', 'entityType': 'entity', ...","[{'name': 'Q17150', 'label': 'Uncharted: Drake..."
2,ecfd471d,"As of 2015, which group held the record for th...","[U2, U2, U2, The Recording Academy's Grammy Ha...",en,U2,music,generic,"[{'name': 'Q41254', 'entityType': 'entity', 'l...","[{'name': 'Q396', 'label': 'U2'}]"
3,5d8dc3ff,Who is the oldest person to ever win an Academ...,"[James Ivory, James Ivory, James Ivory, James ...",en,James Ivory,movies,superlative,"[{'name': 'Q19020', 'entityType': 'entity', 'l...","[{'name': 'Q51577', 'label': 'James Ivory'}]"
4,118daa85,Which Mario Kart games do not feature Link as ...,"[Super Mario Kart, Mario Kart 64, Mario Kart: ...",en,"Super Mario Kart, Mario Kart 64, Mario Kart: S...",videogames,difference,"[{'name': 'Q188196', 'entityType': 'entity', '...","[{'name': 'Q1061560', 'label': 'Super Mario Ka..."


In [2]:
def label_to_entity(label: str, top_k: int = 3) -> list:
    """label_to_entity method to  linking label to WikiData entity ID
    by using elasticsearch Wikimedia public API
    Supported only English language (en)
    
    Parameters
    ----------
    label : str
        label of entity to search
    top_k : int, optional
        top K results from WikiData, by default 3

    Returns
    -------
    list[str] | None
        list of entity IDs or None if not found
    """
    try:
        elastic_results = get_wd_search_results(label, top_k, language='en')[:top_k]
    except:
        elastic_results = []
    
    try:
        elastic_results.extend(
            get_wd_search_results(label.replace("\"", "").replace("\'", "").strip(), top_k, language='en')[:top_k]
        )
    except:
        return None
    
    return list(dict.fromkeys(elastic_results).keys())[:top_k]


def data_to_subgraphs(df):
    for _, row in tqdm(df.iterrows(), total=df.index.size):
        if row['complexityType'] not in ['count', 'yesno']:
            question_entity_ids = [e['name'] for e in row['questionEntity'] if e['entityType'] == 'entity']
            
            for candidate_label in dict.fromkeys(row['model_answer']).keys():
                for candidate_entity_id in label_to_entity(candidate_label):
                    
                    candidate_entity = Entity(candidate_entity_id)
                    yield {
                        'id': row['id'],
                        'question': row['question'],
                        'generatedAnswer': [candidate_label],
                        'answerEntity': [candidate_entity.idx],
                        'answerEntityLabel': [candidate_entity.label],
                        'questionEntity': question_entity_ids,
                        'groundTruthAnswerEntity': [e['name'] for e in row['answerEntity']],
                        'complexityType': row['complexityType'],
                    }

**Validation** predictions can be used for train reranker model

**Test** predictions used for reranking and evaluation after

In [None]:
output_path = Path('./to_subgraphs')
output_path.mkdir(parents=True, exist_ok=True)

with open(output_path / 'mintaka_mixtral_valid.jsonl', 'w') as f:
    for data_line in data_to_subgraphs(valid_df):
        f.write(ujson.dumps(data_line)+'\n')

In [None]:
with open(output_path / 'mintaka_mixtral_test.jsonl', 'w') as f:
    for data_line in data_to_subgraphs(test_df):
        f.write(ujson.dumps(data_line)+'\n')