In [42]:
import os
import pickle
import json
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import torch
import faiss

In [43]:


# Load embedding model
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"  # Replace with a domain-specific model if needed
model = SentenceTransformer(MODEL_NAME)

# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is available. Using GPU for embeddings.")
    model = model.to('cuda')
else:
    print("CUDA is not available. Using CPU for embeddings.")

# Chunking function for fixed-length chunks with overlap
def chunk_text(text, chunk_size=300, overlap=50):
    """Splits text into fixed-length chunks with overlap."""
    words = text.split()
    chunks = []
    for i in range(0, len(words), chunk_size - overlap):
        chunk = " ".join(words[i:i + chunk_size])
        chunks.append(chunk)
    return chunks

# Process dataset rows for labeled data
def process_row_labeled(row):
    """Processes a row to generate chunks and metadata for labeled data."""
    raw_text = row.get("context", {}).get("contexts", [])
    raw_text = " ".join(raw_text) if isinstance(raw_text, list) else raw_text
    
    chunks = chunk_text(raw_text)
   
    chunked_data = []
    for chunk in chunks:
        chunk_data = {
            "chunk_text": chunk,
            "label": row.get("context", {}).get("labels", ["unlabeled"])[0],
            "meshes": row.get("context", {}).get("meshes", ["no_mesh"])
        }
        chunked_data.append(chunk_data)
    return chunked_data

# Process dataset rows for unlabeled data
def process_row_unlabeled(row):
    """Processes a row to generate chunks for unlabeled data."""
    raw_text = row.get("output", "")
    
    chunks = chunk_text(raw_text)
    
    chunked_data = [{"chunk_text": chunk} for chunk in chunks]
    return chunked_data

# Process entire dataset
def process_dataset(dataset, is_labeled=False):
    """Processes the entire dataset into chunked data."""
    chunked_dataset = []
    for row in tqdm(dataset, desc="Processing Dataset"):
        if is_labeled:
            chunked_rows = process_row_labeled(row)
        else:
            chunked_rows = process_row_unlabeled(row)
        chunked_dataset.extend(chunked_rows)
    return chunked_dataset

# Generate 384-dimensional embeddings
def regenerate_embeddings(data, model):
    """Generate 384-dimensional embeddings for the dataset."""
    texts = [chunk["chunk_text"] for chunk in data]
    embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
    return embeddings

# Save embeddings to a .pkl file
def save_embeddings_to_file(data, file_path):
    """Saves the embedded data to a .pkl file."""
    with open(file_path, "wb") as f:
        pickle.dump(data, f)

# Load embeddings from a .pkl file
def load_embeddings_from_file(file_path):
    """Loads the embedded data from a .pkl file."""
    with open(file_path, "rb") as f:
        return pickle.load(f)

# Save metadata to a .pkl file
def save_metadata_to_file(data, file_path):
    """Saves the metadata to a .pkl file."""
    with open(file_path, "wb") as f:
        pickle.dump(data, f)

# Generate metadata from all_chunked_data
def generate_metadata(chunked_data):
    """Generates metadata for chunked data."""
    metadata = []
    for chunk in chunked_data:
        metadata_entry = {
            "chunk_text": chunk["chunk_text"],  # Text of the chunk
            "label": chunk.get("label", "unlabeled"),  # Label (if available)
            "meshes": chunk.get("meshes", []),  # Meshes (if available)
            "embedding": chunk.get("embedding", None)  # Embedding (if available)
        }
        metadata.append(metadata_entry)
    return metadata

# Rebuild FAISS Index
def rebuild_faiss_index(embeddings, dimension=384):
    """Create and save a new FAISS index with the specified dimension."""
    index = faiss.IndexFlatL2(dimension)
    index.add(np.array(embeddings, dtype=np.float32))  # Add embeddings to the index
    faiss.write_index(index, "faiss_index_384.bin")  # Save the new FAISS index
    print("FAISS index rebuilt and saved as 'faiss_index_384.bin'")
    return index


# Corrected query function
def query_options_enhanced_with_keywords(question, options, index, metadata, top_k=5, question_keywords=None, option_keywords=None):
    """Query embeddings for each option with keyword prioritization and retrieve relevant contexts."""
    results = {}
    all_retrieved_chunks = set()  # Track globally retrieved chunks for uniqueness

    for option in options:
        # Combine option and question with keyword emphasis
        query_text = f"Option: {option}. Question: {question}. Focus on relevant keywords."
        query_embedding = model.encode([query_text], convert_to_numpy=True)

        # Add keyword embeddings to query embedding
        if question_keywords or option_keywords:
            combined_keywords = (question_keywords or []) + (option_keywords or [])
            keyword_embeddings = model.encode(combined_keywords, convert_to_numpy=True)
            keyword_embedding = np.mean(keyword_embeddings, axis=0)  # Aggregate keyword embeddings
            query_embedding = 0.7 * query_embedding + 0.3 * keyword_embedding  # Adjust weights as needed

        # Ensure query embedding is in the correct format
        query_embedding = query_embedding.astype("float32")

        # Search in the FAISS index
        distances, indices = index.search(query_embedding.reshape(1, -1), top_k)

        # Retrieve chunks and prioritize uniqueness
        retrieved_chunks = [
            {
                "chunk_text": metadata[idx]["chunk_text"],
                "label": metadata[idx].get("label", "unlabeled"),
                "distance": distances[0][i]
            }
            for i, idx in enumerate(indices[0])
            if metadata[idx]["chunk_text"] not in all_retrieved_chunks
        ]

        # Add retrieved chunks to global set
        all_retrieved_chunks.update(chunk["chunk_text"] for chunk in retrieved_chunks)

        # Post-retrieval keyword filtering and scoring
        if question_keywords or option_keywords:
            combined_keywords = (question_keywords or []) + (option_keywords or [])
            retrieved_chunks = sorted(
                retrieved_chunks,
                key=lambda chunk: sum(
                    1 for keyword in combined_keywords if keyword.lower() in chunk["chunk_text"].lower()
                ),
                reverse=True  # Prioritize chunks with more keyword matches
            )

        # Store results
        results[option] = {"retrieved_chunks": retrieved_chunks}

    return results


# Load datasets
print("Loading datasets...")
context_ds_artificial = load_dataset("qiaojin/PubMedQA", "pqa_artificial")['train']
context_ds_labeled = load_dataset("qiaojin/PubMedQA", "pqa_labeled")['train']
context_ds_unlabeled = load_dataset("qiaojin/PubMedQA", "pqa_unlabeled")['train']
context_ds_knowledge = load_dataset("medalpaca/medical_meadow_wikidoc")['train']

# Process datasets
print("Processing labeled dataset...")
chunked_labeled = process_dataset(context_ds_labeled, is_labeled=True)
print("Processing artificial dataset...")
chunked_artificial = process_dataset(context_ds_artificial, is_labeled=True)
print("Processing unlabeled dataset...")
chunked_unlabeled = process_dataset(context_ds_unlabeled, is_labeled=False)
print("Processing knowledge dataset...")
chunked_knowledge = process_dataset(context_ds_knowledge, is_labeled=False)

# Combine all chunked data
print("Combining all datasets...")
all_chunked_data = chunked_labeled + chunked_artificial + chunked_unlabeled + chunked_knowledge

# Embedding and FAISS Index Handling
embedding_file = "embedded_combined.pkl"
if os.path.exists(embedding_file):
    print(f"Loading existing embeddings from '{embedding_file}'...")
    embedded_data = load_embeddings_from_file(embedding_file)
else:
    print("Generating new embeddings for combined dataset...")
    all_embeddings = regenerate_embeddings(all_chunked_data, model)
    save_embeddings_to_file(all_embeddings, embedding_file)
    print("Embeddings saved to file.")

# Check for FAISS index
faiss_index_file = "faiss_index_384.bin"
if os.path.exists(faiss_index_file):
    print(f"Loading existing FAISS index from '{faiss_index_file}'...")
    faiss_index = faiss.read_index(faiss_index_file)
else:
    print("Generating new FAISS index...")
    faiss_index = rebuild_faiss_index(all_embeddings)

# Save metadata
metadata_file = "metadata.pkl"
metadata = generate_metadata(all_chunked_data)
save_metadata_to_file(metadata, metadata_file)
print(f"Metadata saved to {metadata_file}.")

print("Embedding and FAISS indexing process completed.")


CUDA is available. Using GPU for embeddings.
Loading datasets...
Processing labeled dataset...


Processing Dataset: 100%|██████████| 1000/1000 [00:00<00:00, 8608.03it/s]


Processing artificial dataset...


Processing Dataset: 100%|██████████| 211269/211269 [00:17<00:00, 12233.65it/s]


Processing unlabeled dataset...


Processing Dataset: 100%|██████████| 61249/61249 [00:03<00:00, 17832.11it/s]


Processing knowledge dataset...


Processing Dataset: 100%|██████████| 10000/10000 [00:00<00:00, 34988.98it/s]


Combining all datasets...
Loading existing embeddings from 'embedded_combined.pkl'...
Loading existing FAISS index from 'faiss_index_384.bin'...
Metadata saved to metadata.pkl.
Embedding and FAISS indexing process completed.


In [44]:


import random
import pickle
import faiss
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("GBaker/MedQA-USMLE-4-options")["train"]

# Load FAISS index and metadata
print("Loading FAISS index and metadata...")
faiss_index = faiss.read_index("faiss_index_384.bin")
with open("metadata.pkl", "rb") as f:
    metadata = pickle.load(f)

# Adjust query function to handle dictionary-style options
def query_from_dataset_with_dict_options(dataset, index, metadata, top_k=5, num_samples=10):
    """Query FAISS index for random question-option pairs with dictionary-style options."""
    # Select random indices
    total_samples = len(dataset)
    random_indices = random.sample(range(total_samples), num_samples)  # Randomly select indices
    random_samples = [dataset[i] for i in random_indices]  # Retrieve random samples
    
    # Retrieve contexts for each sample
    retrieved_results = []
    for sample in random_samples:
        question = sample["question"]
        options_dict = sample["options"]  # Extract dictionary of options
        
        # Format options as a list of strings 
        options = [f"{key}. {value}" for key, value in options_dict.items()]
        
        # Query the index for this question and its options
        retrieved_contexts = query_options_enhanced_with_keywords(
            question, options, index, metadata, top_k=top_k
        )
        
        retrieved_results.append({
            "question": question,
            "options": options_dict,  # Keep the original dictionary for reference
            "retrieved_contexts": retrieved_contexts
        })
    return retrieved_results

# Retrieve contexts for 10 random samples
results = query_from_dataset_with_dict_options(dataset, faiss_index, metadata, top_k=5, num_samples=10)

# Display results
for idx, result in enumerate(results):
    print(f"\nQuestion {idx + 1}: {result['question']}")
    for option, data in result["retrieved_contexts"].items():
        print(f"Option {option}:")
        for chunk in data["retrieved_chunks"]:
            print(f"  - {chunk['chunk_text']} (Distance: {chunk['distance']})")



Loading FAISS index and metadata...

Question 1: A 27-year-old male is brought to the emergency department with a 1-week history of worsening headache. Over the past 2 days, he has become increasingly confused and developed nausea as well as vomiting. One week ago, he struck his head while exiting a car, but did not lose consciousness. His maternal uncle had a bleeding disorder. He appears in moderate distress. He is oriented to person and time but not to place. His temperature is 37.1°C (98.8°F), pulse is 72/min, respirations are 20/min, and blood pressure is 128/78 mm Hg. Cardiopulmonary examination is unremarkable. His abdomen is soft and nontender. Muscle strength is 5/5 in left upper and left lower extremities, and 3/5 in right upper and right lower extremities. Laboratory studies show:
Leukocyte Count 10,000/mm3
Hemoglobin 13.6 g/dL
Hematocrit 41%
Platelet Count 150,000/mm3
PT 13 seconds
aPTT 60 seconds
Serum
Sodium 140 mEq/L
Potassium 4.2 mEq/L
Chloride 101 mEq/L
Bicarbonate 24 

In [45]:
from transformers import BartTokenizer, BartForConditionalGeneration

# Load BART model and tokenizer
bart_model_name = "facebook/bart-base"
bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name)
bart_model = BartForConditionalGeneration.from_pretrained(bart_model_name)

# Define a summarization function
def summarize_text_bart(text, max_length=100, min_length=30):
    """Summarizes the input text using BART-base."""
    input_ids = bart_tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True)
    summary_ids = bart_model.generate(
        input_ids, max_length=max_length, min_length=min_length, length_penalty=2.0, num_beams=4, early_stopping=True
    )
    return bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# Summarize chunk_text fields for each option
def summarize_chunk_texts_bart(retrieved_contexts, max_chunks=5):
    """
    Summarizes chunk_text fields for each option using BART-base.

    Args:
        retrieved_contexts (dict): The retrieved contexts structure containing options and their chunks.
        max_chunks (int): Maximum number of chunks to include in the summary per option.

    Returns:
        dict: Summarized context for each option.
    """
    summarized_contexts = {}
    for option, data in retrieved_contexts.items():
        # Extract up to max_chunks of chunk_text
        chunks = [chunk['chunk_text'] for chunk in data['retrieved_chunks'][:max_chunks]]
        
        # Combine chunk_text fields into a single text
        combined_text = " ".join(chunks)
        
        # Summarize the combined chunk_text
        summary = summarize_text_bart(combined_text)
        
        # Store the summary
        summarized_contexts[option] = summary
    
    return summarized_contexts

# Dynamic pipeline
def process_dynamic_summarization(dataset, index, metadata, top_k=5, num_samples=10, max_chunks=5):
    """Pipeline for processing, querying, and summarizing contexts dynamically."""
    # Retrieve question-option pairs
    results = query_from_dataset_with_dict_options(dataset, index, metadata, top_k=top_k, num_samples=num_samples)

    # Summarize retrieved contexts for each question
    summarized_results = []
    for result in results:
        question = result["question"]
        options = result["options"]
        retrieved_contexts = result["retrieved_contexts"]

        # Summarize retrieved contexts
        summarized_contexts = summarize_chunk_texts_bart(retrieved_contexts, max_chunks=max_chunks)

        # Combine summarized contexts with the question and options
        summarized_results.append({
            "question": question,
            "options": options,
            "summarized_contexts": summarized_contexts
        })
    
    return summarized_results

# Process and summarize 10 random question-option pairs
summarized_results = process_dynamic_summarization(dataset, faiss_index, metadata, top_k=5, num_samples=10, max_chunks=5)

# Display summarized results
for idx, result in enumerate(summarized_results):
    print(f"\nQuestion {idx + 1}: {result['question']}")
    for option, summary in result["summarized_contexts"].items():
        print(f"{option} :\n{summary}\n")



Question 1: A 22-year-old man presents with multiple, target-like skin lesions on his right and left upper and lower limbs. He says that the lesions appeared 4 days ago and that, over the last 24 hours, they have extended to his torso. Past medical history is significant for pruritus and pain on the left border of his lower lip 1 week ago, followed by the development of an oral ulcerative lesion. On physical examination, multiple round erythematous papules with a central blister, a pale ring of edema surrounding a dark red inflammatory zone, and an erythematous halo are noted. Mucosal surfaces are free of any ulcerative and exudative lesions. Which of the following statements best explains the pathogenesis underlying this patient’s condition?
A. Tumor necrosis factor (TNF) alpha production by CD4+ T cells in the skin :
To investigate whether tumour necrosis factor alpha (TNFalpha) is expressed in subacute cutaneous lupus erythematosus (SCLE) skin lesions. The in situ expression of TNF