In [1]:
import os
import random
from tqdm.auto import tqdm
import pandas as pd
from pathlib import Path

import numpy as np
import ujson

import torch

import datasets
from evaluateqa.mintaka import evaluate as evaluate_mintaka
from evaluateqa.mintaka import calculate_metrics_for_prediction
from evaluateqa.mintaka.evaluate import normalize_and_tokenize_text

tqdm.pandas()

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [4]:
torch.manual_seed(8)
random.seed(8)
np.random.seed(0)

In [5]:
biggraph_path = Path('/workspace/kbqa/PyTorchBigGraph')
biggraph_names_path   = biggraph_path / 'wikidata_translation_v1_names.json'
biggraph_vectors_path = biggraph_path / 'wikidata_translation_v1_vectors.npy'

with open(biggraph_names_path, 'r') as f:
    biggraph_names = ujson.load(f)
biggraph_vectors = np.load(biggraph_vectors_path)

In [6]:
biggraph_name2id = {}
for idx, name in enumerate(tqdm(biggraph_names)):
    try:
        biggraph_name2id[name.split('>')[0].split('/')[-1]] = idx
    except: 
        pass

len(biggraph_names), len(biggraph_name2id)

  0%|          | 0/78413185 [00:00<?, ?it/s]

(78413185, 75413999)

In [7]:
entities = {
    'London': biggraph_name2id['Q84'],
    'Paris': biggraph_name2id['Q90'],
    'UK': biggraph_name2id['Q145'],
    'Imperial College London': biggraph_name2id['Q189022'],
    'France': biggraph_name2id['Q142'],
}

In [8]:
print(f"{'Entity1':25} {'Entity2':25} {'Dot_product'}")
for key, val in entities.items():
    for key2, val2 in entities.items():
        dot_product = np.dot(
            biggraph_vectors[val],
            biggraph_vectors[val2]
        )
        print(f"{key:25} {key2:25} {dot_product}")

Entity1                   Entity2                   Dot_product
London                    London                    60.947242736816406
London                    Paris                     16.715190887451172
London                    UK                        15.051984786987305
London                    Imperial College London   16.485960006713867
London                    France                    5.931976795196533
Paris                     London                    16.715190887451172
Paris                     Paris                     47.60779571533203
Paris                     UK                        5.8029632568359375
Paris                     Imperial College London   10.893638610839844
Paris                     France                    16.393808364868164
UK                        London                    15.051984786987305
UK                        Paris                     5.8029632568359375
UK                        UK                        44.623741149902344
UK             

In [9]:
test_df = pd.read_json('test_beam_search_preds_mintaka_with_types.json')
test_df.head()

Unnamed: 0,id,lang,question,answerText,category,complexityType,questionEntity,answerEntity,generated_text,sequences_scores,generated_entities,answerRetrievedType,filtered_by_type_preds
0,fae46b21,en,What man was a famous American author and also...,Mark Twain,history,intersection,"[{'name': 'Q1497', 'entityType': 'entity', 'la...","[{'name': 'Q7245', 'label': 'Mark Twain'}]","[Edgar Allan Poe, Ernest Hemingway, Charles Di...","[-0.2734780908, -0.3756849766, -0.418252229700...","[Q16867, Q23434, Q5686, Q131149, Q34597, Q3616...",Q5,"[Q16867, Q23434, Q5686, Q131149, Q34597, Q3616..."
1,bc8713cc,en,How many Academy Awards has Jake Gyllenhaal be...,1,movies,count,"[{'name': 'Q133313', 'entityType': 'entity', '...","[{'name': 'Q106291', 'label': 'Academy Award f...","[1, 2, 3, 4, 5, 11, 6, 0, 8, 7, 9, 10, 13, 12,...","[-0.6568749547, -0.7941160798, -0.851152122, -...","[1, 2, 3, 4, 5, 11, 6, 0, 8, 7, 9, 10, 13, 12,...",Number,"[1, 2, 3, 4, 5, 11, 6, 0, 8, 7, 9, 10, 13, 12,..."
2,d2a03f72,en,"Who is older, The Weeknd or Drake?",Drake,music,comparative,"[{'name': 'Q2121062', 'entityType': 'entity', ...","[{'name': 'Q33240', 'label': 'Drake'}]","[Drake, The Weeknd, Cody Jarrett, Dwight D. Ei...","[-0.0174380932, -0.8993775845, -1.415274024, -...","[Q7559, Q2121062, Q5140439, Q9916, Q713099, Q5...",Q5,"[Q2121062, Q5140439, Q9916, Q713099, Q513019, ..."
3,9a296167,en,How many children did Donald Trump have?,5,history,count,"[{'name': 'Q22686', 'entityType': 'entity', 'l...","[{'name': 'Q3713655', 'label': 'Donald Trump J...","[2, 3, 4, 5, 6, 1, 8, 9, 0, 7, 11, 10, 6 child...","[-0.49233829980000005, -1.0202715397, -1.06337...","[2, 3, 4, 5, 6, 1, 8, 9, 0, 7, 11, 10, Q348559...",Number,"[2, 3, 4, 5, 6, 1, 8, 9, 0, 7, 11, 10, 13]"
4,e343ad26,en,Is the main hero in Final Fantasy IX named Kuja?,No,videogames,yesno,"[{'name': 'Q474573', 'entityType': 'entity', '...",[],"[Yes, No, Yuna, Yuna, Yuna, Yuna and Kuja are ...","[-0.3390540481, -0.3550684452, -1.4538880587, ...","[True, False, None, None, None, None, None, No...",yesno,[]


In [10]:
test_df['answerRetrievedType'].value_counts().head(10)

answerRetrievedType
Q5          1008
Number       820
yesno        544
Q11424       244
Q7889        185
Q3624078      87
Q35657        79
Q482994       69
Q7725634      65
Q1093829      52
Name: count, dtype: int64

In [11]:
def rerank_candidates_by_biggraph(candidates, question_entities):
    if len(candidates) <= 0:
        return []

    scores = []
    for candidate_id in candidates:
        vecid = biggraph_name2id.get(candidate_id)
        dot_product_score = 0
        if vecid is not None:
            for question_entity in question_entities:
                if question_entity['entityType'] != 'entity':
                    continue
                question_entity_id = question_entity['name']
                qvecid = biggraph_name2id.get(question_entity_id)

                if qvecid is not None:
                    dot_product_score += np.dot(
                        biggraph_vectors[vecid],
                        biggraph_vectors[qvecid],
                    )
        scores.append(dot_product_score)

    return np.array(candidates)[np.argsort(scores)[::-1]].tolist()

In [12]:
dataset_split = 'test'
dataset = datasets.load_dataset('AmazonScience/mintaka')

def print_eval(generated_answers=None, mode='kg', df=None, groupbycols=['complexityType']):
    if df is None:
        if not isinstance(generated_answers, dict):
            answers = dict(zip(dataset[dataset_split]['id'], generated_answers))
        else:
            answers = generated_answers

        results_kg = evaluate_mintaka(
            predictions=answers,
            split=dataset_split,
            mode=mode,
        )
    else:
        results_kg = evaluate_mintaka(
            df_with_predictions=df,
            split=dataset_split,
            mode=mode,
            groupbycols=groupbycols,
        )
    
    if 'answerRetrievedType' in results_kg:
        items = sorted(
            results_kg['answerRetrievedType'].items(),
            key=lambda item: -item[1]['hits1 Number Correct Answer Of'][1]
        )[:10]
        # items = [(f"{key} ({Entity(key).label if 'Q' in key[:1] else ''})", val) for key, val in items]
        results_kg['answerRetrievedType'] = dict(items)

    print(f"{'Group':13s}  {'Hits@1':6s} (Correct Of Total)")
    print(f"{'All':13s}= {results_kg['All']['hits1']:2.4f} ({results_kg['All']['hits1 Number Correct Answer Of'][0]:4d} Of {results_kg['All']['hits1 Number Correct Answer Of'][1]:4d})", end='\n\n')
    for key in results_kg.keys():
        if 'All' == key:
            continue

        for key, val in results_kg[key].items():
            print(f"{key:13s}= {val['hits1']:2.4f} ({val['hits1 Number Correct Answer Of'][0]:4d} Of {val['hits1 Number Correct Answer Of'][1]:4d})")
        print('')
    return results_kg

No config specified, defaulting to: mintaka/en
Found cached dataset mintaka (/root/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d)


  0%|          | 0/3 [00:00<?, ?it/s]

In [13]:
test_df['isAnswerEntity'] = test_df['answerEntity'].apply(lambda entities: len(entities) > 0)

preds = test_df.progress_apply(
    lambda row: rerank_candidates_by_biggraph(row['generated_entities'], row['questionEntity'])[0] if row['isAnswerEntity'] else row['generated_entities'][0],
    axis=1
)

df = calculate_metrics_for_prediction(
    dict(zip(dataset[dataset_split]['id'], preds)),
    dataset_split,
    'kg',
)
df['answerRetrievedType'] = test_df['answerRetrievedType']
results_kg = print_eval(df=df, groupbycols=['complexityType', 'answerRetrievedType'])

  0%|          | 0/4000 [00:00<?, ?it/s]

Group          Hits@1 (Correct Of Total)
All          = 0.1757 ( 703 Of 4000)

comparative  = 0.4500 ( 180 Of  400)
count        = 0.0225 (   9 Of  400)
difference   = 0.1175 (  47 Of  400)
generic      = 0.1300 ( 104 Of  800)
intersection = 0.1900 (  76 Of  400)
multihop     = 0.0525 (  21 Of  400)
ordinal      = 0.0550 (  22 Of  400)
superlative  = 0.1300 (  52 Of  400)
yesno        = 0.4800 ( 192 Of  400)

Q5           = 0.1915 ( 193 Of 1008)
Number       = 0.0341 (  28 Of  820)
yesno        = 0.4651 ( 253 Of  544)
Q11424       = 0.1066 (  26 Of  244)
Q7889        = 0.0865 (  16 Of  185)
Q3624078     = 0.2299 (  20 Of   87)
Q35657       = 0.2532 (  20 Of   79)
Q482994      = 0.0145 (   1 Of   69)
Q7725634     = 0.0615 (   4 Of   65)
Q1093829     = 0.2692 (  14 Of   52)



In [14]:
preds = test_df.progress_apply(
    lambda row: rerank_candidates_by_biggraph(row['filtered_by_type_preds'], row['questionEntity'])[0] if row['isAnswerEntity'] else row['generated_entities'][0],
    axis=1
)

df = calculate_metrics_for_prediction(
    dict(zip(dataset[dataset_split]['id'], preds)),
    dataset_split,
    'kg',
)
df['answerRetrievedType'] = test_df['answerRetrievedType']
results_kg = print_eval(df=df, groupbycols=['complexityType', 'answerRetrievedType'])

  0%|          | 0/4000 [00:00<?, ?it/s]

Group          Hits@1 (Correct Of Total)
All          = 0.1872 ( 749 Of 4000)

comparative  = 0.4250 ( 170 Of  400)
count        = 0.0350 (  14 Of  400)
difference   = 0.1200 (  48 Of  400)
generic      = 0.1525 ( 122 Of  800)
intersection = 0.1950 (  78 Of  400)
multihop     = 0.0600 (  24 Of  400)
ordinal      = 0.0825 (  33 Of  400)
superlative  = 0.1700 (  68 Of  400)
yesno        = 0.4800 ( 192 Of  400)

Q5           = 0.1974 ( 199 Of 1008)
Number       = 0.0402 (  33 Of  820)
yesno        = 0.4651 ( 253 Of  544)
Q11424       = 0.1025 (  25 Of  244)
Q7889        = 0.1297 (  24 Of  185)
Q3624078     = 0.2759 (  24 Of   87)
Q35657       = 0.3165 (  25 Of   79)
Q482994      = 0.0580 (   4 Of   69)
Q7725634     = 0.0923 (   6 Of   65)
Q1093829     = 0.2308 (  12 Of   52)

