This is the unsuccessful version of the RAG system.  This script implements a document retrieval and processing pipeline using a combination of BM25 and Dense embeddings. It includes preprocessing, chunking, and retrieval to generate relevant contexts for given questions and options.

In [13]:
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset
import numpy as np
import heapq
import pickle
import os
import torch
from tqdm import tqdm

# Load SentenceTransformer Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)

# Preprocess Data
class TextProcessor:
    """
    Handles text chunking for long documents.
    """
    def __init__(self, chunk_size=300, overlap=50):
        self.chunk_size = chunk_size
        self.overlap = overlap

    def chunk_text(self, text):
        """
        Splits a long text into overlapping chunks.

        Args:
            text (str): Input text to be chunked.

        Returns:
            list: List of text chunks.
        """
        words = text.split()
        chunks = []
        for i in range(0, len(words), self.chunk_size - self.overlap):
            chunk = " ".join(words[i:i + self.chunk_size])
            chunks.append(chunk)
        return chunks

class DataProcessor:
    """
    Processes datasets by chunking and preparing data rows.
    """
    def __init__(self, text_processor):
        self.text_processor = text_processor

    def process_row(self, row, is_labeled):
        """
        Processes a single dataset row.

        Args:
            row (dict): Dataset row.
            is_labeled (bool): Indicates if the dataset is labeled.

        Returns:
            list: Processed chunks with additional metadata.
        """
        raw_text = row.get("context", {}).get("contexts", []) if is_labeled else row.get("output", "")
        raw_text = " ".join(raw_text) if isinstance(raw_text, list) else raw_text
        chunks = self.text_processor.chunk_text(raw_text)

        chunked_data = []
        for chunk in chunks:
            chunk_data = {"chunk_text": chunk}
            if is_labeled:
                chunk_data.update({
                    "label": row.get("context", {}).get("labels", ["unlabeled"])[0],
                    "meshes": row.get("context", {}).get("meshes", ["no_mesh"])
                })
            chunked_data.append(chunk_data)

        return chunked_data

    def process_dataset(self, dataset, is_labeled=False):
        """
        Processes an entire dataset.

        Args:
            dataset (list): List of dataset rows.
            is_labeled (bool): Indicates if the dataset is labeled.

        Returns:
            list: Processed dataset with chunked rows.
        """
        chunked_dataset = []
        for row in tqdm(dataset, desc="Processing Dataset"):
            chunked_rows = self.process_row(row, is_labeled)
            chunked_dataset.extend(chunked_rows)
        return chunked_dataset

class Preprocessor:
    """
    Handles preprocessing of text data and formatting of queries.
    """
    def __init__(self, tokenizer=None):
        self.tokenizer = tokenizer or (lambda x: x.lower().split())

    def preprocess_corpus(self, corpus):
        """
        Tokenizes the entire corpus.

        Args:
            corpus (list): List of documents.

        Returns:
            list: Tokenized corpus.
        """
        return [self.tokenizer(doc) for doc in corpus]

    def format_query(self, question, options):
        """
        Formats queries by combining question and options.

        Args:
            question (str): Input question.
            options (dict): Dictionary of options.

        Returns:
            dict: Formatted queries for each option.
        """
        return {key: f"[Q] {question} [O] {text}" for key, text in options.items()}

# BM25 Retriever
class BM25Retriever:
    """
    BM25-based retrieval system for ranking documents.
    """
    def __init__(self, corpus):
        self.preprocessor = Preprocessor()
        self.tokenized_corpus = self.preprocessor.preprocess_corpus(corpus)
        self.bm25 = BM25Okapi(self.tokenized_corpus)

    def retrieve(self, query, top_n=5):
        """
        Retrieves top-N relevant documents for a given query.

        Args:
            query (str): Input query.
            top_n (int): Number of top documents to retrieve.

        Returns:
            tuple: Indices and scores of the top documents.
        """
        tokenized_query = self.preprocessor.tokenizer(query)
        scores = self.bm25.get_scores(tokenized_query)
        top_indices = np.argsort(scores)[::-1][:top_n]
        return top_indices, [scores[i] for i in top_indices]

# Dense Retriever
class DenseRetriever:
    """
    Dense embedding-based retrieval system.
    """
    def __init__(self, corpus, embeddings_file="embeddings.pkl"):
        self.corpus = corpus
        self.embeddings_file = embeddings_file
        self.embeddings = self.load_or_generate_embeddings()

    def load_or_generate_embeddings(self):
        """
        Loads or generates embeddings for the corpus.

        Returns:
            torch.Tensor: Tensor of document embeddings.
        """
        if os.path.exists(self.embeddings_file):
            print("Loading existing embeddings...")
            with open(self.embeddings_file, 'rb') as f:
                embeddings = torch.tensor(pickle.load(f), device=device)
        else:
            print("Generating embeddings...")
            embeddings = embedding_model.encode(
                self.corpus, show_progress_bar=True, convert_to_tensor=True, device=device
            )
            with open(self.embeddings_file, 'wb') as f:
                pickle.dump(embeddings.cpu().numpy(), f)
        return embeddings

    def retrieve(self, query, top_n=5):
        """
        Retrieves top-N relevant documents for a given query.

        Args:
            query (str): Input query.
            top_n (int): Number of top documents to retrieve.

        Returns:
            tuple: Indices and scores of the top documents.
        """
        query_embedding = embedding_model.encode([query], convert_to_tensor=True, device=device)[0]
        query_embedding = query_embedding / torch.norm(query_embedding)  # Normalize query embedding

        scores = cosine_similarity(query_embedding.cpu().numpy().reshape(1, -1), self.embeddings.cpu().numpy())[0]
        top_indices = np.argsort(scores)[::-1][:top_n]
        return top_indices, [scores[i] for i in top_indices]

#Fusion of BM25 and Dense Results
class HybridRetriever:
    """
    Combines BM25 and Dense retrieval results using weighted fusion.
    """
    def __init__(self, bm25_retriever, dense_retriever):
        self.bm25_retriever = bm25_retriever
        self.dense_retriever = dense_retriever

    def retrieve(self, query, top_n=5, weight_bm25=0.6, weight_dense=0.4):
        """
        Retrieves top-N relevant documents by combining BM25 and Dense scores.

        Args:
            query (str): Input query.
            top_n (int): Number of top documents to retrieve.
            weight_bm25 (float): Weight for BM25 scores.
            weight_dense (float): Weight for Dense scores.

        Returns:
            tuple: Indices and scores of the top documents after fusion.
        """
        bm25_indices, bm25_scores = self.bm25_retriever.retrieve(query, top_n)
        dense_indices, dense_scores = self.dense_retriever.retrieve(query, top_n)

        combined_scores = {}
        for i, score in zip(bm25_indices, bm25_scores):
            combined_scores[i] = combined_scores.get(i, 0) + weight_bm25 * score

        for i, score in zip(dense_indices, dense_scores):
            combined_scores[i] = combined_scores.get(i, 0) + weight_dense * score

        # Deduplicate combined results
        unique_combined = {}
        for idx, score in combined_scores.items():
            context = corpus[idx][:700]
            if context not in unique_combined:
                unique_combined[context] = score

        # Rank and return deduplicated results
        top_combined = heapq.nlargest(top_n, unique_combined.items(), key=lambda x: x[1])
        return [corpus.index(ctx) for ctx, _ in top_combined], [score for _, score in top_combined]



# Summarization and Presentation
class ContextGenerator:
    """
    Generates deduplicated and diverse contexts from retrieved indices.
    """
    def __init__(self, corpus, similarity_threshold=0.85):
        self.corpus = corpus
        self.similarity_threshold = similarity_threshold

    def generate_context(self, indices):
        """
        Generates a list of diverse contexts for given indices.

        Args:
            indices (list): List of document indices.

        Returns:
            list: List of deduplicated contexts.
        """
        seen = set()  # Deduplication tracking
        contexts = []
        unique_embeddings = []

        for i in indices:
            context = self.corpus[i][:700]
            
            # Calculate similarity to existing contexts
            context_embedding = embedding_model.encode([context], convert_to_tensor=True, device=device)[0]
            if unique_embeddings:
                similarities = cosine_similarity(
                    context_embedding.cpu().numpy().reshape(1, -1),
                    torch.stack(unique_embeddings).cpu().numpy()
                )[0]
                if max(similarities) > self.similarity_threshold:
                    continue  # Skip if similar context already exists

            # Add to results
            contexts.append(context)
            unique_embeddings.append(context_embedding)

        return contexts

# Preprocess Knowledge Datasets
print("Processing datasets...")
text_processor = TextProcessor(chunk_size=300, overlap=50)
data_processor = DataProcessor(text_processor)

context_ds_artificial = load_dataset("qiaojin/PubMedQA", "pqa_artificial")
context_ds_labeled = load_dataset("qiaojin/PubMedQA", "pqa_labeled")
context_ds_unlabeled = load_dataset("qiaojin/PubMedQA", "pqa_unlabeled")
context_ds_knowledge = load_dataset("medalpaca/medical_meadow_wikidoc")

# Process and combine all datasets
chunked_artificial = data_processor.process_dataset(context_ds_artificial["train"], is_labeled=True)
chunked_labeled = data_processor.process_dataset(context_ds_labeled["train"], is_labeled=True)
chunked_unlabeled = data_processor.process_dataset(context_ds_unlabeled["train"], is_labeled=False)
chunked_knowledge = data_processor.process_dataset(context_ds_knowledge["train"], is_labeled=False)

all_chunked_data = chunked_artificial + chunked_labeled + chunked_unlabeled + chunked_knowledge
corpus = [chunk["chunk_text"] for chunk in all_chunked_data if "chunk_text" in chunk]
corpus = list(set(corpus))
print(f"Total number of unique chunks in the corpus: {len(corpus)}")

# Query Dataset and Test Query
query_ds = load_dataset("GBaker/MedQA-USMLE-4-options")

def retrieve_context_for_question(question, options):
    preprocessor = Preprocessor()
    formatted_queries = preprocessor.format_query(question, options)
    results = {}

    # Track seen contexts globally for all options 
    global_seen_contexts = set()
    for option_key, combined_query in formatted_queries.items():
        # Generate embedding for the query and option pair
        query_embedding = embedding_model.encode([combined_query], convert_to_tensor=True, device=device)[0]

        # Enhance priority of query keywords
        question_embedding = embedding_model.encode([f"[Q] {question}"], convert_to_tensor=True, device=device)[0]
        option_embedding = embedding_model.encode([f"[O] {options[option_key]}"], convert_to_tensor=True, device=device)[0]
        query_embedding = 0.5 * query_embedding + 0.3 * question_embedding + 0.2 * option_embedding
        query_embedding = query_embedding / torch.norm(query_embedding)

        # Perform Dense Retrieval with updated embedding
        scores = cosine_similarity(query_embedding.cpu().numpy().reshape(1, -1), dense_retriever.embeddings.cpu().numpy())[0]
        top_indices = np.argsort(scores)[::-1][:5]

        # Generate Context with Deduplication Across Options
        deduplicated_contexts = []
        for idx in top_indices:
            context = corpus[idx][:700]
            if context not in global_seen_contexts:
                deduplicated_contexts.append(context)
                global_seen_contexts.add(context)

        results[option_key] = deduplicated_contexts

    return results

# Initialize Components
bm25_retriever = BM25Retriever(corpus)
dense_retriever = DenseRetriever(corpus)
hybrid_retriever = HybridRetriever(bm25_retriever, dense_retriever)
context_generator = ContextGenerator(corpus)

# Test Query: Retrieve and Process the First 20 Questions
for i in range(20):  # Loop through the first 20 questions
    test_question = query_ds["train"][i]["question"]
    test_options = query_ds["train"][i]["options"]

    print(f"\nProcessing Question {i + 1}:")
    print(f"Question: {test_question}")
    print(f"Options: {test_options}")

    # Retrieve Contexts for the Current Question
    retrieved_contexts = retrieve_context_for_question(test_question, test_options)

    # Display Retrieved Contexts for Each Option
    print("\nRetrieved Context for Each Option:")
    for option, contexts in retrieved_contexts.items():
        print(f"\nOption: {option}")
        for j, ctx in enumerate(contexts):
            print(f"  Context {j + 1}: {ctx}")



Processing datasets...


Processing Dataset: 100%|██████████| 211269/211269 [00:19<00:00, 11056.26it/s]
Processing Dataset: 100%|██████████| 1000/1000 [00:00<00:00, 10943.24it/s]
Processing Dataset: 100%|██████████| 61249/61249 [00:03<00:00, 17467.51it/s]
Processing Dataset: 100%|██████████| 10000/10000 [00:00<00:00, 34404.47it/s]


Total number of unique chunks in the corpus: 250726
Loading existing embeddings...

Processing Question 1:
Question: A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?
Options: {'A': 'Ampicillin', 'B': 'Ceftriaxone', 'C': 'Doxycycline', 'D': 'Nitrofurantoin'}

Retrieved Context for Each Option:

Option: A
  Context 1: This study was designed to analyse factors potentially influencing children's return visits to physicians for symptoms of acute otitis media (AOM) 