In [34]:
!pip install python-terrier
import pyterrier as pt
if not pt.started():
    pt.init()

from transformers import BertModel
import torch
import torch.nn as nn



In [35]:
class BertClassifier(nn.Module):
    def __init__(self, num_labels=2):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # Align the layer names with the checkpoint
        self.classification = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        logits = self.classification(pooled_output)
        return logits

# Initialize the model
model = BertClassifier()

# Load the checkpoint
model_path = './model/doc_baseline.ckpt'
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# Extract the state_dict, assuming the state dict is saved under 'state_dict'
model_state_dict = checkpoint['state_dict']

# Adjust keys in the state_dict as needed
adjusted_model_state_dict = {}
for key, value in model_state_dict.items():
    # Remove the 'model.' prefix and handle the 'classifier' to 'classification' naming difference
    new_key = key.replace('model.', '').replace('classification', 'classifier')
    if 'position_ids' not in new_key:  # Ignore 'bert.embeddings.position_ids'
        adjusted_model_state_dict[new_key] = value

# Load the adjusted state dict
model.load_state_dict(adjusted_model_state_dict, strict=False)
model.eval()  # Set the model to evaluation mode


BertClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [36]:
class QueryRewriter(pt.transformer.TransformerBase):
    def __init__(self, model, tokenizer):
        super(QueryRewriter, self).__init__()
        self.model = model
        self.tokenizer = tokenizer

    def transform(self, queries):
        rewritten_queries = []
        for query in queries:
            # Tokenize the query
            inputs = self.tokenizer.encode_plus(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]

            # Predict with the model
            with torch.no_grad():
                # Here we directly use the model's output
                output = self.model(input_ids=input_ids, attention_mask=attention_mask)
                logits = output

                # Assuming your model's output is logits and you're doing classification
                predictions = torch.argmax(logits, dim=-1)

            # Assuming 1 signifies ambiguity and requires rewriting
            if predictions.item() == 1:
                rewritten_query = self.rewrite_query(query)
            else:
                rewritten_query = query

            rewritten_queries.append(rewritten_query)

        return rewritten_queries

    def rewrite_query(self, query):
        # TODO: This is a placeholder for the actual query rewriting logic
        # I'm just simply appending something to indicate a rewrite
        return query + " (rewritten)"


In [37]:
from transformers import BertTokenizer

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Initialize the model (assuming you've already loaded your custom model)
model = BertClassifier()
model.eval()  # Ensure the model is in evaluation mode

# Initialize the QueryRewriter with your model and tokenizer
query_rewriter = QueryRewriter(model, tokenizer)


  super(QueryRewriter, self).__init__()


In [38]:
# Sample queries
queries = [
    "What's the weather like today?",
    "Tell me about historical monuments.",
    "What are the symptoms of the flu?"
]

# Use the QueryRewriter to transform (rewrite) the queries
rewritten_queries = query_rewriter.transform(queries)

# Print the original and rewritten queries for comparison
for original, rewritten in zip(queries, rewritten_queries):
    print(f"Original: {original}\nRewritten: {rewritten}\n")


Original: What's the weather like today?
Rewritten: What's the weather like today?

Original: Tell me about historical monuments.
Rewritten: Tell me about historical monuments.

Original: What are the symptoms of the flu?
Rewritten: What are the symptoms of the flu?


In [39]:
import pandas as pd

rankings_path = './model/wetransfer_model-1-doc2query-all-rankers_2024-03-21_1327/best_rank_list/doc_msmarco/doc_davinci03_doc_context_400tok_5e-4.tsv'
rankings = pd.read_csv(rankings_path, sep='\t', header=None, names=['qid', 'docid', 'rank', 'score'])

print(rankings.head)

precomputed_ranker = pt.apply.generic(lambda df: rankings)

<bound method NDFrame.head of                  qid  docid      rank  \
966413  Q0   D721816      1  5.391621   
        Q0   D104404      2  5.362526   
        Q0  D3547764      3  5.334213   
        Q0   D160447      4  5.331561   
        Q0   D147193      5  5.302257   
...              ...    ...       ...   
1108307 Q0  D2610362     96  1.231557   
        Q0  D3309585     97  1.115268   
        Q0  D1512095     98  0.114295   
        Q0  D1663165     99 -0.648574   
        Q0  D2135708    100 -2.481846   

                                                        score  
966413  Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
        Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
        Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
        Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
        Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
...                                                       ...  
1108307 Q0  msmarco_doc_davin

In [40]:
# Assuming your model is already initialized and set to evaluation mode as `model`
# Initialize the QueryRewriter with your model and tokenizer
query_rewriter = QueryRewriter(model, tokenizer)

# Ensure the precomputed_ranker is correctly defined as before
precomputed_ranker = pt.apply.generic(lambda df: rankings)

# Now create the pipeline with the query_rewriter that already has the model and tokenizer
pipeline = query_rewriter >> precomputed_ranker

# Execute the search with the pipeline
results = pipeline.search("your initial query")

  super(QueryRewriter, self).__init__()


In [41]:
print(results.head)

<bound method NDFrame.head of                  qid  docid      rank  \
1120868 Q0  D1000007     31  5.018914   
156493  Q0  D1000170     77  4.205341   
        Q0  D1000172     70  4.344245   
        Q0  D1000173     21  4.870865   
1104492 Q0  D1000501     35  1.614483   
...              ...    ...       ...   
1117387 Q0   D999412     94 -3.046582   
451602  Q0   D999449     43  4.828393   
1115392 Q0   D999450     67 -0.689782   
1104031 Q0   D999643     44  2.001423   
964223  Q0   D999926     17  4.770953   

                                                        score  
1120868 Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
156493  Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
        Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
        Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
1104492 Q0  msmarco_doc_davinci03_doc_context_1200_400tok_...  
...                                                       ...  
1117387 Q0  msmarco_doc_davin