In [1]:

from tqdm import tqdm
import pandas as pd
import datasets
from pathlib import Path
from transformers import T5ForConditionalGeneration, AutoTokenizer
from typing import List, Optional, Dict
import pickle
import torch
from torch.utils.data import DataLoader

tqdm.pandas()

%load_ext autoreload
%autoreload 2

from pywikidata import Entity
from kbqa.utils.train_eval import get_best_checkpoint_path

from trie import MarisaTrie

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!git clone https://github.com/askplatypus/wikidata-simplequestions.git
!wget -nc https://dl.fbaipublicfiles.com/GENRE/lang_title2wikidataID-normalized_with_redirect.pkl

fatal: destination path 'wikidata-simplequestions' already exists and is not an empty directory.
File ‘lang_title2wikidataID-normalized_with_redirect.pkl’ already there; not retrieving.



In [3]:
with open("./lang_title2wikidataID-normalized_with_redirect.pkl", "rb") as f:
    lang_title2wikidataID = pickle.load(f)

tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm")

In [4]:
train_df = pd.read_csv(
    "./wikidata-simplequestions/annotated_wd_data_train_answerable.txt",
    sep="\t",
    names=["S", "P", "O", "Q"],
)

valid_df = pd.read_csv(
    "./wikidata-simplequestions/annotated_wd_data_valid_answerable.txt",
    sep="\t",
    names=["S", "P", "O", "Q"],
)

test_df = pd.read_csv(
    "./wikidata-simplequestions/annotated_wd_data_test_answerable.txt",
    sep="\t",
    names=["S", "P", "O", "Q"],
)

In [5]:
allowed_names_en = [name for lang, name in lang_title2wikidataID.keys() if lang == "en"]

for df in [train_df, valid_df, test_df]:
    allowed_names_en += (
        df["S"].progress_apply(lambda idx: Entity(idx).label).unique().tolist()
    )
    allowed_names_en += (
        df["P"]
        .progress_apply(lambda idx: Entity(idx.replace("R", "P")).label)
        .unique()
        .tolist()
    )
    allowed_names_en += (
        df["O"].progress_apply(lambda idx: Entity(idx).label).unique().tolist()
    )

100%|██████████| 19481/19481 [00:04<00:00, 4290.21it/s]
100%|██████████| 19481/19481 [00:00<00:00, 196902.06it/s]
100%|██████████| 19481/19481 [00:01<00:00, 10951.96it/s]
100%|██████████| 2821/2821 [00:00<00:00, 5616.17it/s]
100%|██████████| 2821/2821 [00:00<00:00, 223974.63it/s]
100%|██████████| 2821/2821 [00:00<00:00, 12071.05it/s]
100%|██████████| 5622/5622 [00:00<00:00, 5796.05it/s]
100%|██████████| 5622/5622 [00:00<00:00, 227504.41it/s]
100%|██████████| 5622/5622 [00:00<00:00, 11889.87it/s]


In [6]:
allowed_names_en = list(set(allowed_names_en))
allowed_names_en = list(
    filter(lambda s: isinstance(s, str) and s != "", allowed_names_en)
)

allowed_names_en_tok = [tokenizer(name)["input_ids"] for name in tqdm(allowed_names_en)]

100%|██████████| 14753570/14753570 [12:13<00:00, 20117.12it/s]


In [10]:
tok_names_padded = [
    [
        tokenizer.pad_token_id,
    ]
    + toks
    for toks in allowed_names_en_tok
]
trie = MarisaTrie(sequences=tok_names_padded, cache_fist_branch=True)
with open("./wdsq_t5_trie.pkl", "wb") as f:
    pickle.dump(trie, f)

In [19]:
mgenre_tokenizer = AutoTokenizer.from_pretrained("facebook/mgenre-wiki")
allowed_names_en_tok_mgenre = [
    mgenre_tokenizer(name + " >> en")["input_ids"] for name in tqdm(allowed_names_en)
]

100%|██████████| 14753570/14753570 [12:57<00:00, 18975.63it/s]


In [20]:
tok_names_padded_mgenre = [
    [
        mgenre_tokenizer.sep_token_id,
    ]
    + toks
    for toks in allowed_names_en_tok_mgenre
]
trie = MarisaTrie(sequences=tok_names_padded_mgenre, cache_fist_branch=True)
with open("./wdsq_mgenre_trie.pkl", "wb") as f:
    pickle.dump(trie, f)