In [1]:
!pip install python-terrier
import pyterrier as pt
if not pt.started():
    pt.init()

from transformers import BertModel
import torch.nn.functional as F
import torch
import torch.nn as nn

import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
import re
from pathlib import Path
import os
import pandas as pd

from pyterrier.measures import RR, nDCG, MAP
# Ensure NLTK data is downloaded only once
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)



PyTerrier 0.10.0 has loaded Terrier 5.8 (built by craigm on 2023-11-01 18:05) and terrier-helper 0.0.8

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


True

# Load dataset

In [2]:
# Example of loading a dataset from PyTerrier
dataset = pt.get_dataset('irds:msmarco-passage/trec-dl-2020')

topics = dataset.get_topics()
qrels = dataset.get_qrels()
corpus_iter = dataset.get_corpus_iter()
# Convert to an iterator
corpus_iterator = iter(corpus_iter)
first_doc = next(corpus_iterator)

msmarco-passage/trec-dl-2020 documents:   0%|          | 0/8841823 [00:00<?, ?it/s]

In [3]:
topics.head()

Unnamed: 0,qid,query
0,1030303,who is aziz hashim
1,1037496,who is rep scalise
2,1043135,who killed nicholas ii of russia
3,1045109,who owns barnhart crane
4,1049519,who said no one can make you feel inferior


In [4]:
qrels.head()

Unnamed: 0,qid,docno,label,iteration
0,23849,1020327,2,0
1,23849,1034183,3,0
2,23849,1120730,0,0
3,23849,1139571,1,0
4,23849,1143724,0,0


# Load models

This is loading the BertClassifier model, davinciContextModel and the DocQuery2DocModel model. The models are then loaded from the checkpoints and set to evaluation mode.

In [5]:
class BertClassifier(nn.Module):
    def __init__(self, num_labels=2):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # Align the layer names with the checkpoint
        self.classification = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        logits = self.classification(pooled_output)
        return logits

# Initialize the model
classifier_model = BertClassifier()

# Load the checkpoint
model_path = './model/doc_baseline.ckpt'
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# Extract the state_dict, assuming the state dict is saved under 'state_dict'
model_state_dict = checkpoint['state_dict']

# Adjust keys in the state_dict as needed
adjusted_model_state_dict = {}
for key, value in model_state_dict.items():
    # Remove the 'model.' prefix and handle the 'classifier' to 'classification' naming difference
    new_key = key.replace('model.', '').replace('classification', 'classifier')
    if 'position_ids' not in new_key:  # Ignore 'bert.embeddings.position_ids'
        adjusted_model_state_dict[new_key] = value

# Load the adjusted state dict
classifier_model.load_state_dict(adjusted_model_state_dict, strict=False)
classifier_model.eval()  # Set the model to evaluation mode


BertClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [6]:
class TextPreprocessor:
    def __init__(self):
        self.stop_words = set(stopwords.words('english'))
        self.stemmer = PorterStemmer()

    def tokenize(self, text: str):
        return word_tokenize(text)

    def normalize(self, tokens):
        return [re.sub(r'\W+', '', token.lower()) for token in tokens if re.sub(r'\W+', '', token.lower())]

    def stopping(self, tokens):
        return [token for token in tokens if token not in self.stop_words]

    def stemming(self, tokens):
        return [self.stemmer.stem(token) for token in tokens]

    def preprocess(self, query: str, use_tokenize=True, use_normalize=True, use_stopping=True, use_stemming=False):
        if query is None or not isinstance(query, str):
            raise ValueError("Input query must be a non-empty string")
        tokens = self.tokenize(query) if use_tokenize else query.split()
        if use_normalize: tokens = self.normalize(tokens)
        if use_stopping: tokens = self.stopping(tokens)
        if use_stemming: tokens = self.stemming(tokens)
        return ' '.join(tokens)


def preprocess_queries(query: str) -> str:
    preprocessor = TextPreprocessor()
    return preprocessor.preprocess(query, use_tokenize=True, use_normalize=True, use_stopping=True, use_stemming=False)

# Context-Aware Query Rewriting

We need to adjust our logic here used the model from: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2



In [37]:
import requests
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import torch.nn.functional as F
import re

class ContextAwareQueryRewriter:
    def __init__(self, classifier_model, inference_model_endpoint, tokenizer=None):
        self.classifier_model = classifier_model
        self.inference_model_endpoint = inference_model_endpoint
        self.tokenizer = tokenizer if tokenizer else BertTokenizer.from_pretrained('bert-base-uncased')


    def needs_rewriting(self, query):
            inputs = self.tokenizer(query, return_tensors="pt", padding=True, truncation=True)
            with torch.no_grad():
                output = self.classifier_model(**inputs)
            logits = output.logits if hasattr(output, 'logits') else output
            category = torch.argmax(F.log_softmax(logits, dim=1), dim=1).item()
            return category == 1

    def extract_context(self, query):
        headers = {"Authorization": "Bearer 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"}
        API_URL = "https://api-inference.huggingface.co/models/ml6team/keyphrase-extraction-kbir-openkp"
        response = requests.post(API_URL, headers=headers, json={"inputs": query})
    
        if response.status_code == 200:
            keyphrases = response.json()
    
            # Create a context string by joining the 'word' values from each keyphrase.
            extracted_context = ', '.join([kp['word'].strip() for kp in keyphrases])
    
            return extracted_context
        else:
            print("Error: Unable to extract context.")
            return None


    def rewrite_query(self, query):
        """
        Rewriting Phase: Rewrite the given query based on the provided context using model inference.
        
        :param query: The original search query.
        :param context: Context or background information relevant to the query.
        :return: A rewritten query that better captures the intent, considering the context.
        """
        # Construct a prompt that is clearly demarcated to help separate the response from the prompt.
        context = self.extract_context(query)
        if not context:
            context = "[Context not available]"
        prompt = f"Rewrite the following into a precise query for information retrieval based on the context '{context}': '{query}'. Provide only the query."
        payload = {"inputs": prompt}
        response = self.query_model(payload)
        if response and "generated_text" in response[0]:
            rewritten_query = response[0]["generated_text"].split(prompt)[-1].strip()
            return rewritten_query
        else:
            return query


    def query_model(self, payload):
        headers = {"Authorization": "Bearer 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"}
        response = requests.post(self.inference_model_endpoint, headers=headers, json=payload)
        return response.json()

    def transform(self, queries):
        rewritten_queries = []
        for query in queries:
            # First, check if the query needs rewriting
            if self.needs_rewriting(query):
                # If it does, rewrite the query based on extracted context
                rewritten_query = self.rewrite_query(query)
                rewritten_queries.append(rewritten_query if rewritten_query else query)
            else:
                # If it doesn't need rewriting, add it as is
                rewritten_queries.append(query)
        return rewritten_queries


# Initialize ContextAwareQueryRewriter
classifier_model = classifier_model
inference_model_endpoint = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

query_rewriter = ContextAwareQueryRewriter(classifier_model, inference_model_endpoint, tokenizer)

In [42]:
context = "In the context of French cuisine"
original_query = "The main ingredient in a quiche [MASK] is"
refined_query = query_rewriter.fill_in_the_mask(original_query)

print(refined_query)


[{'score': 0.1272163987159729, 'token': 7954, 'token_str': 'meal', 'sequence': 'the main ingredient in a quiche meal is'}, {'score': 0.09382666647434235, 'token': 9850, 'token_str': 'cake', 'sequence': 'the main ingredient in a quiche cake is'}, {'score': 0.07288110256195068, 'token': 11350, 'token_str': 'soup', 'sequence': 'the main ingredient in a quiche soup is'}, {'score': 0.05692329257726669, 'token': 12901, 'token_str': 'sauce', 'sequence': 'the main ingredient in a quiche sauce is'}, {'score': 0.04582372307777405, 'token': 11642, 'token_str': 'sandwich', 'sequence': 'the main ingredient in a quiche sandwich is'}]


### Step 1: Extracting Original Queries

You've already loaded the dataset and accessed the topics, which contain the queries. The code snippet you've shown extracts the queries from the `topics` DataFrame provided by PyTerrier:

In [43]:
# Assuming you've already loaded the dataset and have the 'topics' DataFrame
original_queries = [topic['query'] for topic_id, topic in topics.iterrows()]

print(original_queries[:100])

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("Ashishkr/query_wellformedness_score")
model = AutoModelForSequenceClassification.from_pretrained("Ashishkr/query_wellformedness_score")
sentences = original_queries[:100]
features = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
model.eval()
with torch.no_grad():
    scores = model(**features).logits
print(scores)


['who is aziz hashim', 'who is rep scalise', 'who killed nicholas ii of russia', 'who owns barnhart crane', 'who said no one can make you feel inferior', 'who sings monk theme song', 'who was the highest career passer rating in the nfl', 'why do hunters pattern their shotguns', 'why do some places on my scalp feel sore', 'why is pete rose banned from hall of fame', 'who is thomas m cooley', 'definition of endorsing', 'which hormone increases calcium levels in the blood', 'define geon', 'where can the amazon rainforest is located', 'when are the four forces that act on an airplane in equilibrium', 'define pareto chart in statistics', 'what year were the timberwolves founded', 'what year did knee deep come out funkadelic', 'define define gallows', 'what tissue composes the hypodermis', 'what time zone is st paul minnesota in', 'what the best way to get clothes white', 'what temperature and humidity to dry sausage', 'what mental illnesses', 'what medium do radio waves travel through', 'wh

tokenizer_config.json:   0%|          | 0.00/312 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/740 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/501M [00:00<?, ?B/s]

Some weights of the model checkpoint at Ashishkr/query_wellformedness_score were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[0.9743],
        [0.1841],
        [0.6934],
        [0.6095],
        [0.7336],
        [0.3793],
        [0.6265],
        [0.9032],
        [0.8124],
        [0.6394],
        [0.6664],
        [0.2150],
        [1.0061],
        [0.2947],
        [0.1853],
        [0.7173],
        [0.1922],
        [0.9775],
        [0.7670],
        [0.1133],
        [0.6861],
        [0.7636],
        [0.1986],
        [0.1304],
        [0.0420],
        [1.0067],
        [0.6026],
        [0.2620],
        [0.7683],
        [0.2335],
        [0.3470],
        [0.8602],
        [0.9789],
        [0.1109],
        [0.5709],
        [0.4360],
        [0.7824],
        [0.9976],
        [0.2031],
        [0.6526],
        [0.9268],
        [0.9725],
        [0.1249],
        [0.7022],
        [0.0669],
        [0.1188],
        [0.5852],
        [0.6538],
        [0.0332],
        [0.4317],
        [0.2625],
        [0.0872],
        [0.1248],
        [0.2465],
        [0.0930],
        [0

In [44]:
# Assume scores is a tensor of shape (num_queries, num_classes) and you're interested in the score for being well-formed
well_formed_scores = torch.sigmoid(scores)[:, 0]  # Apply sigmoid if the scores are logits

# Set a threshold for selecting queries that need to be rewritten
threshold = 0.5

# Select queries with scores below the threshold
queries_to_rewrite = [query for query, score in zip(original_queries[:100], well_formed_scores) if score < threshold]

print("Queries to Rewrite:", queries_to_rewrite)


Queries to Rewrite: []


In [21]:
# rewrite the 50 queries
rewritten_queries = query_rewriter.transform(original_queries[:100])

print(rewritten_queries)

['who is aziz hashim', 'who is rep scalise', 'who killed nicholas ii of russia', 'who owns barnhart crane', "``Who spoke the phrase 'no one can make you feel inferior'?'''\n\nNote: For information retrieval queries, it's essential to be as specific and clear as possible to ensure accurate results. In this case, we're looking for the identities of individuals who have made the statement, so our query focuses on the speaker(s) and the exact phrase.", 'who sings monk theme song', 'who was the highest career passer rating in the nfl', 'why do hunters pattern their shotguns', '```\nwhy do certain areas on my scalp present with tenderness or discomfort?', 'why is pete rose banned from hall of fame', "query: thomas m cooley author\n\nThis query assumes a library catalog or information database context, reflecting that 'Thomas M. Cooley' is most likely an author or a topic of a published work.", 'definition of endorsing', 'which hormone increases calcium levels in the blood', 'define geon', 'L

### Step 2: Rewriting Queries

You've also created a `QueryRewriter` class that takes in a model and a tokenizer to rewrite queries based on a certain logic (in your case, this logic depends on the predictions from a BertClassifier model). Initialize the `QueryRewriter` with the necessary model and tokenizer, and then use it to rewrite the preprocessed queries:


### Step 3: Preprocessing Queries (Optional)

You've defined a `TextPreprocessor` class for preprocessing queries. You can apply this preprocessing to each of the original queries to get them ready for the rewriting process:

In [None]:

# preprocessor = TextPreprocessor()
# preprocessed_queries = [preprocessor.preprocess(query) for query in original_queries]

In [None]:
print("Rewritten queries: ")
print(rewritten_queries[:10])
print("\nOriginal queries: ")
print(original_queries[:10])

## Evaluating Rewritten vs. Original Queries with Doc MSMARCO Reranked Models

### Objective

The goal is to utilize the four doc_msmarco reranked models to compare the performance of rewritten queries against original queries.

### Location of Reranked Models

The reranked models are stored within the `doc_msmarco` directory, under the `model` folder. Here's the structure:

```
model/
├── doc_baseline.ckpt
├── doc_davinci_doc_context.ckpt
└── wetransfer_model-1-doc2query-all-rankers_2024-03-21_1327/
    ├── best_rank_list/
    │   └── doc_msmarco/
    │       ├── doc_attention_2_davinci_final_3e-4.tsv
    │       ├── doc_davinci03_doc_context_400tok_5e-4.tsv
    │       ├── doc_linear_2_davinci_final_1e-4.tsv
    │       └── doc_query_2_doc_3e-4.tsv
    └── doc_query2doc.ckpt
```

### Procedure

1. **Retrieve Documents**: Initially, documents are fetched from the MS MARCO datasets for the purpose of reranking.

2. **Reranking**: Utilize the retrieved documents along with the queries (both original and rewritten) to perform reranking using the models specified above.

3. **Evaluation**: Assess the reranking process to determine the effectiveness of rewritten queries in comparison to the original ones.

### Step 1: Indexing the Corpus

Before we can retrieve and rerank documents based on queries, we need to ensure that the documents are indexed. You've mentioned creating an index but commented out the relevant code. Here's how you can do it properly with Pyterrier

In [None]:
index_location = str(Path("index").absolute())
index_exists = os.path.isfile(os.path.join(index_location, "data.properties"))

if not index_exists:
    indexer = pt.IterDictIndexer(index_location)
    index_ref = indexer.index(corpus_iter)
    print("Indexing completed.")
else:
    print("Index already exists, loading from disk.")
    index_ref = index_location

### Step 2: Retrieving Documents

You need to use the created index to retrieve documents based on both the original and rewritten queries. You have already outlined this step correctly, but make sure that `index` is correctly initialized using the `index_ref` from the indexing step:

In [None]:
# Assuming original_queries and rewritten_queries are your lists of queries
query_ids = range(1, len(original_queries) + 1)  # Generating query IDs

# Converting original queries to a DataFrame
original_queries_df = pd.DataFrame({
    'qid': query_ids,
    'query': original_queries
})

# Converting rewritten queries to a DataFrame
rewritten_queries_df = pd.DataFrame({
    'qid': query_ids,
    'query': rewritten_queries
})

In [None]:
# Initialize index from the created index reference
index = pt.IndexFactory.of(index_ref)

# Retrieve documents for both original and rewritten queries using BM25
original_res = pt.BatchRetrieve(index, wmodel="BM25").transform(original_queries_df)
rewritten_res = pt.BatchRetrieve(index, wmodel="BM25").transform(rewritten_queries_df)


### Step 4: Evaluating the Results

Finally, to evaluate the effectiveness of the rewritten queries versus the original queries, you'll use PyTerrier's evaluation functionalities. You need to compare the reranked results against the relevance judgments (`qrels`) using various metrics:

In [None]:
from pyterrier.measures import RR, nDCG, AP

# Ensure your reranked DataFrames are named appropriately and structured correctly
# They should have columns ['qid', 'docno', 'score']

# Perform the evaluation using PyTerrier's Experiment function
# results = pt.Experiment(
#     [original_res_sorted, rewritten_res_sorted],
#     topics,  # Your original queries DataFrame
#     qrels,  # The qrels DataFrame with relevance judgments
#     eval_metrics=[nDCG@10, AP(rel=10), RR(rel=10), nDCG@20],
#     names=["Original Reranked", "Rewritten Reranked"],
#     baseline=0,  # Optionally specify which result set is considered the 'baseline' for comparison
# )

# Display the evaluation result