In [62]:
!git clone https://github.com/usnistgov/trec_eval.git && cd trec_eval && make

fatal: la ruta de destino 'trec_eval' ya existe y no es un directorio vacío.


In [63]:
import os
import json
import numpy as np
from typing import Dict
from tqdm import tqdm
from re import sub
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from rank_bm25 import BM25Okapi

In [64]:
stop_words = set(stopwords.words('english'))
stemmer = SnowballStemmer(language='english')

def clean_text(text):
    # Convierte el texto a minúsculas
    cln_text = text.lower()
    
    # Elimina saltos de línea, tabulaciones y caracteres no deseados como \u200e
    cln_text = sub(r'[\n\t\u200e]', ' ', cln_text)
    
    # remove all non-alphanumeric characters
    cln_text = sub(r'[^a-z0-9]', ' ', cln_text)
    
    # Filtra las stopwords usando un conjunto
    cln_text = ' '.join([stemmer.stem(word) for word in cln_text.split() if word not in stop_words])
    
    # Elimina múltiples espacios consecutivos
    cln_text = sub(r' +', ' ', cln_text).strip()
    
    return cln_text

# create funcion that tokenize in unigrams and bigrams
def tokenizer(text):
    tokens = text.split()
    unigrams = tokens
    bigrams = [f"{tokens[i]} {tokens[i + 1]}" for i in range(len(tokens) - 1)]
    return unigrams + bigrams

def load_qrels(docs_dir: str, fqrels: str) -> Dict[str, Dict[str, int]]:
    ndocs = 40
    docs = []
    for i in range(1, ndocs + 1):
        with open(os.path.join(docs_dir, f"{i}.json")) as f:
            doc = json.load(f)
            docs.append(doc)

    did2pid2id: Dict[str, Dict[str, str]] = {}
    for doc in docs:
        for psg in doc:
            did2pid2id.setdefault(psg["DocumentID"], {})
            assert psg["ID"] not in did2pid2id[psg["DocumentID"]]
            did2pid2id[psg["DocumentID"]].setdefault(psg["PassageID"], psg["ID"])

    with open(fqrels) as f:
        data = json.load(f)
    qrels = {}
    for e in data:
        qid = e["QuestionID"]
        for psg in e["Passages"]:
            qrels.setdefault(qid, {})
            pid = did2pid2id[psg["DocumentID"]][psg["PassageID"]]
            qrels[qid][pid] = 1
    return qrels

In [65]:
qrels = load_qrels("ObliQADataset/StructuredRegulatoryDocuments", "ObliQADataset/ObliQA_test.json")
with open("qrels", "w") as f:
    for qid, rels in qrels.items():
        for pid, rel in rels.items():
            line = f"{qid} Q0 {pid} {rel}"
            f.write(line + "\n")

In [66]:
ndocs = 40
collection = []
for i in range(1, ndocs + 1):
    with open(os.path.join("ObliQADataset/StructuredRegulatoryDocuments", f"{i}.json")) as f:
        doc = json.load(f)
        for psg in doc:
            collection.append(
                dict(text=psg["PassageID"] + " " + psg["Passage"], ID=psg["ID"],
                     DcoumentId=psg['DocumentID'],
                     PassageId=psg['PassageID'],
                     )
            )

In [67]:
tokenized_corpus = [tokenizer(clean_text(doc['text'])) for doc in collection]

bm25 = BM25Okapi(tokenized_corpus, k1=1.5, b=0.75)

collection_array = np.array(collection)

len(tokenized_corpus)

13732

In [68]:
retrieved = {}
with open("ObliQADataset/ObliQA_test.json") as f:
    data = json.load(f)
    for e in tqdm(data):
        query = e['Question']
        tokenized_query = tokenizer(clean_text(query))
        
        doc_scores = bm25.get_scores(tokenized_query)
        
        # Obtenemos los índices de los 10 puntajes más altos
        top_10_indices = np.argpartition(-doc_scores, 10)[:10]
        # Ordenamos correctamente los 10 mejores índices
        top_10_indices = top_10_indices[np.argsort(-doc_scores[top_10_indices])]
        top_10_scores = doc_scores[top_10_indices]
        
        # Recuperamos los documentos correspondientes y añadimos el puntaje
        top_10_docs = [{**collection_array[i], 'score': score} for i, score in zip(top_10_indices, top_10_scores)]
        
        retrieved[e["QuestionID"]] = top_10_docs

100%|██████████| 2786/2786 [03:26<00:00, 13.51it/s]


In [69]:
with open("rankings.trec", "w") as f:
  for qid, hits in retrieved.items():
    for i, hit in enumerate(hits):
      line = f"{qid} 0 {hit['ID']} {i+1} {hit['score']} bm25"
      f.write(line + "\n")

In [71]:
!trec_eval/trec_eval -m recall.10 -m map_cut.10 ./qrels ./rankings.trec

recall_10             	all	0.7771
map_cut_10            	all	0.6338
