In [1]:
import pandas as pd
from pywikidata import Entity
from tqdm.auto import tqdm
import ujson
import itertools
import random
import datasets
from kbqa.entity_linkink import build_mgenre_pipeline, EntitiesSelection
from kbqa.caches.ner_to_sentence_insertion import NerToSentenceInsertion

tqdm.pandas()

In [2]:
import requests
from joblib import Memory

memory = Memory('/tmp/cache', verbose=0)


@memory.cache
def get_wd_search_results(
    search_string: str,
    max_results: int = 500,
    language: str = 'en',
    mediawiki_api_url: str = "https://www.wikidata.org/w/api.php",
    user_agent: str = None,
) -> list:
    params = {
        'action': 'wbsearchentities',
        'language': language,
        'search': search_string,
        'format': 'json',
        'limit': 50
    }

    user_agent = "pywikidata" if user_agent is None else user_agent
    headers = {
        'User-Agent': user_agent
    }

    cont_count = 1
    results = []
    while cont_count > 0:
        params.update({'continue': 0 if cont_count == 1 else cont_count})

        reply = requests.get(mediawiki_api_url, params=params, headers=headers)
        reply.raise_for_status()
        search_results = reply.json()

        if search_results['success'] != 1:
            raise Exception('WD search failed')
        else:
            for i in search_results['search']:
                results.append(i['id'])

        if 'search-continue' not in search_results:
            cont_count = 0
        else:
            cont_count = search_results['search-continue']

        if cont_count > max_results:
            break

    return results

In [10]:
def prepare_data(data_df, results_df, wd_search_results_top_k: int = 1, mgenre=None, ner=None, entities_selection=None, test_mode=False):
    answers_cols = [c for c in results_df.columns if 'answer_' in c]

    results_df['answers_ids'] = results_df[answers_cols].progress_apply(
        lambda row: [
            get_wd_search_results(label, 5, language='en')[:wd_search_results_top_k]
            for label in row.unique()[:5]
        ],
        axis=1
    ).apply(lambda list_of_list_of_answers: list(itertools.chain(*list_of_list_of_answers)))

    df = results_df.merge(data_df, on='question')

    for _, row in tqdm(df.iterrows(), total=df.index.size):
        golden_true_entity = [Entity(e['name']) for e in row['answerEntity']]

        question_entity = [
            Entity(e['name'])
            for e in row['questionEntity']
            if e['entityType'] == 'entity' and e['name'] not in [None, 'None', ''] and e['name'][0] == 'Q'
        ]
        
        if test_mode:
            candidates_ids = row['answers_ids']
        else:
            if mgenre is None or ner is None or entities_selection is None:
                additional_candidates = []
                for qe in question_entity:
                    for _,e in qe.forward_one_hop_neighbours:
                        if e not in golden_true_entity:
                            additional_candidates.append(e.idx)
                random.shuffle(additional_candidates)
                additional_candidates = additional_candidates[:5]
            else:
                question_with_ner, entities_list = ner.entity_labeling(row['question'], True)
                mgenre_results = mgenre(question_with_ner)
                selected_entities = entities_selection(entities_list, mgenre_results)

                question_entity = list(itertools.chain(*[
                    get_wd_search_results(l, 1, language='en')[:1]
                    for l in selected_entities
                ]))
                additional_candidates = []

            candidates_ids = set(additional_candidates + row['answers_ids'] + [e.idx for e in golden_true_entity])
        
        for candidate_id in candidates_ids:
            candidate_entity = Entity(candidate_id)
            yield {
                'id': row['id'],
                'question': row['question'],
                'answerEntity': [candidate_entity.idx],
                'questionEntity': [e.idx for e in question_entity],
                'groundTruthAnswerEntity': [e.idx for e in golden_true_entity],
                'complexityType': row['complexityType'],
            }


In [4]:
ds = datasets.load_dataset('AmazonScience/mintaka', name='en')
ds

Found cached dataset mintaka (/Users/ms/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'lang', 'question', 'answerText', 'category', 'complexityType', 'questionEntity', 'answerEntity'],
        num_rows: 14000
    })
    validation: Dataset({
        features: ['id', 'lang', 'question', 'answerText', 'category', 'complexityType', 'questionEntity', 'answerEntity'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['id', 'lang', 'question', 'answerText', 'category', 'complexityType', 'questionEntity', 'answerEntity'],
        num_rows: 4000
    })
})

In [5]:
results_validation_df: pd.DataFrame = pd.read_csv('./mintaka_results_validation.csv') # validation

with open('to_subgraphs/mintaka_validation.jsonl', 'w') as f:
    for data_line in prepare_data(ds['validation'].to_pandas(), results_validation_df):
        f.write(ujson.dumps(data_line)+'\n')

100%|██████████| 2000/2000 [00:03<00:00, 573.52it/s]
100%|██████████| 2000/2000 [00:05<00:00, 365.25it/s]


In [6]:
results_train_df: pd.DataFrame = pd.read_csv('./mintaka_results_train.csv') # train

with open('to_subgraphs/mintaka_train.jsonl', 'w') as f:
    for data_line in prepare_data(ds['train'].to_pandas(), results_train_df):
        f.write(ujson.dumps(data_line)+'\n')

100%|██████████| 14000/14000 [00:23<00:00, 593.79it/s]
100%|██████████| 14000/14000 [00:13<00:00, 1021.77it/s]


In [11]:
results_test_df: pd.DataFrame = pd.read_csv('./mintaka_results_test.csv') # test

with open('to_subgraphs/mintaka_test.jsonl', 'w') as f:
    for data_line in prepare_data(ds['test'].to_pandas(), results_test_df, test_mode=True):
        f.write(ujson.dumps(data_line)+'\n')

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

In [12]:
df = pd.read_json('./to_subgraphs/mintaka_test.jsonl', lines=True)
df['id'].unique().shape