We provide a step-by-step guide for applying the ToTER framework (last updated: 24.02.21).

In [1]:
import torch
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [2]:
import torch
from Utils import *
import numpy as np
from beir.datasets.data_loader import GenericDataLoader
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

import pathlib, os
from tqdm.auto import tqdm, trange

## Data load

In [3]:
dataset = "scidocs"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join("datasets")
data_path = util.download_and_unzip(url, out_dir)

corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
retriever = EvaluateRetrieval()

k_values=[10,50,100,500, 1000, 2500]

  0%|          | 0/25657 [00:00<?, ?it/s]

## Backbone retriever

Contriever-MS has been trained using massive training labels from general domain (MS-MARCO).

In [4]:
from IPython.utils import io

with torch.no_grad():
    with io.capture_output() as captured:
        text_encoder = DRES(models.SentenceBERT("facebook/contriever-msmarco"), batch_size=32)
        retriever = EvaluateRetrieval(text_encoder, score_function="dot", k_values=k_values + [len(corpus)])
        CTR_full_results = retriever.retrieve(corpus, queries)
print_metric(retriever.evaluate(qrels, CTR_full_results, retriever.k_values))

NDCG@10: 0.16524 , NDCG@100: 0.23594
Recall@100: 0.37807 , Recall@500: 0.54943 , Recall@1000: 0.62155 , Recall@2500: 0.7239


## Training Phase: Class relevance learning using Taxonomy

#### 1. Load pre-computed PLM representations

We encode each text and topic class using PLM. If we directly use the PLM representations for retrieval, the results are unsatisfactory, as they have not been fine-tuned.

In [5]:
c_id_list, q_id_list = np.load('resource/test_doc_id_list.npy'), np.load('resource/test_q_id_list.npy')
c_emb_list, q_emb_list = np.load('resource/test_doc_emb.npy'), np.load('resource/test_q_emb.npy')

In [6]:
score_mat = np.matmul(q_emb_list, c_emb_list.T)
PLM_results = eval_full_score_mat(score_mat, q_id_list, c_id_list)
print_metric(retriever.evaluate(qrels, PLM_results, k_values))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:17<00:00, 56.58it/s]


NDCG@10: 0.02657 , NDCG@100: 0.05212
Recall@100: 0.10853 , Recall@500: 0.24578 , Recall@1000: 0.33232 , Recall@2500: 0.46243


#### 2. Conduct class relevance learning

This example uses a simple MLP classifier. More sophisticated choices can further improve the final result.

In [None]:
from Class_relevance_learning import train_classifier
train_classifier()

## Inference Phase

We first compute topic distributions for queries and documents, and then apply filtering to retain topic classes that have a certain degree of relevance.

In [None]:
model = Classifier(num_class=4028)
model.load_state_dict(torch.load('resource/Classifier_model'))

with torch.no_grad():
    c_emb_tensor = torch.FloatTensor(c_emb_list)
    c_clf_output = model(c_emb_tensor).numpy()
    
    q_emb_tensor = torch.FloatTensor(q_emb_list)
    q_clf_output = model(q_emb_tensor).numpy()

In [6]:
filtered_X = filtering(q_clf_output)
filtered_Y = filtering(c_clf_output)

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

### 1. search space adjustment

Before applying retrievers, we filter out a large number of irrelevant documents that have little overlap in topic classes with the query. In this example, we retain 10% of the corpus size.

In [7]:
from tqdm.contrib.concurrent import process_map

def topic_filter(i):
    num_query_topic = (filtered_X[i] > 0).sum()    
    co_topic_rank = np.argsort(-(filtered_Y * (filtered_X[i] > 0) > 0).sum(1))
    tmp = topical_relatedness_mat[i].copy()
    is_filtered = co_topic_rank[int(len(co_topic_rank) * adjustment_percent)+1:]
    tmp[is_filtered] = 0.
    
    return tmp  

In [8]:
topical_relatedness_mat = compute_topical_relatedness(filtered_X, filtered_Y)

with io.capture_output() as captured:
    adjustment_percent = 0.1
    r = process_map(topic_filter, range(q_clf_output.shape[0]), max_workers=40)
    SSA_score_mat = np.asarray(r)

non_zero_percent = (((SSA_score_mat > 0).sum(1) / (topical_relatedness_mat > 0).sum(1)).mean() * 100)
print("For each query, we retain", round(non_zero_percent, 2), "% of the documents in the corpus.")

For each query, we retain 10.0 % of the documents in the corpus.


In [9]:
SSA_results = eval_full_score_mat(SSA_score_mat, q_id_list, c_id_list)
SSA_results = return_topK_result(SSA_results, topk=2500)
print('Recall@2500:', retriever.evaluate(qrels, SSA_results, [2500])[2]['Recall@2500'])

  0%|          | 0/1000 [00:00<?, ?it/s]

Recall@2500: 0.82947


###  2. Class relevance matching (CRM) for retriever

We can conduct retrieval on the reduced search space. This step can benefit subsequent retrieval by reducing the search
space while preserving topically relevant documents that may otherwise be overlooked by PLM-based retrievers.

In [10]:
CTR_SSA_results = {}
for _, qid in tqdm(enumerate(SSA_results)):
    CTR_SSA_results[qid] = {}
    for cid in SSA_results[qid]:
        CTR_SSA_results[qid][cid] = CTR_full_results[qid][cid]
        
print_metric(retriever.evaluate(qrels, CTR_SSA_results, retriever.k_values))

0it [00:00, ?it/s]

NDCG@10: 0.16749 , NDCG@100: 0.24042
Recall@100: 0.38867 , Recall@500: 0.58482 , Recall@1000: 0.6813 , Recall@2500: 0.82947


We retrieve documents based on both semantic similarity (retriever) and topical relatedness (our classifier). This can help to handle lexical mismatches and fill in missing contexts, providing a complementary aspect to semantic similarity.

In [11]:
normalized_CRM_results = z_normalize(SSA_results)
normalized_CTR_results = z_normalize(CTR_full_results)

In [12]:
CRM_results = {}
for _, qid in tqdm(enumerate(SSA_results)):
    CRM_results[qid] = {}

    for cid in SSA_results[qid]:
        s_CRM = normalized_CRM_results[qid][cid]
        s_de = normalized_CTR_results[qid][cid]

        CRM_results[qid][cid] = s_de + s_CRM
        
print_metric(retriever.evaluate(qrels, CRM_results, retriever.k_values))

0it [00:00, ?it/s]

NDCG@10: 0.17636 , NDCG@100: 0.26056
Recall@100: 0.4337 , Recall@500: 0.636 , Recall@1000: 0.71805 , Recall@2500: 0.82947


### 3. Query enrichment by core phrases (QEP) for reranker

In this last stage, a reranker reorder top-ranked candidates from the retriever. Here, we delve deeper into each topic by focusing on  class-related phrases.

In [None]:
top_result = return_topK_result(CRM_results, topk=100)

In [17]:
from beir.reranking.models import MonoT5
from beir.reranking import Rerank

T5_model = MonoT5('castorini/monot5-base-msmarco-10k', token_false='▁false', token_true='▁true')
reranker = Rerank(T5_model, batch_size=1024)

T5_rerank_results = reranker.rerank(corpus, queries, top_result, top_k=100)
print('NDCG@10:', retriever.evaluate(qrels, T5_rerank_results, [10])[0]['NDCG@10'])

  0%|          | 0/1000 [00:00<?, ?it/s]

NDCG@10: 0.17992


Users familiar with a domain often omit contexts in their queries, which makes it difficult to find accurate relevance. We identify missing contexts (called core phrases) and enrich the query.

In [30]:
with open('resource/topic_enriched_queries', 'rb') as f:
    topic_enriched_queries = pickle.load(f)

example_qid = "2d52f69dd4686a3e66f5a8a1650a24bcea43530e"
print("Original query:\n" + queries[example_qid])
print("Topic-enriched query:\n" + topic_enriched_queries[example_qid])

Original query:
Provable data possession at untrusted stores
Topic-enriched query:
Provable data possession at untrusted stores, relevant topics: data, key, server, access_control, security, system, encryption, allows, untrusted


In [20]:
T5_rerank_results_QEP = reranker.rerank(corpus, topic_enriched_queries, top_result, top_k=100)
print('NDCG@10:', retriever.evaluate(qrels, T5_rerank_results_QEP, [10])[0]['NDCG@10'])

NDCG@10: 0.18645
