In [1]:
import os
import torch
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
import faiss
import numpy as np
from tqdm import tqdm
from typing import List

from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.embeddings.base import Embeddings


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CustomHuggingFaceEmbeddings(Embeddings):  # Inherit from Embeddings base class
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        inputs = self.tokenizer(texts, padding=True, truncation=True, 
                            max_length=512, return_tensors="pt").to('cuda')
        with torch.no_grad():
            outputs = self.model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings.cpu().numpy().tolist()  # Convert to list format
    
    def embed_query(self, text: str) -> List[float]:
        return self.embed_documents([text])[0]

In [3]:
def load_vectorstore(path="./pubmedqa_vectorstore"):
    """Load the FAISS vector store from disk"""
    # Initialize the same embedding model as used in creation
    model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to('cuda')
    
    embeddings = CustomHuggingFaceEmbeddings(model, tokenizer)
    
    # Load the vectorstore
    vectorstore = FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True)
    return vectorstore

In [4]:
def retrieve_contexts(vectorstore, query, k=5):
    """
    Retrieve relevant contexts for a given query using direct FAISS search
    
    Args:
        vectorstore: FAISS vectorstore
        query: str, the search query
        k: int, number of results to return
    
    Returns:
        List of relevant documents with their scores
    """
    # Get query embedding
    query_embedding = vectorstore.embedding_function.embed_query(query)
    query_embedding = np.array(query_embedding).astype('float32').reshape(1, -1)
    
    # Search the FAISS index directly
    distances, indices = vectorstore.index.search(query_embedding, k)
    
    # Format results
    results = []
    for idx, distance in zip(indices[0], distances[0]):
        doc = vectorstore.docstore.get(idx)  # Get document from docstore
        if doc is not None:
            results.append((doc, float(distance)))
    
    return results

In [5]:
vectorstore = load_vectorstore(path="./pubmedqa_vectorstore")

In [20]:
query = "symptoms of heart attack"
results = retrieve_contexts(vectorstore, query, k=5)  # k is the number of results you want

In [21]:
for doc, score in results:
    print(f"Score: {score}")
    print(f"Content: {doc.page_content}\n")

Score: 0.14703282713890076
Content: (UNASEM) trial on the effect of thrombolysis in unstable angina. Patients without a previous myocardial infarction who presented with a typical history of unstable angina in the presence of abnormal findings on the electrocardiogram indicative of ischemia were included in the study. After baseline angiography, study medication (anistreplase or placebo) was given to 126 to 159 patients. Thirty-three patients did not receive medication because of significant main stem disease or normal coronary arteries or

Score: 0.15074096620082855
Content: at the time of worst bradyarrhythmia predicted the risk of TdP. However, the QT, corrected QT (QTc), and T(peak)-T(end) intervals correlated with the risk of TdP. The best single discriminator was a T(peak)-T(end) of 117 ms. LQT1-like and LQT3-like morphologies were rare during bradyarrhythmias. In contrast, LQT2-like "notched T waves" were observed in 55% of patients with TdP but in only 3% of patients with uncom