# Doc2Query

In [1]:
#from Doc2Query.Doc2Query import Doc2Query
from Doc2Query import Doc2Query
import pyterrier as pt
from tira.third_party_integrations import ensure_pyterrier_is_loaded, persist_and_normalize_run
from tira.rest_api_client import Client

In [2]:
# Create a REST client to the TIRA platform for retrieving the pre-indexed data.
ensure_pyterrier_is_loaded()
tira = Client()

PyTerrier 0.10.0 has loaded Terrier 5.8 (built by craigm on 2023-11-01 18:05) and terrier-helper 0.0.8

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


In [3]:
# PyTerrier dataset to pass
pt_dataset = pt.get_dataset('irds:ir-lab-sose-2024/ir-acl-anthology-20240504-training')
type(pt_dataset)

pyterrier.datasets.IRDSDataset

In [5]:
#from huggingface_hub import notebook_login
#notebook_login()

### Expand the Documents

In [6]:
# Create an instance of Doc2Query
# if you are working on Licca you have to change the path to the model
Doc2Query_object = Doc2Query("google/flan-t5-small", temperature=0.7, promting_technique='One-Shot')
#Doc2Query_object = Doc2Query("gbt2", temperature=0.7, promting_technique='Few-Shot')

In [7]:
# The text documents from the PyTerrier dataset
documents = Doc2Query_object.getDocumentsDfFromPtDataset(pt_dataset)

ir-lab-sose-2024/ir-acl-anthology-20240504-training documents: 100%|██████████| 126958/126958 [00:02<00:00, 48786.16it/s]


In [8]:
# Generate queries for the dataset and extend the documents by the queries
expanded_documents = Doc2Query_object.expandDocumentsByQueries(documents)

In [None]:
# to check wether the generated queries make sense
expanded_documents[['text']].to_dict()

In [None]:
# Create the index using PyTerrier
indexer = pt.IterDictIndexer(
    "./indexes/index_Doc2Query-flan-t5-small-oneShot-BM25",
    overwrite=True,
    fields=["text"],
    meta=["docno"]
)
# Index the documents
expanded_documents = expanded_documents.to_dict(orient='records')
indexref = indexer.index(expanded_documents)

# Retrieve documents using BM25
bm25 = pt.BatchRetrieve(indexref, wmodel="BM25")

# Perform retrieval
#queries_df = pt_dataset.get_topics()
#run = bm25.transform(queries_df)
run = bm25(pt_dataset.get_topics('text'))

# Evaluate the results
qrels_df = pt_dataset.get_qrels()
eval = pt.Evaluate(run, qrels_df, metrics=["map", "ndcg", "ndcg_cut.10", "recip_rank", "recall_100"])
print("Evaluation Metrics:")
print(eval)

In [None]:
# Filter run to include only judged documents
qrels_df = pt_dataset.get_qrels()
judged_docnos = qrels_df['docno'].unique()
filtered_run = run[run['docno'].isin(judged_docnos)]

# Evaluate the results
eval = pt.Evaluate(filtered_run, qrels_df, metrics=["map", "ndcg", "ndcg_cut.10", "recip_rank", "recall_100"])
print("Evaluation Metrics:")
print(eval)

In [None]:
# Persist the run file for subsequent evaluations
persist_and_normalize_run(run, system_name='Doc2Query-flan-t5-small-oneShot-BM25', default_output='../runs')