In [1]:
!pip install python-terrier
import pyterrier as pt
if not pt.started():
    pt.init()

from transformers import BertModel
import torch
import torch.nn as nn

import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
import re
from pathlib import Path
import os
import pandas as pd

from pyterrier.measures import RR, nDCG, MAP
# Ensure NLTK data is downloaded only once
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)



PyTerrier 0.10.0 has loaded Terrier 5.8 (built by craigm on 2023-11-01 18:05) and terrier-helper 0.0.8

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


True

# Load dataset

In [2]:
# Example of loading a dataset from PyTerrier
dataset = pt.get_dataset('irds:msmarco-passage/trec-dl-2020')

topics = dataset.get_topics()
qrels = dataset.get_qrels()
corpus_iter = dataset.get_corpus_iter()
# Convert to an iterator
corpus_iterator = iter(corpus_iter)
first_doc = next(corpus_iterator)

msmarco-passage/trec-dl-2020 documents:   0%|          | 0/8841823 [00:00<?, ?it/s]

In [3]:
topics.head()

Unnamed: 0,qid,query
0,1030303,who is aziz hashim
1,1037496,who is rep scalise
2,1043135,who killed nicholas ii of russia
3,1045109,who owns barnhart crane
4,1049519,who said no one can make you feel inferior


In [4]:
qrels.head()

Unnamed: 0,qid,docno,label,iteration
0,23849,1020327,2,0
1,23849,1034183,3,0
2,23849,1120730,0,0
3,23849,1139571,1,0
4,23849,1143724,0,0


# Load models

This is loading the BertClassifier model, davinciContextModel and the DocQuery2DocModel model. The models are then loaded from the checkpoints and set to evaluation mode.

In [5]:
class BertClassifier(nn.Module):
    def __init__(self, num_labels=2):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # Align the layer names with the checkpoint
        self.classification = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        logits = self.classification(pooled_output)
        return logits

# Initialize the model
classifier_model = BertClassifier()

# Load the checkpoint
model_path = './model/doc_baseline.ckpt'
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# Extract the state_dict, assuming the state dict is saved under 'state_dict'
model_state_dict = checkpoint['state_dict']

# Adjust keys in the state_dict as needed
adjusted_model_state_dict = {}
for key, value in model_state_dict.items():
    # Remove the 'model.' prefix and handle the 'classifier' to 'classification' naming difference
    new_key = key.replace('model.', '').replace('classification', 'classifier')
    if 'position_ids' not in new_key:  # Ignore 'bert.embeddings.position_ids'
        adjusted_model_state_dict[new_key] = value

# Load the adjusted state dict
classifier_model.load_state_dict(adjusted_model_state_dict, strict=False)
classifier_model.eval()  # Set the model to evaluation mode


BertClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [6]:
class DocQuery2DocModel(nn.Module):
    def __init__(self, num_labels=2):
        super(DocQuery2DocModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # Align the layer names with the checkpoint
        self.classification = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        logits = self.classification(pooled_output)
        return logits

# Initialize your model
doc_query2doc_model = DocQuery2DocModel()

# Load the checkpoint
model_path = './model/wetransfer_model-1-doc2query-all-rankers_2024-03-21_1327/doc_query2doc.ckpt'
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# Assuming the checkpoint is compatible with the model architecture
doc_query2doc_model.load_state_dict(checkpoint, strict=False)

doc_query2doc_model.eval()  # Set the model to evaluation mode

DocQuery2DocModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise

In [7]:
class DavinciDocContextModel(nn.Module):
    def __init__(self, num_labels=1):  # Adjusted num_labels to 1 to match checkpoint
        super(DavinciDocContextModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classification = nn.Linear(768, num_labels)  # This layer now matches the checkpoint

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        logits = self.classification(pooled_output)
        return logits

# Initialize your model with the corrected number of labels
davinci_doc_context_model = DavinciDocContextModel()

# Load the checkpoint
model_path = './model/doc_davinci_doc_context.ckpt'  # Adjust this path to your checkpoint's location
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# Load the model weights from the checkpoint
davinci_doc_context_model.load_state_dict(checkpoint['state_dict'], strict=False)  # Use strict=False for flexibility

davinci_doc_context_model.eval()  # Set the model to evaluation mode


DavinciDocContextModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [8]:
class TextPreprocessor:
    def __init__(self):
        self.stop_words = set(stopwords.words('english'))
        self.stemmer = PorterStemmer()

    def tokenize(self, text: str):
        return word_tokenize(text)

    def normalize(self, tokens):
        return [re.sub(r'\W+', '', token.lower()) for token in tokens if re.sub(r'\W+', '', token.lower())]

    def stopping(self, tokens):
        return [token for token in tokens if token not in self.stop_words]

    def stemming(self, tokens):
        return [self.stemmer.stem(token) for token in tokens]

    def preprocess(self, query: str, use_tokenize=True, use_normalize=True, use_stopping=True, use_stemming=False):
        if query is None or not isinstance(query, str):
            raise ValueError("Input query must be a non-empty string")
        tokens = self.tokenize(query) if use_tokenize else query.split()
        if use_normalize: tokens = self.normalize(tokens)
        if use_stopping: tokens = self.stopping(tokens)
        if use_stemming: tokens = self.stemming(tokens)
        return ' '.join(tokens)


def preprocess_queries(query: str) -> str:
    preprocessor = TextPreprocessor()
    return preprocessor.preprocess(query, use_tokenize=True, use_normalize=True, use_stopping=True, use_stemming=False)

# Context-Aware Query Rewriting

We need to adjust our logic here used the model from: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2

In [9]:
import requests
from transformers import BertTokenizer


class ContextAwareQueryRewriter:
    def __init__(self, classifier_model, context_analysis_model, inference_model_endpoint, tokenizer=None):
        self.classifier_model = classifier_model
        self.context_analysis_model = context_analysis_model
        self.inference_model_endpoint = inference_model_endpoint
        self.tokenizer = tokenizer

    def needs_rewriting(self, query):
        """
        Classification Phase: Determines whether a query needs rewriting based on ambiguity or specificity.
        :param query: 
        :return: 
        """
        inputs = self.tokenizer(query, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            output = self.classifier_model(**inputs)
        logits = output.logits if hasattr(output, 'logits') else output
        category = torch.argmax(F.log_softmax(logits, dim=1), dim=1).item()
        # Assuming 0 is "does not need rewriting" and 1 is "needs rewriting"
        needs_rewriting = category == 1
        return needs_rewriting

    def analyze_context(self, context):
        """
        Analyze the given context to extract features or insights that can guide the rewriting process.
        
        :param context: The context or background information relevant to the query.
        :return: A dictionary or structured object containing the insights extracted from the context.
        """
        # Tokenize the context and convert to input format expected by BERT
        inputs = self.tokenizer(context, return_tensors="pt", padding=True, truncation=True)

        # Generate embeddings or other representations using the context_analysis_model
        with torch.no_grad():
            outputs = self.context_analysis_model(**inputs)

        # Extract embeddings or other features from the outputs
        # This could be as simple as using the output of the last hidden layer
        embeddings = outputs.last_hidden_state

        insights = self.process_embeddings_to_insights(embeddings)

        print(insights)

        return insights

    def process_embeddings_to_insights(self, embeddings):
        """
        function would process the embeddings generated by a BERT model to extract actionable insights, such as key topics, sentiment, or entities, which can guide the query rewriting process. 
        :param embeddings: 
        :return: 
        """
        # Process the embeddings to extract insights
        # This could involve clustering, summarization, or other techniques
        embeddings_mean = torch.mean(embeddings, dim=1)
        insights = embeddings_mean.tolist()
        print(insights)
        return insights


    def rewrite_query(self, query, context):
        """
        Rewriting Phase: Rewrite the given query based on the provided context using model inference.
        
        :param query: The original search query.
        :param context: Context or background information relevant to the query.
        :return: A rewritten query that better captures the intent, considering the context.
        """
        prompt = f"Given the context: '{context}', rewrite the query to be more specific and clear."
        payload = {"inputs": prompt}
        response = self.query_model(payload)
        if response:
            rewritten_query = response[0]["generated_text"]
            return rewritten_query.strip()
        else:
            print("Error: Unable to rewrite query.")
            return None

    def query_model(self, payload):
        """
        Query the model endpoint with the given payload.
        
        :param payload: Input data to be sent to the model.
        :return: Response from the model endpoint.
        """
        headers = {"Authorization": "Bearer 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"}
        response = requests.post(self.inference_model_endpoint, headers=headers, json=payload)
        return response.json()


    def transform(self, queries, contexts):
        rewritten_queries = []
        for query, context in zip(queries, contexts):
            if self.needs_rewriting(query):
                # Optional: Analyze context to guide the rewriting process
                # context_features = self.analyze_context(context)
                rewritten_query = self.rewrite_query(query, context)
                rewritten_queries.append(rewritten_query)
            else:
                rewritten_queries.append(query)
        return rewritten_queries


# Initialize ContextAwareQueryRewriter
classifier_model = classifier_model
context_analysis_model = davinci_doc_context_model
inference_model_endpoint = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

query_rewriter = ContextAwareQueryRewriter(classifier_model, context_analysis_model, inference_model_endpoint, tokenizer)

# Test with a simple example query and context
query = "What are the major industries in France?"
context = "France is known for its diverse economy, with major industries including aerospace, automotive, luxury goods, and tourism. The country is also a leader in agriculture, particularly wine production."

rewritten_query = query_rewriter.rewrite_query(query, context)
print(rewritten_query)


Given the context: 'France is known for its diverse economy, with major industries including aerospace, automotive, luxury goods, and tourism. The country is also a leader in agriculture, particularly wine production.', rewrite the query to be more specific and clear. For example: 'What are the primary industries that contribute significantly to France's economy, and what specific products or sectors does each industry focus on?' Or 'Which industries hold a significant share of France's economy, and what are the prominent areas of expertise within each industry, such as in the case of aerospace, automotive, luxury goods, agriculture (particularly wine production), and tourism?'


### Step 1: Extracting Original Queries

You've already loaded the dataset and accessed the topics, which contain the queries. The code snippet you've shown extracts the queries from the `topics` DataFrame provided by PyTerrier:

In [10]:
# Assuming you've already loaded the dataset and have the 'topics' DataFrame
original_queries = [topic['query'] for topic_id, topic in topics.iterrows()]

### Step 2: Preprocessing Queries

You've defined a `TextPreprocessor` class for preprocessing queries. You can apply this preprocessing to each of the original queries to get them ready for the rewriting process:


In [11]:
# preprocessor = TextPreprocessor()
# preprocessed_queries = [preprocessor.preprocess(query) for query in original_queries]


### Step 3: Rewriting Queries

You've also created a `QueryRewriter` class that takes in a model and a tokenizer to rewrite queries based on a certain logic (in your case, this logic depends on the predictions from a BertClassifier model). Initialize the `QueryRewriter` with the necessary model and tokenizer, and then use it to rewrite the preprocessed queries:

In [12]:
from transformers import BertTokenizer

# Initialize the tokenizer for BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Initialize the QueryRewriter with your model and tokenizer
query_rewriter = ContextAwareQueryRewriter(classifier_model,doc_query2doc_model, tokenizer)

# Rewrite the preprocessed queries
rewritten_queries = query_rewriter.transform(original_queries)



TypeError: ContextAwareQueryRewriter.transform() missing 1 required positional argument: 'contexts'

In [None]:
print("Rewritten queries: ")
print(rewritten_queries[:10])
print("\nOriginal queries: ")
print(original_queries[:10])

## Evaluating Rewritten vs. Original Queries with Doc MSMARCO Reranked Models

### Objective

The goal is to utilize the four doc_msmarco reranked models to compare the performance of rewritten queries against original queries.

### Location of Reranked Models

The reranked models are stored within the `doc_msmarco` directory, under the `model` folder. Here's the structure:

```
model/
├── doc_baseline.ckpt
├── doc_davinci_doc_context.ckpt
└── wetransfer_model-1-doc2query-all-rankers_2024-03-21_1327/
    ├── best_rank_list/
    │   └── doc_msmarco/
    │       ├── doc_attention_2_davinci_final_3e-4.tsv
    │       ├── doc_davinci03_doc_context_400tok_5e-4.tsv
    │       ├── doc_linear_2_davinci_final_1e-4.tsv
    │       └── doc_query_2_doc_3e-4.tsv
    └── doc_query2doc.ckpt
```

### Procedure

1. **Retrieve Documents**: Initially, documents are fetched from the MS MARCO datasets for the purpose of reranking.

2. **Reranking**: Utilize the retrieved documents along with the queries (both original and rewritten) to perform reranking using the models specified above.

3. **Evaluation**: Assess the reranking process to determine the effectiveness of rewritten queries in comparison to the original ones.

### Step 1: Indexing the Corpus

Before we can retrieve and rerank documents based on queries, we need to ensure that the documents are indexed. You've mentioned creating an index but commented out the relevant code. Here's how you can do it properly with Pyterrier

In [None]:
index_location = str(Path("index").absolute())
index_exists = os.path.isfile(os.path.join(index_location, "data.properties"))

if not index_exists:
    indexer = pt.IterDictIndexer(index_location)
    index_ref = indexer.index(corpus_iter)
    print("Indexing completed.")
else:
    print("Index already exists, loading from disk.")
    index_ref = index_location

### Step 2: Retrieving Documents

You need to use the created index to retrieve documents based on both the original and rewritten queries. You have already outlined this step correctly, but make sure that `index` is correctly initialized using the `index_ref` from the indexing step:

In [None]:
# Assuming original_queries and rewritten_queries are your lists of queries
query_ids = range(1, len(original_queries) + 1)  # Generating query IDs

# Converting original queries to a DataFrame
original_queries_df = pd.DataFrame({
    'qid': query_ids,
    'query': original_queries
})

# Converting rewritten queries to a DataFrame
rewritten_queries_df = pd.DataFrame({
    'qid': query_ids,
    'query': rewritten_queries
})

In [None]:
# Initialize index from the created index reference
index = pt.IndexFactory.of(index_ref)

# Retrieve documents for both original and rewritten queries using BM25
original_res = pt.BatchRetrieve(index, wmodel="BM25").transform(original_queries_df)
rewritten_res = pt.BatchRetrieve(index, wmodel="BM25").transform(rewritten_queries_df)


### Step 3: Reranking the Results

After retrieving the initial set of documents for both the original and rewritten queries, the next step is to rerank these results. You've outlined the process, assuming you have a `reranker` defined. This reranker could be any model or method that takes a set of retrieved documents and reorders them based on relevance. Let's proceed with the assumption:

In [None]:
# This is how an entry of .tsv looks like: 966413	Q0	D660657	1	5.271960258483887	doc_attention_davinci_2_final_3e-4

# Example for loading one reranking file
reranking_path = './model/wetransfer_model-1-doc2query-all-rankers_2024-03-21_1327/best_rank_list/doc_msmarco/doc_davinci03_doc_context_400tok_5e-4.tsv'

reranking_df = pd.read_csv(
    reranking_path,
    sep='\t',
    header=None,
    names=['qid', 'placeholder', 'docid', 'rank', 'score', 'identifier']
)


# Display the first few rows to understand its structure
reranking_df.head()

In [None]:
original_res.head()

In [None]:
print(original_res)


In [None]:
# Convert 'docid' to string in both your results and reranking DataFrame
original_res['docid'] = original_res['docid'].astype(str)
rewritten_res['docid'] = rewritten_res['docid'].astype(str)
reranking_df['docid'] = reranking_df['docid'].astype(str)  # Assuming this column actually refers to 'docid' as discussed

# Merging with corrected data types
original_res_with_ranks = pd.merge(
    original_res,
    reranking_df,
    on='docid',
    how='left',
    suffixes=('', '_reranked')
)

rewritten_res_with_ranks = pd.merge(
    rewritten_res,
    reranking_df,
    on='docid',
    how='left',
    suffixes=('', '_reranked')
)


print(original_res_with_ranks.head)


### Step 4: Evaluating the Results

Finally, to evaluate the effectiveness of the rewritten queries versus the original queries, you'll use PyTerrier's evaluation functionalities. You need to compare the reranked results against the relevance judgments (`qrels`) using various metrics:

In [None]:
from pyterrier.measures import RR, nDCG, AP

# Ensure your reranked DataFrames are named appropriately and structured correctly
# They should have columns ['qid', 'docno', 'score']

# Perform the evaluation using PyTerrier's Experiment function
# results = pt.Experiment(
#     [original_res_sorted, rewritten_res_sorted],
#     topics,  # Your original queries DataFrame
#     qrels,  # The qrels DataFrame with relevance judgments
#     eval_metrics=[nDCG@10, AP(rel=10), RR(rel=10), nDCG@20],
#     names=["Original Reranked", "Rewritten Reranked"],
#     baseline=0,  # Optionally specify which result set is considered the 'baseline' for comparison
# )

# Display the evaluation results
print(results)
