In [None]:
import json
import numpy as np
import torch
import faiss
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import os
import pickle

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

def main():
    bm25_corpus = []
    metadata = []
    with open('preprocessed_cord19.jsonl') as f:
        for line in tqdm(f, desc="Loading documents"):
            doc = json.loads(line)
            bm25_corpus.append(doc['bm25']['combined'].split())
            metadata.append({
                'doc_id': doc['doc_id'],
                'original_title': doc['original_title'],
                'biobert_text': doc['biobert']['combined']
            })

    np.save('bm25_corpus.npy', np.array(bm25_corpus, dtype=object))
    with open('metadata.json', 'w') as f:
        json.dump(metadata, f)

    bm25 = BM25Okapi(bm25_corpus)
    with open('bm25_model.pkl', 'wb') as f:
        pickle.dump(bm25, f)

    tokenizer = AutoTokenizer.from_pretrained("monologg/biobert_v1.1_pubmed")
    model = AutoModel.from_pretrained("monologg/biobert_v1.1_pubmed").to(device)
    
    embeddings = []
    batch_size = 128 if device != 'cpu' else 32
    for i in tqdm(range(0, len(metadata), batch_size), desc="Generating embeddings"):
        batch = [m['biobert_text'] for m in metadata[i:i+batch_size]]
        inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.last_hidden_state[:,0,:].cpu().numpy())
    
    full_embeddings = np.concatenate(embeddings)
    np.save('biobert_embeddings.npy', full_embeddings)
    

if __name__ == '__main__':
    main()
    print("Precomputing complete! You can now use search_engine.py")




Loading documents: 148752it [00:21, 6766.43it/s]
Generating embeddings: 100%|██████████| 1163/1163 [1:30:13<00:00,  4.65s/it]


Precomputing complete! You can now use search_engine.py
