In [1]:
import time
import numpy as np
from tqdm import tqdm

import json
import pandas as pd

from transformers import AutoTokenizer, AutoModel

from utils.eval_utils import micro_precision, micro_recall, f1_score
from utils.openai_utils import LLMTripletExtractor
from utils.verifier_utils import TripletFilter
from utils.structured_dynamic_index_utils_with_db import Aligner

import warnings
import os
import ast

warnings.filterwarnings('ignore')

import re
from unidecode import unidecode

from pymongo.mongo_client import MongoClient
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def compute_f1(prediction, truth):
    pred_tokens = prediction.split()
    truth_tokens = truth.split()
    
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec)

In [3]:
import string
def normalize(input_string):
    input_string = unidecode(input_string)
    input_string = input_string.lower()
        # Replace all punctuation with a space
    input_string = re.sub(f"[{re.escape(string.punctuation)}]", " ", input_string)
    
    # Replace multiple spaces with a single space
    input_string = re.sub(r"\s+", " ", input_string)
    
    # Trim leading/trailing whitespace
    return input_string.strip()

In [4]:
def get_mongo_client(mongo_uri):
    client = MongoClient(mongo_uri)
    return client

mongo_client = get_mongo_client("mongodb://localhost:27018/?directConnection=true")
db = mongo_client.get_database("wikidata_ontology")
db.list_collection_names()

['properties',
 'property_aliases',
 'entity_aliases',
 'triplets',
 'filtered_triplets',
 'entity_type_aliases',
 'entity_types']

In [16]:
db.get_collection('triplets').delete_many({})

DeleteResult({'n': 15532, 'electionId': ObjectId('7fffffff0000000000000008'), 'opTime': {'ts': Timestamp(1751401842, 1554), 't': 8}, 'ok': 1.0, '$clusterTime': {'clusterTime': Timestamp(1751401842, 1554), 'signature': {'hash': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', 'keyId': 0}}, 'operationTime': Timestamp(1751401842, 1554)}, acknowledged=True)

In [17]:
list(db.get_collection('triplets').find({}, {"_id": 0, "sample_id": 1}))

[]

In [20]:
list(db.get_collection('triplets').find({"sample_id": 'a4ddb2ba-eb7e-4268-92ca-fec0036c35cc'}))

[{'_id': ObjectId('686446949362efa843c65a26'),
  'subject': 'Michael Joseph Jackson',
  'object': 'August 29, 1958',
  'relation': 'date of birth',
  'object_type': 'point in time',
  'subject_type': 'human',
  'completion_token_num': 0,
  'prompt_token_nums': 0,
  'qualifiers': [],
  'sample_id': 'a4ddb2ba-eb7e-4268-92ca-fec0036c35cc',
  'source_text_id': 0},
 {'_id': ObjectId('686446969362efa843c65a27'),
  'subject': 'Michael Joseph Jackson',
  'subject_type': 'human',
  'object_type': 'point in time',
  'relation': 'date of death',
  'object': 'June 25, 2009',
  'completion_token_num': 0,
  'prompt_token_nums': 0,
  'qualifiers': [],
  'sample_id': 'a4ddb2ba-eb7e-4268-92ca-fec0036c35cc',
  'source_text_id': 0},
 {'_id': ObjectId('686446979362efa843c65a28'),
  'subject': 'Michael Joseph Jackson',
  'object': 'singer',
  'relation': 'occupation',
  'object_type': 'profession',
  'subject_type': 'human',
  'completion_token_num': 0,
  'prompt_token_nums': 0,
  'qualifiers': [],
  'samp

In [21]:
list(db.get_collection('entity_aliases').find({"sample_id": '0'}, {"_id": 0, "label": 1, "entity_type": 1, "alias": 1, "sample_id": 1}))

[]

In [8]:
list(db.get_collection('triplets').find({"sample_id": '0'}))

[{'_id': ObjectId('685468cfdc5ef5b1d9abb334'),
  'subject': '"Billie Jean"',
  'relation': 'instance of',
  'object': 'song',
  'subject_type': 'creative work',
  'object_type': 'musical work',
  'qualifiers': [],
  'sample_id': '0'},
 {'_id': ObjectId('685468d7dc5ef5b1d9abb338'),
  'subject': '"Billie Jean"',
  'relation': 'instance of',
  'object': 'song',
  'subject_type': 'creative work',
  'object_type': 'musical work',
  'qualifiers': [],
  'sample_id': '0'},
 {'_id': ObjectId('685468dedc5ef5b1d9abb33a'),
  'subject': '"Billie Jean"',
  'relation': 'instance of',
  'object': 'song',
  'subject_type': 'creative work',
  'object_type': 'musical work',
  'qualifiers': [],
  'sample_id': '0'},
 {'_id': ObjectId('685469f5004d743558e8cc0e'),
  'subject': '"Billie Jean"',
  'relation': 'instance of',
  'object': 'song',
  'subject_type': 'creative work',
  'object_type': 'musical work',
  'qualifiers': [],
  'source_text_id': 0,
  'prompt_token_nums': 0,
  'completion_token_num': 0,
  '

In [5]:
from utils.structured_dynamic_index_utils_with_db import Aligner
aligner = Aligner(db)
model_name = 'gpt-4.1-mini'
# model_name = 'gpt-4.1'
extractor = LLMTripletExtractor(model=model_name)

with open("musique_200_test.json", "r") as f:
    ds = json.load(f)


In [6]:
aligner.retrieve_entity_type_hirerarchy("human")

['Q5',
 'Q103940464',
 'Q106559804',
 'Q10855152',
 'Q154954',
 'Q16889133',
 'Q215627',
 'Q21871294',
 'Q223557',
 'Q24229398',
 'Q26401003',
 'Q27043950',
 'Q35120',
 'Q3778211',
 'Q4406616',
 'Q488383',
 'Q5',
 'Q55983715',
 'Q66394244',
 'Q7048977',
 'Q7239',
 'Q729',
 'Q795052',
 'Q98119401',
 'Q99527517']

In [7]:
list(db.get_collection('entity_aliases').find({"sample_id": "2hop__121145_561444"}, {"_id": 0, "label": 1, "entity_type": 1, "sample_id": 1}))

[{'label': 'reality television series',
  'entity_type': 'genre',
  'sample_id': '2hop__121145_561444'},
 {'label': 'The Real L Word',
  'entity_type': 'television series',
  'sample_id': '2hop__121145_561444'},
 {'label': 'United States',
  'entity_type': 'country',
  'sample_id': '2hop__121145_561444'},
 {'label': 'Showtime',
  'entity_type': 'television network',
  'sample_id': '2hop__121145_561444'},
 {'label': 'June 20, 2010',
  'entity_type': 'point in time',
  'sample_id': '2hop__121145_561444'},
 {'label': 'Ilene Chaiken',
  'entity_type': 'human',
  'sample_id': '2hop__121145_561444'},
 {'label': 'Magical Elves Productions',
  'entity_type': 'production company',
  'sample_id': '2hop__121145_561444'},
 {'label': 'group of lesbians',
  'entity_type': 'group of humans',
  'sample_id': '2hop__121145_561444'},
 {'label': 'Los Angeles',
  'entity_type': 'city',
  'sample_id': '2hop__121145_561444'},
 {'label': 'Brooklyn',
  'entity_type': 'district',
  'sample_id': '2hop__121145_56

In [8]:
# for idx in db.get_collection('entity_aliases').list_indexes():
#     print(idx)

In [9]:
aligner.retrive_similar_entity_names("John", k=5, sample_id="2hop__135993_160249")

[{'entity': 'John Lennon', 'entity_type': 'person'},
 {'entity': 'He', 'entity_type': 'human'},
 {'entity': 'the university', 'entity_type': 'educational institution'},
 {'entity': '1982', 'entity_type': 'point in time'},
 {'entity': 'William Eliot', 'entity_type': 'human'}]

In [10]:
aligner.retrieve_entity_by_type("Derech Mitzvosecha", sample_id="2hop__121145_561444", entity_type="human")

{'Rabbi Menachem Mendel Schneersohn': 'Rabbi Menachem Mendel Schneersohn',
 'Rabbi Dovber Schneuri': 'Rabbi Dovber Schneuri',
 'Rabbi Mordecai M. Kaplan': 'Rabbi Mordecai M. Kaplan',
 'Chaya Mushka Schneersohn': 'Chaya Mushka Schneersohn',
 'Michael Cera': 'Clark Duke',
 'Jack Cohen': 'Jack Cohen',
 'Ilene Chaiken': 'Ilene Chaiken',
 'Satyajit Ray': 'Satyajit Ray',
 'Clark Duke': 'Clark Duke',
 'Jean Isidore Harispe': 'Jean Isidore Harispe'}

In [11]:
from pymongo.operations import SearchIndexModel

vector_search_index_model = SearchIndexModel(
    definition={
        "mappings": {
            "dynamic": True,
            "fields": {
                "alias_text_embedding": {
                    "dimensions": 768,
                    "similarity": "cosine",
                    "type": "knnVector",
                },
                "entity_type": {
                    "type": "token"
                },
                "sample_id": {
                    "type": "token"
                }
            },
        }
    },
    name="entities",
)

db.get_collection('entity_aliases').create_search_index(model=vector_search_index_model)

'entities'

In [12]:
len(set([elem['sample_id'] for elem in list(db.get_collection('entity_aliases').find({}, {"_id": 0, "sample_id": 1}))]))

50

In [13]:
with open("musique_200_test.json", "r") as f:
    ds = json.load(f)
    
ds = ds['data'][:50]

id2sample = {}
for elem in ds:
    id2sample[elem['id']] = elem

In [14]:
lens = []
for sample_id in id2sample.keys():
    lens.append(len(list(db.get_collection('triplets').find({"sample_id": sample_id}))))
assert all([item > 0 for item in lens])
# lens

In [23]:
sample_id2ans = {}
for sample_id in id2sample.keys():
    
    lens.append(len(list(db.get_collection('triplets').find({"sample_id": sample_id}))))
    question = id2sample[sample_id]['question']
    entities = extractor.extract_entities_from_question(question)
    identified_entities = []
    chosen_entities = []
    print(entities)
    if isinstance(entities, dict):
        entities = [entities]

    for ent in entities:
        similar_entities = aligner.retrive_similar_entity_names(entity_name=ent, k=5, sample_id=sample_id)
        exact_entity_match = [e for e in similar_entities if e['entity']==ent]
        if len(exact_entity_match) > 0:
            chosen_entities.extend(exact_entity_match)
        else:
            identified_entities.extend(similar_entities)
            
    chosen_entities.extend(extractor.identify_relevant_entities(question=question, entity_list=identified_entities))
    print("Chosen relevant entities: ", chosen_entities)
    entity_set = {(e['entity'], e['entity_type']) for e in chosen_entities}

    entities4search = list(entity_set)
    or_conditions = []

    for val, typ in entities4search:
        or_conditions.append({
            '$and': [
                {'subject': val},
                {'subject_type': typ}
            ]
        })
        or_conditions.append({
            '$and': [
                {'object': val},
                {'object_type': typ}
            ]
        })

    pipeline = [
        {
            '$match': {
                'sample_id': sample_id,
                '$or': or_conditions
            }
        }
    ]
    entities4search = [ent[0] for ent in entity_set]

    for i in range(5):
        
        or_conditions = []

        for ent in entities4search:
            or_conditions.append({
                '$and': [
                    {'subject': ent},
                ]
            })
            or_conditions.append({
                '$and': [
                    {'object': ent},
                ]
            })

        pipeline = [
            {
                '$match': {
                    'sample_id': sample_id,
                    '$or': or_conditions
                }
            }
        ]

        results = list(db.get_collection('triplets').aggregate(pipeline))

        for doc in results:
            entities4search.append(doc['subject'])
            entities4search.append(doc['object'])

            for q in doc['qualifiers']:
                entities4search.append(q['object'])

        entities4search = list(set(entities4search))
                        
    print(len(results))
    supporting_triplets = []
    for item in results:
        supporting_triplets.append({"subject": item['subject'], 'relation': item['relation'], 'object': item['object'], 'qualifiers': item['qualifiers']})
    

    ans = extractor.answer_question(question=question, triplets=supporting_triplets)
    print(question, ' | ', ans, " | ", id2sample[sample_id]['answer'])
    sample_id2ans[sample_id] = ans

['Derech Mitzvosecha', 'creator of Derech Mitzvosecha']
Chosen relevant entities:  [{'entity': 'second Rebbe of the Chabad Hasidic movement', 'entity_type': 'position (social role with a set of powers and responsibilities within an organization)'}, {'entity': 'third Rebbe', 'entity_type': 'position (social role with a set of powers and responsibilities within an organization)'}]
19
Who did the creator of Derech Mitzvosecha follow?  |  third Rebbe  |  Dovber Schneuri
['NBA scoring title', 'team', 'winner']
Chosen relevant entities:  [{'entity': 'NBA scoring title', 'entity_type': 'award'}, {'entity': 'NBA scoring title', 'entity_type': 'award'}, {'entity': 'Houston Rockets', 'entity_type': 'sports team'}]
263
What team drafted the winner of the NBA scoring title this year?  |  Golden State Warriors  |  Oklahoma City Thunder
['The Beach', 'Pao Sarasin']
Chosen relevant entities:  [{'entity': 'Pao Sarasin', 'entity_type': 'human'}, {'entity': 'Pao Sarasin', 'entity_type': 'person'}]
16
Wh

In [24]:
sample_id2ans

{'2hop__121145_561444': 'third Rebbe',
 '2hop__86689_728109': 'Golden State Warriors',
 '3hop1__462960_160545_62931': '""',
 '3hop1__68732_39743_24526': 'often below freezing point',
 '2hop__364489_861485': 'Minnesota History Center',
 '2hop__835710_7298': 'Josh Groban',
 '2hop__96062_159673': 'Deepwater Horizon',
 '3hop2__79512_16214_84681': 'early 1700s',
 '3hop1__831499_228453_10972': '""',
 '3hop2__230_89048_66294': 'Lawrence Hilton',
 '2hop__96414_47902': 'Damon',
 '3hop1__462960_160545_34754': 'occupied for over 30 years',
 '2hop__142699_67465': 'March 11, 2011',
 '2hop__115515_779396': 'University of Glasgow',
 '4hop1__166471_49925_13759_736921': '""',
 '2hop__622308_61845': 'Mido',
 '3hop1__773338_42197_18397': 'Politburo',
 '3hop2__87184_38738_76291': 'January 2015',
 '4hop2__161602_474028_88460_18966': 'Hokkien speakers',
 '2hop__6870_16335': 'intermarriage between Ashkenazi and non-Ashkenazi; many do not see historic markers as relevant to their life experiences as Jews',
 '

In [25]:
f1s = []
ems = []
for sample_id in id2sample.keys():
        question = id2sample[sample_id]['question']
        ans = sample_id2ans[sample_id]
        max_f1 = 0
        max_em = 0
        max_f1_entity = ''
        max_em_entity = ''
        gold_answers_variations = [id2sample[sample_id]['answer']]
        gold_answers_variations.extend(id2sample[sample_id]['answer_aliases'])
        for golden_answer in gold_answers_variations:
            golden_answer = normalize(golden_answer)
            ans = normalize(ans)
            f1 = compute_f1(prediction=ans, truth=golden_answer)
            em = golden_answer == ans


            if f1 > max_f1:
                max_f1_entity = golden_answer
            if em > max_em:
                max_em_entity = golden_answer
            max_f1 = max(max_f1, f1)
            max_em = max(max_em, em)
        
        print(sample_id, " | ", ans, " | ", max_em_entity, " | ", max_f1_entity)
        f1s.append(max_f1)
        ems.append(max_em)

sum(ems) / len(ems), sum(f1s)/len(f1s)

2hop__121145_561444  |  third rebbe  |    |  
2hop__86689_728109  |  golden state warriors  |    |  
3hop1__462960_160545_62931  |    |    |  
3hop1__68732_39743_24526  |  often below freezing point  |    |  
2hop__364489_861485  |  minnesota history center  |  minnesota history center  |  minnesota history center
2hop__835710_7298  |  josh groban  |    |  
2hop__96062_159673  |  deepwater horizon  |    |  
3hop2__79512_16214_84681  |  early 1700s  |    |  
3hop1__831499_228453_10972  |    |    |  
3hop2__230_89048_66294  |  lawrence hilton  |    |  lawrence hilton jacobs
2hop__96414_47902  |  damon  |    |  matt damon
3hop1__462960_160545_34754  |  occupied for over 30 years  |    |  
2hop__142699_67465  |  march 11 2011  |  march 11 2011  |  march 11 2011
2hop__115515_779396  |  university of glasgow  |  university of glasgow  |  university of glasgow
4hop1__166471_49925_13759_736921  |    |    |  
2hop__622308_61845  |  mido  |  mido  |  mido
3hop1__773338_42197_18397  |  politburo 

(0.3, 0.3802600195503421)

In [21]:
0.3, 0.39260606060606057

(0.3, 0.39260606060606057)

In [22]:
(0.3, 0.3816236559139785)

(0.3, 0.3816236559139785)

In [28]:
np.std([0.39260606060606057, 0.3802600195503421, 0.3816236559139785])

0.005526677188393516

In [14]:
with open("hotpotqa200.json", "r") as f:
    ds = json.load(f)

ds = ds[:50]
id2sample = {}
for elem in ds:
    id2sample[elem['_id']] = elem

sampled_ids = list(id2sample.keys())[:50]

In [21]:
sample_id2ans = {}
lens = []
for sample_id in id2sample.keys():
    
    lens.append(len(list(db.get_collection('triplets').find({"sample_id": sample_id}))))
    question = id2sample[sample_id]['question']
    entities = extractor.extract_entities_from_question(question)
    identified_entities = []
    chosen_entities = []
    print(entities)
    if isinstance(entities, dict):
        entities = [entities]

    for ent in entities:
        similar_entities = aligner.retrive_similar_entity_names(entity_name=ent, k=5, sample_id=sample_id)
        exact_entity_match = [e for e in similar_entities if e['entity']==ent]
        if len(exact_entity_match) > 0:
            chosen_entities.extend(exact_entity_match)
        else:
            identified_entities.extend(similar_entities)
            
    chosen_entities.extend(extractor.identify_relevant_entities(question=question, entity_list=identified_entities))
    print("Chosen relevant entities: ", chosen_entities)
    entity_set = {(e['entity'], e['entity_type']) for e in chosen_entities}

    entities4search = list(entity_set)
    or_conditions = []

    for val, typ in entities4search:
        or_conditions.append({
            '$and': [
                {'subject': val},
                {'subject_type': typ}
            ]
        })
        or_conditions.append({
            '$and': [
                {'object': val},
                {'object_type': typ}
            ]
        })

    pipeline = [
        {
            '$match': {
                'sample_id': sample_id,
                '$or': or_conditions
            }
        }
    ]
    entities4search = [ent[0] for ent in entity_set]

    for i in range(5):
        
        or_conditions = []

        for ent in entities4search:
            or_conditions.append({
                '$and': [
                    {'subject': ent},
                ]
            })
            or_conditions.append({
                '$and': [
                    {'object': ent},
                ]
            })

        pipeline = [
            {
                '$match': {
                    'sample_id': sample_id,
                    '$or': or_conditions
                }
            }
        ]

        results = list(db.get_collection('triplets').aggregate(pipeline))

        for doc in results:
            entities4search.append(doc['subject'])
            entities4search.append(doc['object'])

            for q in doc['qualifiers']:
                entities4search.append(q['object'])

        entities4search = list(set(entities4search))
                        
    print(len(results))
    supporting_triplets = []
    for item in results:
        supporting_triplets.append({"subject": item['subject'], 'relation': item['relation'], 'object': item['object'], 'qualifiers': item['qualifiers']})
    

    ans = extractor.answer_question(question=question, triplets=supporting_triplets)
    print(question, ' | ', ans, " | ", id2sample[sample_id]['answer'])
    sample_id2ans[sample_id] = ans

['VIVA Media AG', '2004', 'acronym']
Chosen relevant entities:  [{'entity': 'VIVA Media AG', 'entity_type': 'organization'}, {'entity': 'VIVA Media AG', 'entity_type': 'organization'}, {'entity': 'VIVA Media AG', 'entity_type': 'organization'}, {'entity': 'VIVA Media AG', 'entity_type': 'organization'}]
56
VIVA Media AG changed it's name in 2004. What does their new acronym stand for?  |  Viacom International Media Networks Europe  |  Gesellschaft mit beschränkter Haftung
['Jonny Craig', 'Pete Doherty', 'bands']
Chosen relevant entities:  [{'entity': 'Jonny Craig', 'entity_type': 'person'}, {'entity': 'Jonny Craig', 'entity_type': 'human'}, {'entity': 'Pete Doherty', 'entity_type': 'group of humans'}, {'entity': 'Jonny Craig', 'entity_type': 'musician'}, {'entity': 'Pete Doherty', 'entity_type': 'musician'}]
106
Which of Jonny Craig and Pete Doherty has been a member of more bands ?  |  Jonny Craig  |  Jonny" Craig
['The Missouri Compromise', 'governor']
Chosen relevant entities:  [{'e

In [22]:
f1s = []
ems = []
for sample_id in id2sample.keys():
        question = id2sample[sample_id]['question']
        ans = sample_id2ans[sample_id]
        max_f1 = 0
        max_em = 0
        max_f1_entity = ''
        max_em_entity = ''
        gold_answers_variations = [id2sample[sample_id]['answer']]
        # gold_answers_variations.extend(id2sample[sample_id]['answer_aliases'])
        for golden_answer in gold_answers_variations:
            golden_answer = normalize(golden_answer)
            ans = normalize(ans)
            f1 = compute_f1(prediction=ans, truth=golden_answer)
            em = golden_answer == ans


            if f1 > max_f1:
                max_f1_entity = golden_answer
            if em > max_em:
                max_em_entity = golden_answer
            max_f1 = max(max_f1, f1)
            max_em = max(max_em, em)
        
        print(sample_id, " | ", ans, " | ", max_em_entity, " | ", max_f1_entity)
        f1s.append(max_f1)
        ems.append(max_em)

sum(ems) / len(ems), sum(f1s)/len(f1s)

5a7613c15542994ccc9186bf  |  viacom international media networks europe  |    |  
5adf2fa35542993344016c11  |  jonny craig  |  jonny craig  |  jonny craig
5adfdef9554299025d62a36b  |  bath maine  |  bath maine  |  bath maine
5a7180205542994082a3e856  |  creature comforts  |  creature comforts  |  creature comforts
5a78bc6b554299148911f979  |  lifestyle magazine magazines focused on women interest  |    |  fortnightly women interest magazine
5abdd0f15542991f6610604d  |  failed coup attempt  |    |  a failed coup attempt
5a8e27d45542995a26add46a  |  2009  |    |  
5a881d2355429938390d3eeb  |  love and theft  |  love and theft  |  love and theft
5ae4a1ef55429970de88d9e7  |  2 march 1972  |  2 march 1972  |  2 march 1972
5ae6f2a7554299572ea5464a  |  romeo montague  |    |  
5a845d735542996488c2e52e  |  virginia  |  virginia  |  virginia
5adf5daf5542995534e8c79d  |  no only darren benjamin shepherd is american remi lange is french  |    |  no
5a7c49dc55429935c91b514f  |  science  |    |  
5

(0.48, 0.6642524493050809)

In [23]:
(0.48, 0.6509191159717475)
(0.48, 0.6520100250626567)
(0.48, 0.6642524493050809)

(0.48, 0.6520100250626567)

In [24]:
np.mean([0.6509191159717475, 0.6520100250626567, 0.6642524493050809])

0.6557271967798283

In [25]:
np.std([0.6509191159717475, 0.6520100250626567, 0.6642524493050809])

0.006044692913382812