In [2]:
import json
from collections import Counter

from SPARQLWrapper import SPARQLWrapper, JSON
import time
import re

from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from langdetect import detect

import faiss
from tenacity import retry, wait_random_exponential, before_sleep_log

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
ONTOLOGY_MAPPINGS_DIR = "../utils/ontology_mappings/"

## Collecting relation names and constraints

### Collecting labels and data types of relations

In [4]:
PROP_2_LABEL = {}
PROP_2_DATA_TYPE = {}

sparql = SPARQLWrapper("https://query.wikidata.org/sparql")

# SPARQL query for properties with data types: Item, Quantity, Point in time
query = """
SELECT ?property ?propertyLabel ?typeLabel WHERE {
  ?property a wikibase:Property .
  ?property wikibase:propertyType ?type .
  
  VALUES ?type { wikibase:WikibaseItem wikibase:Quantity wikibase:Time }
  
  BIND(
    IF(?type = wikibase:WikibaseItem, "Item",
      IF(?type = wikibase:Quantity, "Quantity",
        IF(?type = wikibase:Time, "Point in time", "Unknown")
      )
    ) AS ?typeLabel
  )
  
  SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
}
"""

sparql.setQuery(query)
sparql.setReturnFormat(JSON)

try:
    results = sparql.query().convert()

    for result in results["results"]["bindings"]:
        prop = result["property"]["value"].split("/")[-1]
        label = result.get("propertyLabel", {}).get("value", "No label")
        data_type = result.get("typeLabel", {}).get("value", "Unknown")

        PROP_2_LABEL[prop] = label
        PROP_2_DATA_TYPE[prop] = data_type        

except Exception as e:
    print(f"Error executing SPARQL query: {e}")

In [5]:
len(PROP_2_LABEL), len(PROP_2_DATA_TYPE)

(2432, 2432)

In [7]:
set(PROP_2_DATA_TYPE.values())

{'Item', 'Point in time', 'Quantity'}

In [8]:
with open(ONTOLOGY_MAPPINGS_DIR+"prop2data_type.json", 'w') as f:
    json.dump(PROP_2_DATA_TYPE, f)

In [7]:
with open(ONTOLOGY_MAPPINGS_DIR+"prop2label.json", 'w') as f:
    json.dump(PROP_2_LABEL, f)

### Collecting relation aliases

In [19]:
len(set(PROP_2_LABEL.keys())), len(set(PROP_2_LABEL.values()))

(2414, 2414)

In [22]:
@retry(wait=wait_random_exponential(multiplier=1, max=60))
def get_property_aliases(property_id):
    sparql = SPARQLWrapper("https://query.wikidata.org/sparql")
    
    query = f"""
    SELECT ?alias WHERE {{
      wd:{property_id} skos:altLabel ?alias .
      FILTER (lang(?alias) = "en")
    }}
    """
    
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    results = sparql.query().convert()
    
    aliases = [result["alias"]["value"] for result in results["results"]["bindings"]]
    return aliases

PROP2ALIASES = {}

for property_id in tqdm(PROP_2_LABEL.keys()):
    PROP2ALIASES[property_id] = get_property_aliases(property_id)

100%|██████████| 2414/2414 [13:28<00:00,  2.99it/s]  


In [25]:
alias_set = set()
alias_list = []
for prop, aliases in PROP2ALIASES.items():
    alias_set.update(aliases)
    alias_list.extend(aliases)
print(len(alias_set), len(alias_list))

7932 8483


In [139]:
with open(ONTOLOGY_MAPPINGS_DIR+"prop2label.json", 'r') as f:
    PROP_2_LABEL = json.load(f)

with open(ONTOLOGY_MAPPINGS_DIR+"prop2aliases.json", 'r') as f:
    PROP2ALIASES = json.load(f)

In [141]:
for prop in PROP_2_LABEL:
    print(PROP_2_LABEL[prop], PROP2ALIASES[prop])

head of government ['executive power headed by', 'government headed by', 'head of national government', 'president', 'chancellor', 'mayor', 'governor', 'prime minister', 'premier', 'first minister']
transport network ['subway system', 'light rail system', 'highway system', 'road type', 'network', 'transport network', 'tram network', 'metro system', 'part of network', 'routes system', 'system of routes', 'trail system', 'transit network']
country ['host country', 'state', 'land', 'sovereign state']
place of birth ['birth location', 'birth place', 'birthplace', 'location of birth', 'POB', 'birth city', 'born', 'born at', 'born in', 'location born']
place of death ['POD', 'died', 'deathplace', 'location of death', 'death location', 'death place', 'died in', 'killed in']
sex or gender ['gender or sex', 'gender expression', 'biological sex', 'gender identity', 'sex', 'gender']
father ['child of', 'dad', 'daddy', 'daughter of', 'has father', 'is child of', 'is daughter of', 'is son of', 'par

### Collecting subject and value constraints of relations

In [11]:
from SPARQLWrapper import SPARQLWrapper, JSON

@retry(wait=wait_random_exponential(multiplier=1, max=60))
def get_constraints(property_id):
    """Retrieve value-type and subject-type constraints for a specified Wikidata property."""
    
    sparql = SPARQLWrapper("https://query.wikidata.org/sparql")
    
    query = f"""
    SELECT ?constraintType ?entity ?entityLabel WHERE {{
      VALUES ?property {{ wd:{property_id} }}  

      ?property p:P2302 ?statement.  # Property constraints
      ?statement ps:P2302 ?constraintEntity.  # Constraint type

      VALUES ?constraintEntity {{ wd:Q21510865 wd:Q21503250 }}  # Value-type & Subject-type constraints

      ?statement pq:P2308 ?entity.  # The constrained entity type (allowed type)

      BIND(
        IF(?constraintEntity = wd:Q21510865, "Value-type constraint", "Subject type constraint")
        AS ?constraintType
      )
    }}
    """
    # SERVICE wikibase:label {{ bd:serviceParam wikibase:language "en". }}
    
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    results = sparql.query().convert()

    constraints = {"Value-type constraint": [], "Subject type constraint": []}
    for result in results["results"]["bindings"]:
        constraints[result["constraintType"]["value"]].append(result["entity"]["value"].split("/")[-1])

    return constraints

# Example usage:
property_id = "P40"  # Replace with any Wikidata property ID
constraints = get_constraints(property_id)

print(constraints)

{'Value-type constraint': ['Q5', 'Q729', 'Q4886', 'Q95074', 'Q178885', 'Q207174', 'Q795052', 'Q2135501', 'Q4271324', 'Q13002315', 'Q16979650', 'Q21070568', 'Q21070598', 'Q21191150', 'Q24334299', 'Q64520857', 'Q75855169', 'Q115537581'], 'Subject type constraint': ['Q5', 'Q729', 'Q4886', 'Q95074', 'Q178885', 'Q207174', 'Q215627', 'Q219160', 'Q795052', 'Q2135501', 'Q3046146', 'Q4271324', 'Q13002315', 'Q16979650', 'Q21070568', 'Q21070598', 'Q24334299', 'Q75855169', 'Q115537581']}


In [12]:
constraint_dict = {}

for prop in tqdm(PROP_2_LABEL.keys()):
    constraint_dict[prop] = get_constraints(prop)
    time.sleep(0.1)
len(constraint_dict)

100%|██████████| 2414/2414 [21:12<00:00,  1.90it/s]


2414

In [5]:
with open(ONTOLOGY_MAPPINGS_DIR+"prop2data_type.json", 'r') as f:
    PROP_2_DATA_TYPE = json.load(f)

In [6]:
PROP_2_DATA_TYPE

{'P6': 'Item',
 'P16': 'Item',
 'P17': 'Item',
 'P19': 'Item',
 'P20': 'Item',
 'P21': 'Item',
 'P22': 'Item',
 'P25': 'Item',
 'P26': 'Item',
 'P27': 'Item',
 'P30': 'Item',
 'P31': 'Item',
 'P35': 'Item',
 'P36': 'Item',
 'P37': 'Item',
 'P38': 'Item',
 'P39': 'Item',
 'P40': 'Item',
 'P47': 'Item',
 'P50': 'Item',
 'P53': 'Item',
 'P54': 'Item',
 'P57': 'Item',
 'P58': 'Item',
 'P59': 'Item',
 'P61': 'Item',
 'P65': 'Item',
 'P66': 'Item',
 'P69': 'Item',
 'P78': 'Item',
 'P81': 'Item',
 'P84': 'Item',
 'P85': 'Item',
 'P86': 'Item',
 'P87': 'Item',
 'P88': 'Item',
 'P91': 'Item',
 'P92': 'Item',
 'P97': 'Item',
 'P98': 'Item',
 'P101': 'Item',
 'P102': 'Item',
 'P103': 'Item',
 'P105': 'Item',
 'P106': 'Item',
 'P108': 'Item',
 'P110': 'Item',
 'P111': 'Item',
 'P112': 'Item',
 'P113': 'Item',
 'P114': 'Item',
 'P115': 'Item',
 'P118': 'Item',
 'P119': 'Item',
 'P121': 'Item',
 'P122': 'Item',
 'P123': 'Item',
 'P126': 'Item',
 'P127': 'Item',
 'P128': 'Item',
 'P129': 'Item',
 'P1

In [10]:
constraint_dict['P2294']

{'Value-type constraint': ['Q309314'], 'Subject type constraint': ['Q56061']}

In [13]:
wo_constraint = []
for prop in constraint_dict:
    if len(constraint_dict[prop]["Value-type constraint"]) == 0 and len(constraint_dict[prop]["Subject type constraint"]) == 0:
            wo_constraint.append(prop)
len(wo_constraint)

581

In [14]:
quantity_props = []
time_props = []
other_props = []
for prop in wo_constraint:
    if PROP_2_DATA_TYPE[prop] == "Quantity": 
        quantity_props.append(prop)
    elif PROP_2_DATA_TYPE[prop] == "Point in time": 
        time_props.append(prop)
    else:
        other_props.append(prop)
        # print(PROP_2_LABEL[prop])
len(time_props), len(quantity_props), len(other_props)

(28, 295, 258)

In [9]:
for prop in constraint_dict:
    if PROP_2_DATA_TYPE[prop] == "Point in time":
        constraint_dict[prop]["Value-type constraint"].append('Q186408')

    elif PROP_2_DATA_TYPE[prop] == 'Quantity':
        constraint_dict[prop]["Value-type constraint"].append('Q309314')

In [11]:
with open(ONTOLOGY_MAPPINGS_DIR+'prop2constraints.json', 'w') as f:
    json.dump(constraint_dict, f)

In [3]:
with open(ONTOLOGY_MAPPINGS_DIR+'prop2constraints.json', 'r') as f:
    constraint_dict = json.load(f)

In [5]:
with open(ONTOLOGY_MAPPINGS_DIR+"prop2data_type.json", 'r') as f:
    PROP_2_DATA_TYPE = json.load(f)

## Colecting entities' metadata

In [31]:
entities = set()
for prop, constraint in constraint_dict.items():

    for const_type in constraint:
        for entity in constraint[const_type]:
            entities.add(entity)
entities = list(entities)

In [32]:
len(entities)

3460

### Collecting entities' hierarchy of superclasses

In [33]:
@retry(wait=wait_random_exponential(multiplier=1, max=60))
def get_subclass_hierarchy(entity_id):
      sparql = SPARQLWrapper("https://query.wikidata.org/sparql")

      # SPARQL query to get all subclasses (direct and indirect) of the given entity
      query = f"""
      SELECT DISTINCT ?subclass ?subclassLabel WHERE {{
          {{
              wd:{entity_id} wdt:P31/wdt:P279* ?subclass.
          }}
            UNION
          {{
              wd:{entity_id} wdt:P279* ?subclass.
          }}
      }}
      """
    # SERVICE wikibase:label {{ bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }}

      sparql.setQuery(query)
      sparql.setReturnFormat(JSON)

      results = sparql.query().convert()

      subclass_hierarchy = []

      for result in results["results"]["bindings"]:
          subclass_id = result["subclass"]["value"].split("/")[-1]
          subclass_hierarchy.append(subclass_id)

      return subclass_hierarchy

ENTITY_2_HIERARCHY = {}
for entity_id in tqdm(entities):
    hierarchy = get_subclass_hierarchy(entity_id)
    ENTITY_2_HIERARCHY[entity_id] = hierarchy

len(ENTITY_2_HIERARCHY)

100%|██████████| 3460/3460 [29:57<00:00,  1.93it/s]  


3460

In [34]:
ents = []
for item in ENTITY_2_HIERARCHY.values():
    ents.extend(item)
len(set(ents)), len(entities)

(7074, 3460)

In [35]:
# leaving only entity types that are used in constraints
for entity in tqdm(ENTITY_2_HIERARCHY):
    filtered_super_entities = [item for item in ENTITY_2_HIERARCHY[entity] if item in entities]
    ENTITY_2_HIERARCHY[entity] = filtered_super_entities

100%|██████████| 3460/3460 [00:03<00:00, 865.48it/s]


In [36]:
ents = []
for item in ENTITY_2_HIERARCHY.values():
    ents.extend(item)
len(set(ents)), len(entities)

(3460, 3460)

In [37]:
with open(ONTOLOGY_MAPPINGS_DIR + 'entity_hierarchy.json', 'w') as f:
    json.dump(ENTITY_2_HIERARCHY, f)

### Collecting entity types' labels

In [38]:
BATCH_SIZE = 50

@retry(wait=wait_random_exponential(multiplier=1, max=60))
def fetch_labels(batch):
    entity_values = " ".join(f"wd:{entity}" for entity in batch)
    
    query = f"""
    SELECT ?entity ?entityLabel WHERE {{
      VALUES ?entity {{ {entity_values} }}
      SERVICE wikibase:label {{ bd:serviceParam wikibase:language "en". }}
    }}
    """
    
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    
    try:
        results = sparql.query().convert()
        return {
            result["entity"]["value"].split("/")[-1]: result.get("entityLabel", {}).get("value", "No label")
            for result in results["results"]["bindings"]
        }
    except Exception as e:
        print(f"Error with batch {batch[:5]}...: {e}")
        return {}

ENTITY_2_LABEL = {}

for i in range(0, len(entities), BATCH_SIZE):
    batch = entities[i:i + BATCH_SIZE]
    print(f"Processing batch {i // BATCH_SIZE + 1}/{(len(entities) // BATCH_SIZE) + 1}")
    
    labels = fetch_labels(batch)
    ENTITY_2_LABEL.update(labels)
    
# for entity, label in all_labels.items():
#     print(f"{entity}: {label}")

Processing batch 1/70


Processing batch 2/70
Processing batch 3/70
Processing batch 4/70
Processing batch 5/70
Processing batch 6/70
Processing batch 7/70
Processing batch 8/70
Processing batch 9/70
Processing batch 10/70
Processing batch 11/70
Processing batch 12/70
Processing batch 13/70
Processing batch 14/70
Processing batch 15/70
Processing batch 16/70
Processing batch 17/70
Processing batch 18/70
Processing batch 19/70
Processing batch 20/70
Processing batch 21/70
Processing batch 22/70
Processing batch 23/70
Processing batch 24/70
Processing batch 25/70
Processing batch 26/70
Processing batch 27/70
Processing batch 28/70
Processing batch 29/70
Processing batch 30/70
Processing batch 31/70
Processing batch 32/70
Processing batch 33/70
Processing batch 34/70
Processing batch 35/70
Processing batch 36/70
Processing batch 37/70
Processing batch 38/70
Processing batch 39/70
Processing batch 40/70
Processing batch 41/70
Processing batch 42/70
Processing batch 43/70
Processing batch 44/70
Processing batch 45

In [39]:
len(ENTITY_2_LABEL)

3460

In [40]:
len(set(ENTITY_2_LABEL.keys())), len(set(ENTITY_2_LABEL.values()))

(3460, 3402)

In [41]:
label2entity = {}
for entity, label in ENTITY_2_LABEL.items():
    if label not in label2entity:
        label2entity[label] = []
    label2entity[label].append(entity)

for label, entities in label2entity.items():
    if len(entities) > 1:
        print(label, entities)

kinship ['Q171318', 'Q109664302']
repository ['Q2145117', 'Q3133368', 'Q108296843']
article ['Q712597', 'Q191067']
attribute ['Q109674924', 'Q2722260']
crossing ['Q10816681', 'Q62059481']
lineage ['Q1642895', 'Q1517820']
bibliography ['Q134995', 'Q1631107']
test ['Q27318', 'Q1003030']
commission ['Q63705303', 'Q55657615']
statement ['Q613299', 'Q2684591']
motif ['Q1229071', 'Q68614425']
color ['Q22006653', 'Q1075']
competition ['Q476300', 'Q841654', 'Q23807345']
cabinet ['Q640506', 'Q6866562']
process ['Q3249551', 'Q10843872']
territory ['Q183366', 'Q4835091']
class ['Q16889133', 'Q18204', 'Q217594']
language ['Q4113741', 'Q34770']
parish ['Q7137411', 'Q102496']
goal ['Q109405570', 'Q4503831']
report ['Q10870555', 'Q10429085']
character ['Q3241972', 'Q95074']
monastery ['Q6021560', 'Q44613']
season ['Q27020041', 'Q10688145']
group ['Q83478', 'Q16887380']
record ['Q107435521', 'Q1241356']
position ['Q1781513', 'Q4164871']
register ['Q19386377', 'Q286576']
space ['Q2133296', 'Q107']
epit

### Collecting descriptions for entity types with duplicated labels

In [42]:
@retry(wait=wait_random_exponential(multiplier=1, max=60))
def get_entity_info(entity_id):
    sparql = SPARQLWrapper("https://query.wikidata.org/sparql")
    
    query = f"""
    SELECT ?entityLabel ?entityDescription WHERE {{
      wd:{entity_id} rdfs:label ?entityLabel .
      wd:{entity_id} schema:description ?entityDescription .
      FILTER (lang(?entityLabel) = "en")
      FILTER (lang(?entityDescription) = "en")
    }}
    """
    
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    results = sparql.query().convert()
    if results["results"]["bindings"]:
        result = results["results"]["bindings"][0]
        return {
            "label": result["entityLabel"]["value"],
            "description": result["entityDescription"]["value"]
        }
    else:
        return None

for label, entities in label2entity.items():
    if len(entities) > 1:
        for entity_id in entities:
            info = get_entity_info(entity_id)
            ENTITY_2_LABEL[entity_id] = info['label'] + " (" + info['description'] +")"

In [43]:
len(set(ENTITY_2_LABEL.keys())), len(set(ENTITY_2_LABEL.values()))

(3460, 3460)

### Collecting entity types' aliases

In [122]:
@retry(wait=wait_random_exponential(multiplier=1, max=60))
def get_entity_aliases(entity_id):
    chinese_japanese_pattern = re.compile(r"[\u4E00-\u9FFF\u3400-\u4DBF\uF900-\uFAFF\u3040-\u309F\u30A0-\u30FF\u31F0-\u31FF\uFF00-\uFFEF]")
    sparql = SPARQLWrapper("https://query.wikidata.org/sparql")
    
    query = f"""
    SELECT ?alias WHERE {{
      wd:{entity_id} skos:altLabel ?alias .
      FILTER (lang(?alias) = "en")
    }}
    """
    
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    results = sparql.query().convert()

    aliases = []
    for result in results["results"]["bindings"]:

        alias = result["alias"]["value"]
        if not chinese_japanese_pattern.search(alias):
          aliases.append(alias)
        # except Exception as e:
        #    continue
    
    return aliases

ENTITY_2_ALIASES = {}
for entity_id in tqdm(ENTITY_2_LABEL.keys()):
    ENTITY_2_ALIASES[entity_id] = get_entity_aliases(entity_id)

100%|██████████| 3460/3460 [20:08<00:00,  2.86it/s]  


In [142]:
# [(ENTITY_2_LABEL[ent], aliases) for ent, aliases in ENTITY_2_ALIASES.items()]

In [90]:
ENTITY_2_LABEL["Q186408"], get_entity_aliases("Q186408")

('point in time',
 ['date', 'moment', 'given moment', 'instant', 'time point', 'timepoint'])

In [135]:
ENTITY_2_LABEL["Q309314"], get_entity_aliases("Q309314")

('quantity', ['amount', 'number', 'qty', 'quantulum'])

## Building inverse mapping - object/subject type constraint to relation

### Checking that relations with 'point in time' and 'quantity' data types don't have other constraints

In [147]:
for prop, data_type in PROP_2_DATA_TYPE.items():
    if data_type == "Quantity":    
        val_constraints = [ENTITY_2_LABEL[ent] for ent in constraint_dict[prop]["Value-type constraint"]]
        subj_constraints = [ENTITY_2_LABEL[ent] for ent in constraint_dict[prop]["Subject type constraint"]]            
        # print(PROP_2_LABEL[prop], subj_constraints, val_constraints)
        assert len(val_constraints) == 0
        # no value constraints for data type quantity

In [148]:
for prop, data_type in PROP_2_DATA_TYPE.items():
    if data_type == "Point in time":    
        val_constraints = [ENTITY_2_LABEL[ent] for ent in constraint_dict[prop]["Value-type constraint"]]
        subj_constraints = [ENTITY_2_LABEL[ent] for ent in constraint_dict[prop]["Subject type constraint"]]            
        # print(PROP_2_LABEL[prop], subj_constraints, val_constraints)
        assert len(val_constraints) == 0
        # no value constraints for data type quantity

In [134]:
constraint_dict

{'P6': {'Value-type constraint': ['Q5', 'Q95074'],
  'Subject type constraint': ['Q7188',
   'Q56061',
   'Q327333',
   'Q524572',
   'Q640506',
   'Q1006644',
   'Q1048835',
   'Q1145276',
   'Q6866562',
   'Q64034456']},
 'P16': {'Value-type constraint': ['Q5503',
   'Q34442',
   'Q498002',
   'Q639030',
   'Q924286',
   'Q3241753',
   'Q15640053',
   'Q25631158',
   'Q30014735',
   'Q124130081'],
  'Subject type constraint': ['Q376799',
   'Q1067164',
   'Q25377652',
   'Q28043022',
   'Q44667495',
   'Q120367839']},
 'P17': {'Value-type constraint': ['Q6256',
   'Q7275',
   'Q43702',
   'Q48349',
   'Q59281',
   'Q148837',
   'Q161243',
   'Q170156',
   'Q182547',
   'Q253836',
   'Q1145276',
   'Q1151405',
   'Q1250464',
   'Q1620908',
   'Q1896989',
   'Q2577883',
   'Q3024240',
   'Q3238337',
   'Q3624078',
   'Q3895768',
   'Q5982983',
   'Q10711424',
   'Q15239622',
   'Q15304003',
   'Q28171280'],
  'Subject type constraint': []},
 'P19': {'Value-type constraint': ['Q4130',
 

In [6]:
subj2prop_constraints = {"<ANY SUBJECT>": []}
# Q309314 - quantity, Q186408 -  point in time 
obj2prop_constraint = {"<ANY OBJECT>": [], "Q309314": [], 'Q186408': []}

for prop, constraint in constraint_dict.items():

    if PROP_2_DATA_TYPE[prop] == "Point in time":
        obj2prop_constraint['Q186408'].append(prop)
    
    elif PROP_2_DATA_TYPE[prop] == 'Quantity':
        obj2prop_constraint['Q309314'].append(prop)
    
    elif len(constraint["Value-type constraint"]) == 0:
        obj2prop_constraint["<ANY OBJECT>"].append(prop)
    
    else:
        for entity in constraint["Value-type constraint"]:
            if entity not in obj2prop_constraint:
                obj2prop_constraint[entity] = []
            obj2prop_constraint[entity].append(prop)

    
    if len(constraint["Subject type constraint"]) == 0:
        subj2prop_constraints["<ANY SUBJECT>"].append(prop)

    else:
        for entity in constraint["Subject type constraint"]:
            if entity not in subj2prop_constraints:
                subj2prop_constraints[entity] = []
            subj2prop_constraints[entity].append(prop)


len(subj2prop_constraints), len(obj2prop_constraint)

(2194, 2143)

In [7]:
with open(ONTOLOGY_MAPPINGS_DIR+"subj_constraint2prop.json", 'w') as f:
    json.dump(subj2prop_constraints, f)

with open(ONTOLOGY_MAPPINGS_DIR+"obj_constraint2prop.json", 'w') as f:
    json.dump(obj2prop_constraint, f)

In [257]:
with open(ONTOLOGY_MAPPINGS_DIR+'entity_type2label.json', 'w') as f:
    json.dump(ENTITY_2_LABEL, f)

In [258]:
with open(ONTOLOGY_MAPPINGS_DIR+'prop2aliases.json', 'w') as f:
    json.dump(PROP2ALIASES, f)

In [259]:
with open(ONTOLOGY_MAPPINGS_DIR+'entity_type2aliases.json', 'w') as f:
    json.dump(ENTITY_2_ALIASES, f)

In [260]:
with open(ONTOLOGY_MAPPINGS_DIR+'entity_type2hierarchy.json', 'w') as f:
    json.dump(ENTITY_2_HIERARCHY, f)

## Index utils

In [11]:
tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
model = AutoModel.from_pretrained('facebook/contriever').to('cuda:4')

In [12]:
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings

def embed_entity_batch(entity_list):
    inputs = tokenizer(entity_list, padding=True, truncation=True, return_tensors='pt')

    outputs = model(**inputs.to('cuda:4'))
    embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
    return embeddings

## Relation index

In [31]:
with open(ONTOLOGY_MAPPINGS_DIR+"prop2label.json", 'r') as f:
    PROP_2_LABEL = json.load(f)

with open(ONTOLOGY_MAPPINGS_DIR+"prop2aliases.json", 'r') as f:
    PROP2ALIASES = json.load(f)

In [32]:
prop_name_id_pairs = []

for p, aliases in PROP2ALIASES.items():
    prop_name_id_pairs.append((PROP_2_LABEL[p], p))

    for alias in aliases:
        prop_name_id_pairs.append((alias, p))

prop_names = [item[0] for item in prop_name_id_pairs]
prop_ids = [item[1] for item in prop_name_id_pairs]
enum_prop_ids = {i: prop_id for i, prop_id in enumerate(prop_ids)}

In [33]:
prop_embeddings = []
batch_size = 100

for i in tqdm(range(0, len(prop_names), batch_size)):

    if i + batch_size > len(prop_names):
        prop_list = prop_names[i: len(prop_names)]
    else:
        prop_list = prop_names[i: i + batch_size]

    prop_embeddings.append(embed_entity_batch(prop_list).detach().to('cpu'))

100%|██████████| 109/109 [00:04<00:00, 26.81it/s]


In [34]:
prop_output = np.array(torch.concat(prop_embeddings))
prop_output.shape

(10897, 768)

In [35]:
dim = prop_output.shape[1]
metric = faiss.METRIC_INNER_PRODUCT
prop_index = faiss.index_factory(dim, "IDMap,Flat", metric)

prop_index.add_with_ids(prop_output, list(enum_prop_ids.keys()))

In [36]:
print(prop_index.is_trained)

True


In [37]:
from time import time
start, end = 200, 210
before = time()
distances, indices = prop_index.search(prop_output[start:end, :], 5)
after = time()
after - before

0.00814199447631836

In [38]:
indices, distances

(array([[ 200, 2272, 9878,  229, 9881],
        [ 201, 1304,  605,  200, 1298],
        [ 202, 1607, 1610, 8314, 8313],
        [9717,  830,  203,  445, 2278],
        [ 905,  316,  233,  204, 1301],
        [ 205,  449, 6314, 9628, 5369],
        [2798,  206, 7301,  207, 1389],
        [ 207, 2802,  219, 8612,  217],
        [ 208, 6013, 6022, 6009, 6005],
        [ 209,  126, 2828, 2823,  170]]),
 array([[1.3455201 , 1.2125266 , 1.2077221 , 1.1134841 , 1.1079599 ],
        [1.4369535 , 1.0882614 , 0.9400271 , 0.93929845, 0.92932105],
        [1.4525927 , 1.1482468 , 1.1236124 , 1.0907669 , 1.0507144 ],
        [1.5604218 , 1.5604218 , 1.5604218 , 1.5604217 , 1.3541275 ],
        [1.222655  , 1.222655  , 1.222655  , 1.222655  , 1.0452965 ],
        [1.5717002 , 1.161402  , 1.1291238 , 1.0859156 , 1.0595844 ],
        [1.3546925 , 1.3546925 , 1.1879663 , 1.1140074 , 1.021622  ],
        [2.1679156 , 1.5630033 , 1.4717267 , 1.4477936 , 1.4357615 ],
        [1.2989877 , 1.0566187 , 1.054

In [39]:
for ind, prop_name in zip(indices, prop_names[start:end]):
    print(prop_name, [prop_names[i] for i in ind], [PROP_2_LABEL[enum_prop_ids[i]] for i in ind])

writer ['writer', 'writer of', 'writer for', 'screenwriter', 'wrote for'] ['author', 'notable work', 'has written for', 'screenwriter', 'has written for']
poet ['poet', 'text poet', 'literary movement', 'writer', 'lyricist'] ['author', 'lyricist', 'movement', 'author', 'lyricist']
playwright ['playwright', 'theatre company', 'theater company', 'stage designer', 'scenic designer'] ['author', 'production company', 'production company', 'scenographer', 'scenographer']
creator ['creator', 'creator', 'creator', 'creator', 'creator of'] ['collection creator', 'creator', 'author', 'founded by', 'notable work']
written by ['written by', 'written by', 'written by', 'written by', 'lyrics by'] ['developer', 'composer', 'screenwriter', 'author', 'lyricist']
co-author ['co-author', 'co-founder', 'co-pilot', 'co-manager of', 'co-located with'] ['author', 'founded by', 'co-driver', 'coach of sports team', 'colocated with']
family ['family', 'family', 'siblings', 'member of family', 'family name'] ['r

In [40]:
faiss.write_index(prop_index, ONTOLOGY_MAPPINGS_DIR+"wikidata_relations.index")

In [41]:
with open(ONTOLOGY_MAPPINGS_DIR+'enum_prop_ids.json', 'w') as f:
    json.dump(prop_ids, f)

In [43]:
prop_id2enum = {}
for i, prop_id in enumerate(prop_ids):
    if prop_id not in prop_id2enum:
        prop_id2enum[prop_id] = []
    prop_id2enum[prop_id].append(i)
prop_id2enum

{'P6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 'P16': [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
 'P17': [25, 26, 27, 28, 29],
 'P19': [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40],
 'P20': [41, 42, 43, 44, 45, 46, 47, 48, 49],
 'P21': [50, 51, 52, 53, 54, 55, 56],
 'P22': [57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67],
 'P25': [68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81],
 'P26': [82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99],
 'P27': [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
 'P30': [110],
 'P31': [111, 112, 113, 114, 115, 116, 117, 118, 119, 120],
 'P35': [121, 122, 123, 124, 125, 126, 127, 128, 129, 130],
 'P36': [131,
  132,
  133,
  134,
  135,
  136,
  137,
  138,
  139,
  140,
  141,
  142,
  143,
  144,
  145,
  146,
  147,
  148,
  149],
 'P37': [150, 151, 152, 153, 154],
 'P38': [155, 156, 157],
 'P39': [158, 159, 160, 161, 162, 163, 164, 165, 166, 167],
 'P40': [168,
  169,
  170,
  1

In [44]:
with open(ONTOLOGY_MAPPINGS_DIR+'propid2enum.json', 'w') as f:
    json.dump(prop_id2enum, f)

## Indexing ontology labels

In [4]:
with open(ONTOLOGY_MAPPINGS_DIR+'entity_type2label.json', 'r') as f:
    ENTITY_2_LABEL = json.load(f)
# ENTITY_2_LABEL

with open(ONTOLOGY_MAPPINGS_DIR+'entity_type2hierarchy.json', 'r') as f:
    ENTITY_2_HIERARCHY = json.load(f)
# ENTITY_2_HIERARCHY

with open(ONTOLOGY_MAPPINGS_DIR+'entity_type2aliases.json', 'r') as f:
    ENTITY_2_ALIASES = json.load(f)

In [5]:
ENTITY_2_ALIASES["Q5"]

['human being',
 'humans',
 'people',
 'person',
 'individual human',
 'individual Homo sapiens',
 'non-fictional human',
 'nonfictional human',
 'modern humans']

In [6]:
entity_items = list(ENTITY_2_LABEL.items())
entity_ids = [int(item[0][1:]) for item in entity_items]
entity_names = [item[1] for item in entity_items]
len(set(entity_ids)), len(set(entity_names))

(3460, 3460)

In [7]:
entity_name_id_pairs = []

for e, aliases in ENTITY_2_ALIASES.items():
    entity_name_id_pairs.append((ENTITY_2_LABEL[e], e))

    for alias in aliases:
        entity_name_id_pairs.append((alias, e))

entity_names = [item[0] for item in entity_name_id_pairs]
entity_ids = [item[1] for item in entity_name_id_pairs]
enum_entity_ids = {i: entity_id for i, entity_id in enumerate(entity_ids)}
len(entity_names)

12297

In [27]:
entity2id = {}
for entity_id, entity_name in ENTITY_2_LABEL.items():
    entity2id[entity_name] = entity_id

with open(ONTOLOGY_MAPPINGS_DIR+'label2entity.json', 'w') as f:
    json.dump(entity2id, f)

with open(ONTOLOGY_MAPPINGS_DIR+'enum_entity_ids.json', 'w') as f:
    # json.dump(enum_entity_ids, f)
    json.dump(entity_ids, f)

In [13]:
entity_embeddings = []
batch_size = 100

for i in tqdm(range(0, len(entity_names), batch_size)):

    if i + batch_size > len(entity_names):
        entity_list = entity_names[i: len(entity_names)]
    else:
        entity_list = entity_names[i: i + batch_size]

    entity_embeddings.append(embed_entity_batch(entity_list).detach().to('cpu'))

100%|██████████| 123/123 [00:06<00:00, 19.83it/s]


In [14]:
entity_output = np.array(torch.concat(entity_embeddings))
entity_output.shape

(12297, 768)

In [15]:
dim = entity_output.shape[1]
metric = faiss.METRIC_INNER_PRODUCT
entity_index = faiss.index_factory(dim, "IDMap,Flat", metric)

# index = faiss.index_factory(dim, "Flat")  # Equivalent to IndexFlatL2
# entity_index = faiss.IndexIDMap(index)  # Wrap it in an IDMap


entity_index.add_with_ids(entity_output, list(enum_entity_ids.keys()))

In [16]:
print(entity_index.is_trained)

True


In [17]:
from time import time

start, end = 100, 110
before = time()
distances, indices = entity_index.search(entity_output[start:end, :], 3)
after = time()
after - before

0.00577998161315918

In [18]:
indices, distances

(array([[  100,   101,   103],
        [  101,   103,   100],
        [  103,   102,   101],
        [  103,   101,   102],
        [  104,   105,   100],
        [  105,   104,   102],
        [  106,   107, 11570],
        [  107,   106, 11570],
        [  108,  6661,  6663],
        [ 6509,  5401,  4868]]),
 array([[1.3560436, 1.3521374, 1.284651 ],
        [1.4821671, 1.380825 , 1.3521374],
        [1.3684261, 1.3655431, 1.2988927],
        [1.4372516, 1.380825 , 1.3684261],
        [1.3088332, 1.2300593, 1.1386945],
        [1.3916845, 1.2300593, 1.2141753],
        [1.7491164, 1.7176946, 1.4938099],
        [1.8628964, 1.7176946, 1.4420071],
        [1.7601748, 1.2296405, 1.2208571],
        [1.7292591, 1.7292591, 1.7292591]], dtype=float32))

In [19]:
for ind, entity_name in zip(indices, entity_names[start:end]):
    print(entity_name, [entity_names[i] for i in ind], [ENTITY_2_LABEL[enum_entity_ids[i]] for i in ind])

literary award ['literary award', 'literature award', 'literature prize'] ['literary award', 'literary award', 'literary award']
literature award ['literature award', 'literature prize', 'literary award'] ['literary award', 'literary award', 'literary award']
literary prize ['literature prize', 'literary prize', 'literature award'] ['literary award', 'literary award', 'literary award']
literature prize ['literature prize', 'literature award', 'literary prize'] ['literary award', 'literary award', 'literary award']
book award ['book award', 'book prize', 'literary award'] ['literary award', 'literary award', 'literary award']
book prize ['book prize', 'book award', 'literary prize'] ['literary award', 'literary award', 'literary award']
fictional religion ['fictional religion', 'fictional religions', 'fictional religious occupation'] ['fictional religion', 'fictional religion', 'fictional religious occupation']
fictional religions ['fictional religions', 'fictional religion', 'fictional

In [46]:
distances, idx = entity_index.search(embed_entity_batch(['person']).detach().cpu().numpy(), 10)
distances, idx

(array([[2.0584254, 2.0584254, 2.0584254, 1.7754977, 1.7754977, 1.6061485,
         1.5236574, 1.5170953, 1.4831569, 1.4694436]], dtype=float32),
 array([[ 9736,  9115,  1279,  9737,  9118,   609,  8939,  4717,  4730,
         10819]]))

In [47]:
print([ENTITY_2_LABEL[enum_entity_ids[i]] for i in idx[0]])

['person', 'grammatical person', 'human', 'person', 'grammatical person', 'hypothetical person', 'group of humans', 'duo', 'duo', 'individual']


In [48]:
faiss.write_index(entity_index, ONTOLOGY_MAPPINGS_DIR+"wikidata_ontology_entities.index")

In [49]:
entity_index = faiss.read_index(ONTOLOGY_MAPPINGS_DIR+"wikidata_ontology_entities.index")

In [101]:
def search_within_subset(query, subset_ids, faiss_id_index, k):

    subset_vectors = np.array([faiss_id_index.index.reconstruct(int(i)) for i in subset_ids])
    
    subset_index = faiss.IndexFlatL2(dim)
    subset_index.add(subset_vectors)
    
    distances, subset_indices = subset_index.search(query, k)
    retrieved_ids = np.array(subset_ids)[subset_indices]

    return distances, retrieved_ids

In [123]:
before = time()
subset_ids = list(np.random.randint(0, len(entity_names), 100))
dist, idx = search_within_subset(query=embed_entity_batch(['person']).detach().cpu().numpy(), subset_ids=subset_ids, faiss_id_index=entity_index, k=10)
print([ENTITY_2_LABEL[enum_entity_ids[i]] for i in idx[0]])

after = time()
after - before

['people', 'grammatical person', 'religious adherent', 'by-law', 'volcanic eruption', 'universe', 'violence', 'sacrament', 'musical group', 'solution']


0.025350570678710938

In [132]:
before = time()
subset_ids = list(np.random.randint(0, len(entity_names), 100))
dist, idx = search_within_subset(query=embed_entity_batch(['film']).detach().cpu().numpy(), subset_ids=subset_ids, faiss_id_index=entity_index, k=10)
print([ENTITY_2_LABEL[enum_entity_ids[i]] for i in idx[0]])
after = time()
after - before

['film award', 'film festival edition', 'comedy troupe', 'filmography', 'district', 'Wikimedia disambiguation page', 'association football club', 'ballet company', 'patent', 'theatre company (organization that produces theatrical performances)']


0.02289271354675293

In [138]:
before = time()
subset_ids = list(np.random.randint(0, len(entity_names), 1000))
dist, idx = search_within_subset(query=embed_entity_batch(['company']).detach().cpu().numpy(), subset_ids=subset_ids, faiss_id_index=entity_index, k=10)
print([ENTITY_2_LABEL[enum_entity_ids[i]] for i in idx[0]])
after = time()
after - before

['film production company', 'production company', 'company', 'video game developer', 'theatre company (organization that produces theatrical performances)', 'factory', 'video game developer', 'hotel', 'film festival', 'record label']


0.023937702178955078

In [75]:
entity_index.index.reconstruct(3570)

array([-7.14140013e-02, -5.10838442e-03,  5.03451675e-02, -8.86862427e-02,
       -8.38634968e-02,  1.55223720e-02,  4.14387509e-03, -2.00053435e-02,
       -2.76929624e-02, -5.16494438e-02, -7.80235752e-02,  8.93565826e-04,
       -2.71469206e-02, -7.41515309e-02, -4.63954434e-02,  1.92716923e-02,
       -2.74377130e-02, -7.01681301e-02, -1.58312060e-02, -1.70979612e-02,
        2.76822764e-02, -4.09433916e-02, -2.68200226e-02, -3.30740474e-02,
       -2.18856577e-02,  1.63533725e-02,  9.61226318e-03,  1.76195092e-02,
        5.26813976e-03, -1.90336406e-02,  1.22226588e-03, -3.83749045e-02,
        1.90886948e-03, -5.19278198e-02, -1.75596531e-02, -8.97471309e-02,
       -1.51305683e-02,  1.48628335e-02, -2.77183373e-02,  2.21263058e-03,
       -2.63622049e-02, -2.72796657e-02,  2.69438680e-02, -9.15787965e-02,
       -4.60187159e-02, -8.35824311e-02, -1.13960085e-02, -2.52104215e-02,
       -2.95539107e-02,  5.04589006e-02,  2.73709763e-02, -6.07187580e-03,
       -1.78940762e-02, -