In [1]:
!pip install python-terrier
import pyterrier as pt
if not pt.started():
    pt.init()

from transformers import BertModel
import torch.nn.functional as F
import torch
import torch.nn as nn

import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
import re
from pathlib import Path
import os
import pandas as pd

from pyterrier.measures import RR, nDCG, MAP
# Ensure NLTK data is downloaded only once
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)



PyTerrier 0.10.0 has loaded Terrier 5.8 (built by craigm on 2023-11-01 18:05) and terrier-helper 0.0.8

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


True

In [2]:
class DocQuery2DocModel(nn.Module):
    def __init__(self, num_labels=2):
        super(DocQuery2DocModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # Align the layer names with the checkpoint
        self.classification = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        logits = self.classification(pooled_output)
        return logits

# Initialize your model
doc_query2doc_model = DocQuery2DocModel()

# Load the checkpoint
model_path = './model/wetransfer_model-1-doc2query-all-rankers_2024-03-21_1327/doc_query2doc.ckpt'
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# Assuming the checkpoint is compatible with the model architecture
doc_query2doc_model.load_state_dict(checkpoint, strict=False)

doc_query2doc_model.eval()  # Set the model to evaluation mode

DocQuery2DocModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise

In [3]:
class DavinciDocContextModel(nn.Module):
    def __init__(self, num_labels=1):  # Adjusted num_labels to 1 to match checkpoint
        super(DavinciDocContextModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classification = nn.Linear(768, num_labels)  # This layer now matches the checkpoint

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs.pooler_output
        logits = self.classification(pooled_output)
        return logits

# Initialize your model with the corrected number of labels
davinci_doc_context_model = DavinciDocContextModel()

# Load the checkpoint
model_path = './model/doc_davinci_doc_context.ckpt'  # Adjust this path to your checkpoint's location
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# Load the model weights from the checkpoint
davinci_doc_context_model.load_state_dict(checkpoint['state_dict'], strict=False)  # Use strict=False for flexibility

davinci_doc_context_model.eval()  # Set the model to evaluation mode


DavinciDocContextModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen