# Prepare dataset 

(1-hop subset from question's entities)

In [1]:
# OPTIONAL: load the prefix tree (trie), you need to additionally download
# https://huggingface.co/facebook/mgenre-wiki/blob/main/trie.py and 
# https://huggingface.co/facebook/mgenre-wiki/blob/main/titles_lang_all105_trie_with_redirect.pkl
# that is fast but memory inefficient prefix tree (trie) -- it is implemented with nested python `dict`
# NOTE: loading this map may take up to 10 minutes and occupy a lot of RAM!
# import pickle
# from trie import Trie
# with open("titles_lang_all105_marisa_trie_with_redirect.pkl", "rb") as f:
#     trie = Trie.load_from_dict(pickle.load(f))

# or a memory efficient but a bit slower prefix tree (trie) -- it is implemented with `marisa_trie` from
# https://huggingface.co/facebook/mgenre-wiki/blob/main/titles_lang_all105_marisa_trie_with_redirect.pkl
# from genre.trie import MarisaTrie
# with open("titles_lang_all105_marisa_trie_with_redirect.pkl", "rb") as f:
#     trie = pickle.load(f)

In [2]:
!wget -nc https://raw.githubusercontent.com/facebookresearch/GENRE/main/genre/trie.py
!wget -nc http://dl.fbaipublicfiles.com/GENRE/titles_lang_all105_trie_with_redirect.pkl
!wget -nc https://dl.fbaipublicfiles.com/GENRE/lang_title2wikidataID-normalized_with_redirect.pkl
!wget -nc https://dl.fbaipublicfiles.com/GENRE/wikidataID2lang_title-normalized_with_redirect.pkl

File ‘trie.py’ already there; not retrieving.

File ‘titles_lang_all105_trie_with_redirect.pkl’ already there; not retrieving.

File ‘lang_title2wikidataID-normalized_with_redirect.pkl’ already there; not retrieving.

File ‘wikidataID2lang_title-normalized_with_redirect.pkl’ already there; not retrieving.



In [3]:
import pickle
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from trie import Trie


mgenre_tokenizer = AutoTokenizer.from_pretrained("facebook/mgenre-wiki")
mgenre_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mgenre-wiki").eval()

# with open("titles_lang_all105_trie_with_redirect.pkl", "rb") as f:
#     mgenre_trie = Trie.load_from_dict(pickle.load(f))

In [13]:
import requests
import time
import pandas as pd

from transformers import Pipeline
from functools import lru_cache
from nltk.stem.porter import PorterStemmer

import sys
sys.path.insert(0, '../')

from caches.ner_to_sentence_insertion import NerToSentenceInsertion

import logging 
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [55]:
class MGENREPipeline(Pipeline):
    """MGENREPipeline - HF Pipeline for mGENRE EntityLinking model"""

    def _sanitize_parameters(self, **kwargs):
        forward_kwargs = {}
        if "num_beams" in kwargs:
            forward_kwargs["num_beams"] = kwargs.get("num_beams", 10)
        if "num_return_sequences" in kwargs:
            forward_kwargs["num_return_sequences"] = kwargs.get(
                "num_return_sequences", 10
            )
        return {}, forward_kwargs, {}

    def preprocess(self, input_):
        return self.tokenizer(
            input_,
            return_tensors="pt",
        )

    def _forward(
        self,
        input_tensors,
        num_beams=10,
        num_return_sequences=10,
    ):
        outputs = self.model.generate(
            **{k: v.to(self.device) for k, v in input_tensors.items()},
            num_beams=num_beams,
            num_return_sequences=num_return_sequences,
            # prefix_allowed_tokens_fn=lambda batch_id, sent: mgenre_trie.get(sent.tolist()),
        )
        return outputs

    def postprocess(self, model_outputs):
        outputs = self.tokenizer.batch_decode(
            model_outputs,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        return outputs


class EntitiesSelection:
    def __init__(self, ner_model):
        self.stemmer = PorterStemmer()
        self.ner_model = ner_model

    def entities_selection(self, entities_list, mgenre_predicted_entities_list):
        final_preds = []

        for pred_text in mgenre_predicted_entities_list:
            labels = []
            try:
                _label, lang = pred_text.split(" >> ")
                if lang == "en":
                    labels.append(_label)
            except Exception as e:
                logger.error(f"Error {str(e)} with pred_text={pred_text}")

            if len(labels) > 0:
                for label in labels:
                    label = label.lower()
                    if self._check_label_fn(label, entities_list):
                        final_preds.append(pred_text)

        return list(dict.fromkeys(final_preds))

    @lru_cache(maxsize=8192)
    def _label_format_fn(self, label):
        return " ".join(
            [self.stemmer.stem(str(token)) for token in self.ner_model(label)]
        )

    def _check_label_fn(self, label, entities_list):
        label = self._label_format_fn(label)
        for entity in entities_list:
            entity = self._label_format_fn(entity)
            if label == entity:
                return True
        return False


class LabelToEntity:
    def __init__(
        self,
        title2wikidataID_path: str = 'lang_title2wikidataID-normalized_with_redirect.pkl',
    ):
        with open(title2wikidataID_path, "rb") as f:
            self.lang_title_to_wikidata_id = pickle.load(f)

    @lru_cache(maxsize=16384)
    def __call__(self, text):
        if ' >> ' not in text:
            text += ' >> ln'

        set_of_ids = self.lang_title_to_wikidata_id.get(
            tuple(reversed(text.split(" >> ")))
        )
        if set_of_ids is None:
            return None
        else:
            return max(
                set_of_ids,
                key=lambda y: int(y[1:]),
            )


class EntityToLabel:
    def __init__(
        self,
        wikidataID2title_path: str = 'wikidataID2lang_title-normalized_with_redirect.pkl',
        lang: str = 'en',
    ):
        with open(wikidataID2title_path, "rb") as f:
            self.wikidata_id_lang_title = pickle.load(f)
        
        self.lang = lang

    @lru_cache(maxsize=16384)
    def __call__(self, index):
        labels = self.wikidata_id_lang_title.get(index)
        if labels is not None:
            for _lng, label in labels:
                if self.lang == _lng:
                    return label
        else:
            return None


class EntityLinking():
    def __init__(self, mgenre: MGENREPipeline, ner: NerToSentenceInsertion, label_to_entity: LabelToEntity):
        self.mgenre = mgenre
        self.ner = ner
        self.entities_selection_module = EntitiesSelection(self.ner.model)
        self.label_to_entity = label_to_entity
    
    def __call__(self, text):
        text_with_labeling, entities_list = self.ner.entity_labeling(text, True)
        mgenre_predicted_entities_list = self.mgenre(text_with_labeling)
        linked_entities_list = self.entities_selection_module.entities_selection(
            entities_list, mgenre_predicted_entities_list
        )
        linked_entities_ids_list = [self.label_to_entity(label) for label in linked_entities_list]
        return {
            "text_with_labeling": text_with_labeling,
            "ner_entities": entities_list,
            "mgenre_predicted_entities_list": mgenre_predicted_entities_list,
            "linked_entities_list": linked_entities_list,
            "linked_entities_ids_list": linked_entities_ids_list,
        }

In [56]:
ner = NerToSentenceInsertion(model_path='../../ner/model-best/')
mgenre = MGENREPipeline(mgenre_model, mgenre_tokenizer)
label_to_entity = LabelToEntity('./lang_title2wikidataID-normalized_with_redirect.pkl')
entity_to_label = EntityToLabel('./wikidataID2lang_title-normalized_with_redirect.pkl')
entity_linker = EntityLinking(
    mgenre=mgenre,
    ner=ner,
    label_to_entity=label_to_entity,
)

In [67]:
SPARQL_ENDPOINT = "https://query.wikidata.org/sparql"

@lru_cache(maxsize=16384)
def wikidata_request(query):
    logger.info(f'Request to Wikidata with query:\n{query}')
    response = requests.get(
        SPARQL_ENDPOINT,
        params={"format": "json", "query": query},
        headers={"Accept": "application/json"},
    )
    to_sleep = 0.2
    while response.status_code == 429:
        if "retry-after" in response.headers:
            to_sleep += int(response.headers["retry-after"])
        to_sleep += 0.5
        logger.info(f'wikidata_request to sleep...')
        time.sleep(to_sleep)
        response = requests.get(
            SPARQL_ENDPOINT,
            params={"format": "json", "query": query},
            headers={"Accept": "application/json"},
        )
    return response.json()["results"]["bindings"]

def wikidata_get_corresponded_objects(entity):
    query = """
SELECT ?p ?item WHERE {
    wd:<E1> ?p ?item .
    ?article schema:about ?item .
    ?article schema:inLanguage "en" .
    ?article schema:isPartOf <https://en.wikipedia.org/> .
    SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
    """.replace('<E1>', entity)

    cor_objects = wikidata_request(query)
    cor_objects = [
        val['item']['value'].split('/')[-1]
        for val in cor_objects
    ]
    
    return list(dict.fromkeys(cor_objects))

In [12]:
!git clone https://github.com/askplatypus/wikidata-simplequestions.git

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
fatal: destination path 'wikidata-simplequestions' already exists and is not an empty directory.


In [14]:
train_df = pd.read_csv(
    './wikidata-simplequestions/annotated_wd_data_train_answerable.txt',
    sep='\t',
    names=['S', 'P', 'O', 'Q'],
)

valid_df = pd.read_csv(
    './wikidata-simplequestions/annotated_wd_data_valid_answerable.txt',
    sep='\t',
    names=['S', 'P', 'O', 'Q'],
)

test_df = pd.read_csv(
    './wikidata-simplequestions/annotated_wd_data_test_answerable.txt',
    sep='\t',
    names=['S', 'P', 'O', 'Q'],
)

In [79]:
from tqdm.auto import tqdm

def prepare_dataset(df):
    results = []

    for _, row in tqdm(df.iterrows(), total=df.index.size):
        linked_results = entity_linker(row['Q'])
        linked_results['one_hop_neighbors'] = {
            entity: [(eidx, entity_to_label(eidx)) for eidx in  wikidata_get_corresponded_objects(entity)]
            for entity in linked_results['linked_entities_ids_list'] if entity is not None
        }
        results.append(
            dict(**row.to_dict(), **linked_results)
        )

In [80]:
new_valid_df = pd.DataFrame(prepare_dataset(valid_df.iloc[:100]))

  0%|          | 0/100 [00:00<?, ?it/s]

