In [1]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import pickle
import torch.nn.utils.rnn as rnn_utils
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
import time
from tqdm import tqdm

In [4]:
with open("trec_covid_preprocessed_full.pkl", "rb") as f:
    processed_records = pickle.load(f)

In [5]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

class BioBERTReranker:
    def __init__(self, precomputed_embeddings=None):
        self.tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
        self.model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1").to(device)
        self.model.eval()  
        
        self.doc_embeddings = precomputed_embeddings or {}
        
        self.query_cache = defaultdict(lambda: None)

    def _encode_batch(self, texts, batch_size=32):
        embeddings = []
        for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
            batch = texts[i:i+batch_size]
            inputs = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(device)
            
            with torch.no_grad(), torch.autocast(device_type=device.type):
                outputs = self.model(**inputs)
            
            batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            embeddings.append(batch_embeddings)
        
        return np.concatenate(embeddings, axis=0)

    def precompute_document_embeddings(self, documents, batch_size=32):
        doc_embeddings = {}
        for doc_id, chunks in tqdm(documents.items(), desc="Precomputing docs"):
            if chunks:  
                chunk_texts = [chunk['text'] for chunk in chunks]
                doc_embeddings[doc_id] = self._encode_batch(chunk_texts, batch_size)
        self.doc_embeddings = doc_embeddings
        return doc_embeddings

    def encode_query(self, query_text, batch_size=16):
        if self.query_cache[query_text] is None:
            self.query_cache[query_text] = self._encode_batch([query_text], batch_size)[0]
        return self.query_cache[query_text]

    def rerank(self, query_text, doc_ids, top_k=10, chunk_weight=0.9):
        start_time = time.time()
        
        query_embedding = self.encode_query(query_text)
        
        doc_similarities = []
        for doc_id in doc_ids:
            if doc_id not in self.doc_embeddings:
                continue
                
            chunk_embeddings = self.doc_embeddings[doc_id]
            if len(chunk_embeddings) == 0:
                continue
            
            similarities = cosine_similarity(
                [query_embedding],
                chunk_embeddings
            )[0]
            
            max_sim = np.max(similarities)
            avg_sim = np.mean(similarities)
            final_score = (chunk_weight * max_sim) + ((1 - chunk_weight) * avg_sim)
            
            doc_similarities.append((doc_id, final_score))
        
        doc_similarities.sort(key=lambda x: x[1], reverse=True)
        
        print(f"Reranked {len(doc_similarities)} docs in {time.time()-start_time:.2f}s")
        return doc_similarities[:top_k]

Using device: mps


In [6]:
if __name__ == "__main__":
    reranker = BioBERTReranker()

    sample_query = "COVID-19 transmission in children"
    candidate_doc_ids = ["8qnrcgnk", "785vg6d", "ejv2xln0"]  

    reranked = reranker.rerank(sample_query, candidate_doc_ids)
    print("Top results:", reranked[:3])

Processing batches: 100%|██████████| 1/1 [00:00<00:00, 28.61it/s]

Reranked 0 docs in 0.04s
Top results: []



