In [1]:
from tqdm import tqdm
import pandas as pd
import spacy
import torch
import pickle
import numpy as np
from nltk.stem.porter import *

tqdm.pandas()

import sys
sys.path.insert(0,'../../')

%load_ext autoreload
%autoreload 2

from wikidata.wikidata_entity_to_label import WikidataEntityToLabel
from wikidata.wikidata_redirects import WikidataRedirectsCache
from metrics import recall

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = pd.read_csv('data_third_iteration.csv')
dataset.head()

Unnamed: 0,subject,property,object,question,subject_text,flag,subject_text_add,subject_text_all
0,Q7358590,P20,Q1637790,Where did roger marquis die,Roger Marquis,1.0,,Roger Marquis
1,Q154335,P509,Q12152,what was the cause of death of yves klein,"Yves Klein, The Void (artwork)",0.0,,"Yves Klein, The Void (artwork)"
2,Q2747238,P413,Q5059480,What position does carlos gomez play?,"Carlos Gómez, Carlos Gomez, Gómez, Carlos",0.0,,"Carlos Gómez, Carlos Gomez, Gómez, Carlos"
3,Q62498,P21,Q6581097,how does engelbert zaschka identify,"Engelbert Zaschka, Englebert Zaschka, Rotation...",0.0,,"Engelbert Zaschka, Englebert Zaschka, Rotation..."
4,Q182485,P413,Q1143358,what position does pee wee reese play in baseball,"Pee Wee Reese, Harold H. Reese, Harold Henry &...",0.0,,"Pee Wee Reese, Harold H. Reese, Harold Henry &..."


In [3]:
df = pd.read_pickle('./WD_SQ_test_with_entities_rerank_v1.pkl')
df.head(3)

Unnamed: 0,S,P,O,Q,Q_with_NER,q_without_ner_mayhewsw,q_with_ner_mayhewsw,Q_with_NER_entities,q_without_ner_mayhewsw_entities,q_with_ner_mayhewsw_entities,entities_after_rerank_v1
0,Q7358590,P20,Q1637790,Where did roger marquis die,Where Did [START] Roger Marquis [END] Die,[START] Where did Roger Marquis die [END],Where did [START] Roger Marquis [END] Die,"[{'id': 'Q7358590', 'texts': ['Roger Marquis >...","[{'id': 'Q8012493', 'texts': ['List of stories...","[{'id': 'Q7358590', 'texts': ['Roger Marquis >...","[Q7358590, Q6598240, Q8012493, Q8068232, Q4993..."
1,Q154335,P509,Q12152,what was the cause of death of yves klein,What Was The Cause Of Death Of [START] Yves Kl...,[START] What was the cause of death of Yves Kl...,What was the cause of death of [START] Yves Kl...,"[{'id': 'Q154335', 'texts': ['Yves Klein >> en...","[{'id': 'Q154335', 'texts': ['Yves Klein >> en...","[{'id': 'Q154335', 'texts': ['Yves Klein >> en...","[Q154335, Q1931388, Q633234]"
2,Q2747238,P413,Q5059480,What position does carlos gomez play?,What Position Does [START] Carlos Gomez [END] ...,[START] What position does Carlos Gomez play? ...,What position does [START] Carlos Gomez [END] ...,"[{'id': 'Q2747238', 'texts': ['Carlos Gómez >>...","[{'id': 'Q2747238', 'texts': ['Carlos Gómez >>...","[{'id': 'Q2747238', 'texts': ['Carlos Gómez >>...","[Q2747238, Q5042155, Q203210, Q62592284, Q5555..."


In [4]:
df = dataset.merge(df, left_on='question', right_on='Q')[[
    'subject', 'property', 'object',
    'question',
    'subject_text',
    'Q_with_NER_entities',
]]

In [5]:
def check_label_fn(label, entities_list):
    if label in entities_list:
        return True
    return False


def entities_selection(q, preds, ner_model, check_label_fn=check_label_fn):
    final_preds = []

    doc = ner_model(q)
    entities = [e.text.lower() for e in doc.ents]

    for pred in preds:
        label = None
        for text in pred['texts']:
            _label, lang = text.split(' >> ')
            if lang == 'en':
                label = _label

        if label is not None:
            label = label.lower()
            if check_label_fn(label, entities):
                if isinstance(pred['scores'], torch.Tensor):
                    pred['scores'] = pred['scores'].cpu().numpy().tolist()
                    pred['score'] = pred['score'].cpu().numpy().tolist()
                final_preds.append(pred)
        
    return final_preds


In [6]:
ner_model = spacy.load('../../../ner_model')

df['baseline_entities_selection'] = df.progress_apply(
    lambda row: entities_selection(row['question'], row['Q_with_NER_entities'], ner_model),
    axis=1
)

100%|██████████| 5676/5676 [02:25<00:00, 38.94it/s]


## Without selection

In [34]:
def entities_to_labels(entities):
    labels = []
    for e in entities:
        for text in e['texts']:
            label, lang = text.split(' >> ')
            if lang == 'en':
                labels.append(label)
    return labels

targets = df['subject_text'].fillna('').apply(
    lambda s: s.split(', ')
).values.tolist()
candidates = df['Q_with_NER_entities'].apply(entities_to_labels).values.tolist()
candidates = [c[:1] for c in candidates]

recall_with_redirects = recall(
    targets,
    candidates
)
print("recall_with_redirects: ", recall_with_redirects)


recall_without_redirects = recall(
    [t[:1] for t in targets],
    candidates
)
print("recall_without_redirects: ", recall_without_redirects)


not_missed_candidates_idxs = [idx for idx, cl in enumerate(candidates) if cl != []]

recall_with_redirects_without_missed = recall(
    np.array(targets, dtype=object)[not_missed_candidates_idxs].tolist(),
    np.array(candidates, dtype=object)[not_missed_candidates_idxs].tolist()
)
print("recall_with_redirects_without_missed: ", recall_with_redirects_without_missed)

recall_without_redirects_without_missed = recall(
    np.array([t[:1] for t in targets])[not_missed_candidates_idxs].tolist(),
    np.array(candidates)[not_missed_candidates_idxs].tolist()
)
print("recall_without_redirects_without_missed: ", recall_without_redirects_without_missed)

recall: 100%|██████████| 5676/5676 [00:00<00:00, 125340.87it/s]


recall_with_redirects:  0.755461592670895


recall: 100%|██████████| 5676/5676 [00:00<00:00, 428165.70it/s]
  np.array(targets)[not_missed_candidates_idxs].tolist(),
  np.array(candidates)[not_missed_candidates_idxs].tolist()


recall_without_redirects:  0.7274489076814659


recall: 100%|██████████| 5642/5642 [00:00<00:00, 207584.90it/s]
  np.array(candidates)[not_missed_candidates_idxs].tolist()


recall_with_redirects_without_missed:  0.760014179369018


recall: 100%|██████████| 5642/5642 [00:00<00:00, 432421.44it/s]

recall_without_redirects_without_missed:  0.7318326834455867





## Selection baseline

In [35]:
candidates = df['baseline_entities_selection'].apply(entities_to_labels).values.tolist()
candidates = [c[:1] for c in candidates]

recall_with_redirects = recall(
    targets,
    candidates
)
print("recall_with_redirects: ", recall_with_redirects)


recall_without_redirects = recall(
    [t[:1] for t in targets],
    candidates
)
print("recall_without_redirects: ", recall_without_redirects)


not_missed_candidates_idxs = [idx for idx, cl in enumerate(candidates) if cl != []]

recall_with_redirects_without_missed = recall(
    np.array(targets, dtype=object)[not_missed_candidates_idxs].tolist(),
    np.array(candidates, dtype=object)[not_missed_candidates_idxs].tolist()
)
print("recall_with_redirects_without_missed: ", recall_with_redirects_without_missed)

recall_without_redirects_without_missed = recall(
    np.array([t[:1] for t in targets])[not_missed_candidates_idxs].tolist(),
    np.array(candidates)[not_missed_candidates_idxs].tolist()
)
print("recall_without_redirects_without_missed: ", recall_without_redirects_without_missed)

recall: 100%|██████████| 5676/5676 [00:00<00:00, 290628.94it/s]


recall_with_redirects:  0.5181465821000705


recall: 100%|██████████| 5676/5676 [00:00<00:00, 562903.31it/s]
  np.array(targets)[not_missed_candidates_idxs].tolist(),
  np.array(candidates)[not_missed_candidates_idxs].tolist()


recall_without_redirects:  0.5042283298097252


recall: 100%|██████████| 3633/3633 [00:00<00:00, 272304.84it/s]
  np.array(candidates)[not_missed_candidates_idxs].tolist()


recall_with_redirects_without_missed:  0.8095238095238095


recall: 100%|██████████| 3633/3633 [00:00<00:00, 492291.75it/s]

recall_without_redirects_without_missed:  0.7877786952931461





## Selection v1

In [8]:
from functools import lru_cache

stemmer = PorterStemmer()

@lru_cache(maxsize=8192)
def label_format_fn(label, stemmer=stemmer):
    ' '.join([stemmer.stem(str(token)) for token in ner_model(label)])


def check_label_fn(label, entities_list):
    label = label_format_fn(label)
    for entity in entities_list:
        entity = label_format_fn(entity)
        if label == entity:
            return True
    return False


df['entities_selection_v1'] = df.progress_apply(
    lambda row: entities_selection(row['question'], row['Q_with_NER_entities'], ner_model, check_label_fn),
    axis=1
)

100%|██████████| 5676/5676 [10:08<00:00,  9.32it/s]


In [32]:
candidates = df['entities_selection_v1'].apply(entities_to_labels).values.tolist()
candidates = [c[:1] for c in candidates]

recall_with_redirects = recall(
    targets,
    candidates,
    label_preprocessor_fn=label_format_fn
)
print("recall_with_redirects: ", recall_with_redirects)


recall_without_redirects = recall(
    [t[:1] for t in targets],
    candidates,
    label_preprocessor_fn=label_format_fn
)
print("recall_without_redirects: ", recall_without_redirects)


not_missed_candidates_idxs = [idx for idx, cl in enumerate(candidates) if cl != []]

recall_with_redirects_without_missed = recall(
    np.array(targets, dtype=object)[not_missed_candidates_idxs].tolist(),
    np.array(candidates, dtype=object)[not_missed_candidates_idxs].tolist(),
    label_preprocessor_fn=label_format_fn
)
print("recall_with_redirects_without_missed: ", recall_with_redirects_without_missed)

recall_without_redirects_without_missed = recall(
    np.array([t[:1] for t in targets])[not_missed_candidates_idxs].tolist(),
    np.array(candidates)[not_missed_candidates_idxs].tolist(),
    label_preprocessor_fn=label_format_fn
)
print("recall_without_redirects_without_missed: ", recall_without_redirects_without_missed)

recall: 100%|██████████| 5676/5676 [07:24<00:00, 12.76it/s] 


recall_with_redirects:  0.9793868921775899


recall: 100%|██████████| 5676/5676 [01:10<00:00, 80.75it/s] 
  np.array(targets)[not_missed_candidates_idxs].tolist(),
  np.array(candidates)[not_missed_candidates_idxs].tolist(),


recall_without_redirects:  0.9793868921775899


recall: 100%|██████████| 5559/5559 [05:10<00:00, 17.90it/s]
  np.array(candidates)[not_missed_candidates_idxs].tolist(),


recall_with_redirects_without_missed:  1.0


recall: 100%|██████████| 5559/5559 [01:06<00:00, 84.00it/s] 

recall_without_redirects_without_missed:  1.0





In [25]:
def mgenre_entities_view(preds):
    final_preds = []
    for pred in preds:
        label = None
        for text in pred['texts']:
            _label, lang = text.split(' >> ')
            if lang == 'en':
                label = _label

        if label is not None:
            label = label.lower()
            final_preds.append({'texts': pred['texts'], 'id': pred['id']})
    return final_preds


_df = df.copy()
for col in ['Q_with_NER_entities', 'baseline_entities_selection', 'entities_selection_v1']:
    _df[col] = _df[col].apply(mgenre_entities_view)
    
_df.to_excel('WDSQ_EL_selection.xlsx')