In [1]:
import re
import json
from wikipedia2vec import Wikipedia2Vec
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForNextSentencePrediction, BertForMultipleChoice
import re
import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.corpus import stopwords
import string
from nltk.tokenize import word_tokenize
from rich import inspect
from tqdm import tqdm
import json

def parse_relation_wiki(r):
    r = re.sub("[-+.^:,'()<>_-]", " ", r)
    r = re.sub("P+[0-9]+", " ", r)
    r = re.sub(" +", " ", r)
    r = r.strip()
    return r
    
def parse_entity_wiki(e):
    e = re.sub("[-+.^:,'()<>_-]", " ", e)
    e = re.sub("P+[0-9]+", " ", e)
    e = re.sub(" +", " ", e)
    e = e.strip()
    return e

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /home/ots2rng/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/ots2rng/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## Get qids from wikidata

In [None]:
!pip install sparqlwrapper

In [375]:
import sys
from SPARQLWrapper import SPARQLWrapper, JSON

endpoint_url = "https://query.wikidata.org/sparql"

def get_results(endpoint_url, query):
    user_agent = "WDQS-example Python/%s.%s" % (sys.version_info[0], sys.version_info[1])
    # TODO adjust user agent; see https://w.wiki/CX6
    sparql = SPARQLWrapper(endpoint_url, agent=user_agent)
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    return sparql.query().convert()

def query(qids):
    query = """
    SELECT ?qid ?label WHERE {
      VALUES ?qid { """ + qids + """ }
      ?qid rdfs:label ?label
      FILTER(LANG(?label) = "en").
    }"""
    results = get_results(endpoint_url, query)

    return results["results"]["bindings"]

def query_alts(qids):
    query = """
    SELECT ?qid ?label WHERE {
      VALUES ?qid { """ + qids + """ }
      ?qid owl:sameAs ?qid_alt.
      ?qid_alt rdfs:label ?label.
      FILTER(LANG(?label) = "en").
    }"""
    results = get_results(endpoint_url, query)

    return results["results"]["bindings"]

In [50]:
qids_label = dict()
with open("./data/wiki44k/meta.txt", "r") as infile:
    count = int(infile.readline().split("\t")[0])
    print(count)
    qids = list()
    for i in tqdm(range(count)):
        ent = infile.readline().strip()
        qid = re.search("Q+[0-9]+", ent).group(0)
        qids.append(qid)
        qids_label[qid] = parse_entity_wiki(ent)

43934


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 43934/43934 [00:00<00:00, 776420.57it/s]


In [68]:
for i in tqdm(range(0,len(qids),500)):
    values_qid = " ".join(map(lambda x: "wd:"+x, qids[i:i+500]))
    results = query(values_qid)
    for result in results:
        qids_label[result["qid"]["value"].replace("http://www.wikidata.org/entity/","")] = result["label"]["value"]

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 88/88 [02:00<00:00,  1.37s/it]


In [392]:
qids_no_label = []
for qid in qids:
    if qid not in qids_label:
        qids_no_label.append(qid)
print(len(qids_no_label))
        
for i in tqdm(range(0,len(qids_no_label),500)):
    values_qid = " ".join(map(lambda x: "wd:"+x, qids_no_label[i:i+500]))
    results = query_alts(values_qid)
    for result in results:
        qids_label[result["qid"]["value"].replace("http://www.wikidata.org/entity/","")] = result["label"]["value"]

213


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.13s/it]


In [394]:
with open("./data/wiki44k/qids_label.json", "w") as outfile:
    outfile.write(json.dumps(qids_label, indent=4))

### Load

In [2]:
with open("./data/wiki44k/qids_label.json", "r") as infile:
    qids_label = json.load(infile)

In [3]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

from tqdm import tqdm

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def embed(text):
    # Tokenize sentences
    encoded_input = tokenizer([text], padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    return sentence_embeddings

def get_most_similar(text, k=1):
    embedding = embed(text)
    top_k = torch.topk((embeddings_pt @ embedding.T).reshape(-1),k=k)
    return *top_k, [entities[x] for x in top_k[1].tolist()]


# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

In [4]:
embeddings = {}

sentences = list(qids_label.values())

for sentence in tqdm(sentences):
    embeddings[sentence] = embed(sentence)
    

entities = list(embeddings.keys())
embeddings_pt = torch.vstack(tuple(embeddings.values()))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 43934/43934 [03:42<00:00, 197.86it/s]


In [5]:
embeddings_vocab = {}

from transformers import BertTokenizer
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

sentences = list(bert_tokenizer.vocab.keys())
for sentence in tqdm(sentences):
    embeddings_vocab[sentence] = embed(sentence)
    
entities_vocab = list(embeddings_vocab.keys())
embeddings_pt_vocab = torch.vstack(tuple(embeddings_vocab.values()))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30522/30522 [02:04<00:00, 244.89it/s]


In [6]:
sim = (embeddings_pt @ embeddings_pt_vocab.T)
argmax = torch.argmax(sim, dim=1)

In [7]:
{entities[idx]: (i, entities_vocab[i]) for idx, i in enumerate(argmax.tolist())}

{'Japan': (2900, 'japan'),
 'Norway': (5120, 'norway'),
 'Hungary': (5872, 'hungary'),
 'Spain': (3577, 'spain'),
 'United States of America': (2637, 'america'),
 'Poland': (3735, 'poland'),
 'Italy': (3304, 'italy'),
 'New York City': (16392, 'nyc'),
 'Berlin': (4068, 'berlin'),
 'Los Angeles': (3349, 'angeles'),
 'Alexandria': (10297, 'alexandria'),
 'London': (2414, 'london'),
 'Paris': (3000, 'paris'),
 'Mexico': (3290, 'mexico'),
 'France': (2605, 'france'),
 'United Kingdom': (3725, 'britain'),
 'Russia': (3607, 'russia'),
 'Germany': (2762, 'germany'),
 'German': (2446, 'german'),
 'Rome': (4199, 'rome'),
 'North Macedonia': (11492, 'macedonia'),
 'Monaco': (14497, 'monaco'),
 'Algeria': (11337, 'algeria'),
 'Lyon': (10241, 'lyon'),
 'Łódź': (17814, 'łodz'),
 'Delft': (3972, 'del'),
 'William Shakespeare': (8101, 'shakespeare'),
 'Iran': (4238, 'iran'),
 'Kyrgyzstan': (23209, 'kyrgyzstan'),
 'Carl Linnaeus': (21610, 'linnaeus'),
 'Saxony': (13019, 'saxony'),
 'Karnataka': (12092

In [33]:
label2idx = {entities[idx]: i for idx, i in enumerate(argmax.tolist())}

with open("./data/wiki44k/entities2tokens.json", "w") as outfile:
    outfile.write(json.dumps(label2idx, indent=4))