### Importing Libraries

In [1]:
import os
import json
from typing import List, Dict, Any
from tqdm import tqdm


import dotenv
from dotenv import load_dotenv

import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer, util


from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_core.embeddings import Embeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import Document

  from .autonotebook import tqdm as notebook_tqdm


### Env Config

In [2]:
load_dotenv()

os.environ["NVIDIA_API_KEY"] = os.getenv("NVIDIA_API_KEY")
os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY")
os.environ["HUGGINGFACE_TOKEN"] = os.getenv("HUGGINGFACE_TOKEN")

if os.environ.get("NVIDIA_API_KEY"):
    print("NVIDIA API key loaded successfully")
else:
    print("Failed to load NVIDIA API key")

if os.environ.get("GOOGLE_API_KEY"):
    print("Google API key loaded successfully")
else:
    print("Failed to load Google API key")

if os.environ.get("HUGGINGFACE_TOKEN"):
    print("Hugging Face token loaded successfully")
else:
    print("Failed to load Hugging Face token")

NVIDIA API key loaded successfully
Google API key loaded successfully
Hugging Face token loaded successfully


### RAG Ingestion Pipeline - Document Loading & Processing Component


In [3]:
def find_all_json_files(base_path: str) -> List[str]:
    json_files = []
    for root, _, files in os.walk(base_path):
        if 'diagnostic_kg' in root:      # No KG files
            continue
        for f in files:
            if f.endswith('.json'):
                json_files.append(os.path.join(root, f))
    return json_files

def read_json_file(path: str) -> Dict[str, Any]:
    try:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        print(f"Error reading {path}: {e}")
        return {}

def extract_categories_from_path(path: str) -> Dict[str, str]:
    parts = path.split(os.sep)
    try:
        idx = parts.index('Finished')
        category = parts[idx+1] if len(parts) > idx+1 else "Unknown"
        subcategory = parts[idx+2] if len(parts) > idx+2 else "None"
        return {
            "category": category,
            "subcategory": subcategory
        }
    except ValueError:
        return {"category": "Unknown", "subcategory": "Unknown"}

def flatten_json_to_text(data: Dict[str, Any], parent_key='') -> str:
    """Convert nested JSON to a flattened text format optimized for clinical notes."""
    result = []
    
    # Clinical note section headers
    section_headers = {
        'input1': 'CHIEF COMPLAINT',
        'input2': 'HISTORY OF PRESENT ILLNESS',
        'input3': 'PAST MEDICAL HISTORY',
        'input4': 'FAMILY HISTORY',
        'input5': 'PHYSICAL EXAMINATION',
        'input6': 'PERTINENT RESULTS'
    }
    
    def _flatten(obj, prefix=''):
        if isinstance(obj, dict):
            for input_key in sorted(section_headers.keys()):
                if input_key in obj and obj[input_key]:
                    result.append(f"\n=== {section_headers[input_key]} ===")
                    result.append(obj[input_key].strip())
            
            for k in sorted(obj.keys()):
                if k not in section_headers:
                    v = obj[k]
                    new_key = f"{prefix}.{k}" if prefix else k
                    if isinstance(v, (dict, list)):
                        _flatten(v, new_key)
                    elif v:  
                        result.append(f"{new_key}: {v}")
        elif isinstance(obj, list):
            for i, item in enumerate(obj):
                _flatten(item, f"{prefix}[{i}]")
        elif obj and prefix: 
            result.append(f"{prefix}: {obj}")
    
    _flatten(data)
    return "\n".join(result)

def create_documents_from_json(files: List[str]) -> List[Document]:
    docs = []
    for path in files:
        data = read_json_file(path)
        if not data: continue
        
        meta = extract_categories_from_path(path)
        
        flattened_text = flatten_json_to_text(data)
        
        header = f"""
=== CLINICAL NOTE ===
Category: {meta['category']}
Subcategory: {meta['subcategory']}
Source: {os.path.basename(path)}

{flattened_text}
"""
        
        doc = Document(
            page_content=header.strip(),
            metadata={
                "source": path,
                "category": meta["category"],
                "subcategory": meta["subcategory"],
                "filename": os.path.basename(path),
                "document_type": "clinical_note"
            }
        )
        docs.append(doc)
    
    return docs

# Document Pipeline Execution (RAG Ingestion)
BASE_PATH = "mimic-iv-ext-direct-1.0.0/mimic-iv-ext-direct-1.0.0/samples/Finished"
json_files = find_all_json_files(BASE_PATH)
documents = create_documents_from_json(json_files)


In [4]:
print(f"Total documents created: {len(documents)}")
if documents:
    print("\nSample document preview:")
    sample = documents[0]
    print(f"Source: {sample.metadata['source']}")
    print(f"Category: {sample.metadata['category']}")
    print(f"Content length: {len(sample.page_content)} characters")
    print(f"Content preview: {sample.page_content}...\n")

Total documents created: 511

Sample document preview:
Source: mimic-iv-ext-direct-1.0.0/mimic-iv-ext-direct-1.0.0/samples/Finished/Acute Coronary Syndrome/NSTEMI/17183564-DS-13.json
Category: Acute Coronary Syndrome
Content length: 2792 characters
Content preview: === CLINICAL NOTE ===
Category: Acute Coronary Syndrome
Subcategory: NSTEMI
Source: 17183564-DS-13.json


=== CHIEF COMPLAINT ===
Chest Pain

=== HISTORY OF PRESENT ILLNESS ===
a man with PMHx of HTN and chronic hepatitis C who presents with substernal chest pain. The patient reports that he was sitting in a recliner this AM drinking coffee, when he began to experience chest pain. The pain dissipated sponatneously after a few minutes, and then returned for approximately 30 minutes. The pain resolved immediately with nitro and aspirin administered by the EMTs. Reports tingling down both extremities and associated diaphoresis and shortness of breath. Denies nausea, vomiting, palpitations, and loss of conscioussness. The patien

### RAG Text Splitting - Chunking Component 

In [5]:
# text_splitter = RecursiveCharacterTextSplitter(
#     chunk_size=1500,  
#     chunk_overlap=300,  
#     length_function=len,
#     separators=[
#         "\n=== ", # Split on section headers first
#         "\n\n",   # Then paragraphs
#         "\n",     # Then lines
#         ". ",     # Then sentences
#         ", ",     # Then clauses
#         " ",      # Then words
#         ""
#     ]
# )

# # Filter out tiny chunks that don't contain meaningful content
# chunks = text_splitter.split_documents(documents)

# meaningful_chunks = [chunk for chunk in chunks if len(chunk.page_content.strip()) > 100]  

# # Print statistics about the chunks
# print(f"Original documents: {len(documents)}")
# print(f"Raw chunks after splitting: {len(chunks)}")
# print(f"Filtered chunks (removing tiny chunks): {len(meaningful_chunks)}")



In [6]:
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")

semantic_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",token="huggingface_token")

similarity_threshold = 0.7  # Tune this

def semantic_chunk(text: str, max_chunk_size=1500) -> List[str]:
    sentences = text.split(". ")
    embeddings = semantic_model.encode(sentences, convert_to_tensor=True)

    chunks = []
    current_chunk = []
    current_chunk_tokens = 0

    for i in range(len(sentences)):
        sentence = sentences[i]
        current_chunk.append(sentence)
        current_chunk_tokens += len(sentence)

        # Determine whether to split based on similarity to next sentence
        if i < len(sentences) - 1:
            sim = util.cos_sim(embeddings[i], embeddings[i + 1]).item()
            if sim < similarity_threshold or current_chunk_tokens > max_chunk_size:
                chunks.append(". ".join(current_chunk).strip() + ".")
                current_chunk = []
                current_chunk_tokens = 0

    if current_chunk:
        chunks.append(". ".join(current_chunk).strip() + ".")

    return chunks

chunks = []
for doc in documents:
    semantic_chunks = semantic_chunk(doc.page_content)
    for chunk_text in semantic_chunks:
        if len(chunk_text.strip()) > 100:  # Filter tiny chunks
            chunks.append(Document(page_content=chunk_text.strip(), metadata=doc.metadata))

### RAG Vector Embedding - Embedding Component 


In [7]:
# class ClinicalEmbeddings(Embeddings):
#     def embed_documents(self, texts: List[str]) -> List[List[float]]:
#         return [self.embed_query(text) for text in texts]

#     def embed_query(self, text: str) -> List[float]:
#         text = f"""Clinical Context:
#         Patient Information and History: {text}
        
#         Analysis Framework:
#         - Primary symptoms and presentations
#         - Relevant medical, surgical, and family history
#         - Physical examination findings
#         - Diagnostic considerations (labs, imaging, specialized tests)
#         - Differential diagnoses across medical specialties
#         - Treatment implications and follow-up plan"""
        
#         inputs = tokenizer(
#             text, 
#             return_tensors="pt", 
#             padding=True, 
#             truncation=True, 
#             max_length=768  
#         )
#         inputs = {key: val.to(device) for key, val in inputs.items()}
        
#         with torch.no_grad():
#             outputs = model(**inputs)
#             # Using mean pooling instead of just CLS token
#             embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0]
        
#         return embeddings.tolist()

# clinical_embeddings = ClinicalEmbeddings()

In [8]:
def get_embeddings(embedding_type: str = "google") -> Embeddings:
    if embedding_type == "google":
        return GoogleGenerativeAIEmbeddings(
            model="models/embedding-001",
            task_type="retrieval_document"
        )
    else:  # Default to clinical embeddings
        # Set up Clinical ModernBERT
        model = AutoModel.from_pretrained('Simonlee711/Clinical_ModernBERT')
        tokenizer = AutoTokenizer.from_pretrained('Simonlee711/Clinical_ModernBERT')
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        model.eval()
        return ClinicalEmbeddings()

### RAG Storage & Indexing - Vector Store Component 


In [9]:
db_directory = "chroma_db"  # Directory to store Chroma database

import shutil
if os.path.exists(db_directory):
    print(f"Removing existing database directory: {db_directory}")
    shutil.rmtree(db_directory)

embeddings = get_embeddings("google")

vectorstore = Chroma.from_documents(
    documents=chunks,
    embedding=embeddings,
    persist_directory=db_directory
)

vectorstore.persist()

  vectorstore.persist()


### RAG Retriever - Retrieval Component 


In [10]:
retriever = vectorstore.as_retriever(
    search_type="mmr",  # Using MMR for better diversity in results
    search_kwargs={
        "k": 5,  # Number of documents to retrieve
        "fetch_k": 10,  # Fetch more documents initially for MMR to choose from
        "lambda_mult": 0.7,  # Balance between relevance and diversity
    }
)

### RAG Prompt Engineering - Prompt Template Component


In [11]:
template = """You are an expert medical AI assistant specialized in clinical decision support. Use the following retrieved medical documents to answer the question.

Here's an example of how to analyze a medical query:
Example Query: "I have chest pain and shortness of breath"
Reasoning Process:
1. Consider urgent vs non-urgent presentation
2. Analyze key symptoms and their patterns
3. Review risk factors and medical history
4. Evaluate differential diagnoses
5. Determine appropriate level of care

Example Response:
Key Findings: Acute chest pain with dyspnea suggests cardiopulmonary etiology
Clinical Interpretation: Given symptoms, must rule out acute coronary syndrome
Recommendations: Immediate emergency evaluation recommended
Chain of Thought: Chest pain + shortness of breath → possible cardiac/pulmonary cause → risk of ACS → requires urgent assessment

CONTEXT DOCUMENTS:
{context}

CLINICAL QUERY: {question}

Please provide a detailed medical response following this structure:
1. Key Findings: Summarize the most relevant information
2. Clinical Interpretation: Analyze the medical significance
3. Chain of Thought: Show your reasoning process step by step
4. Recommendations: Suggest evidence-based approaches
5. References: Cite specific sections from the provided context

If you cannot provide a complete answer based on the available context, explicitly state the limitations.

Response:"""

In [12]:
# document formatting with clinical section emphasis
def format_docs(docs):
    formatted_sections = []
    for i, doc in enumerate(docs, 1):
        # Extract key clinical sections
        sections = doc.page_content.split("===")
        formatted = f"\nSOURCE {i}:\n"
        formatted += f"Category: {doc.metadata['category']}\n"
        formatted += f"Document Type: {doc.metadata['document_type']}\n"
        formatted += "-" * 40 + "\n"
        formatted += "\n".join(section.strip() for section in sections if section.strip())
        formatted_sections.append(formatted)
    return "\n\n".join(formatted_sections)

### RAG Generator/LLM - Generation Component 

In [13]:
# Initialize the LLM
# Palmyra-Med-70B-32k
llm = ChatNVIDIA(model="writer/palmyra-med-70b-32k")

# Create the RAG prompt
prompt = ChatPromptTemplate.from_template(template)


### RAG Orchestration - RAG Pipeline Integration Component 


In [14]:
rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

### RAG Query Handling - Query Processing Component 

In [15]:

def process_medical_query(query: str, min_sources: int = 5, confidence_threshold: float = 0.7) -> dict:
    try:
        # Get retrieved documents and their similarity scores (distances)
        retrieved_docs_with_scores = retriever.vectorstore.similarity_search_with_score(query, k=min_sources)
        
        # Check source adequacy
        if len(retrieved_docs_with_scores) < min_sources:
            return {
                "answer": "Insufficient clinical evidence found. Please refine the query.",
                "sources": [],
                "confidence": 0.0,
                "warnings": ["Insufficient source documents found"]
            }

        # Separate documents and distances (lower = more similar)
        retrieved_docs = [doc for doc, distance in retrieved_docs_with_scores]
        distances = [distance for doc, distance in retrieved_docs_with_scores]

        # Enhanced confidence scoring with dynamic thresholding
        similarity_scores = [1 / (1 + d) for d in distances]
        avg_similarity = sum(similarity_scores) / len(similarity_scores)
        std_similarity = (sum((s - avg_similarity) ** 2 for s in similarity_scores) / len(similarity_scores)) ** 0.5
        
        # Adjust confidence threshold based on query complexity
        adjusted_threshold = confidence_threshold - (0.1 if len(query.split()) > 10 else 0)
        
        # Calculate final confidence score
        confidence = avg_similarity * (1 - std_similarity)  # Penalize high variance
        
        # Generate response
        response = rag_chain.invoke(query)

        # Enhanced warning system
        warnings = []
        if confidence < adjusted_threshold:
            warnings.append("Low confidence response - please verify with healthcare provider")
        if std_similarity > 0.3:  # High variance in relevance
            warnings.append("Mixed relevance in sources - interpretation may be limited")

        return {
            "answer": response,
            "sources": [doc.metadata.get("source", "unknown") for doc in retrieved_docs],
            "confidence": round(confidence, 2),
            "reasoning_confidence": round(avg_similarity, 2),
            "source_consistency": round(1 - std_similarity, 2),
            "warnings": warnings
        }

    except Exception as e:
        return {
            "answer": f"Error processing clinical query: {str(e)}",
            "sources": [],
            "confidence": 0.0,
            "warnings": ["Processing error occurred"]
        }


### Testing

In [16]:
test_queries = [
    "I have a headache and nausea from past week. What could be the cause?",
    "I’ve had lower back pain that radiates down my leg. Is this a sign of a nerve issue?",
    "I feel dizzy when I stand up and sometimes my vision gets blurry. Should I be concerned?",
    "My child has a sore throat and high fever. Could it be strep or something else?",
    "I’ve noticed blood in my stool. Should I be worried about colon cancer?",
]

print("\nRunning clinical test queries:")
for query in test_queries:
    print(f"\nCLINICAL QUERY: {query}")
    result = process_medical_query(query)
    print("\nRESPONSE:")
    print(result["answer"])
    print("\nSOURCES USED:", len(result["sources"]))
    print("CONFIDENCE SCORE:", f"{result['confidence']:.2f}")
    if result["warnings"]:
        print("WARNINGS:", ", ".join(result["warnings"]))
    print("-" * 80)


Running clinical test queries:

CLINICAL QUERY: I have a headache and nausea from past week. What could be the cause?

RESPONSE:
Query Analysis:

Key Findings: Headache and nausea from past week. No mention of fever or trauma to the head.
Clinical Interpretation: Migraine or tension headaches could be the cause of persistent headaches. Presence of nausea may suggest involvement of the gastrointestinal system. Hormonal imbalance and zygomatic sinusitis can be less likely causes.
Chain of Thought: 1. Headache + nausea → Rule out causes of tension headache, migraine, and sinusitis.
2. Is associated with trauma or fever? → If yes, sinus infection or subdural hematoma should be considered.
3. Duration of symptoms ≥ 1 week → Could be a migraine or tension headache if not improving with self-care.
4. Check previous medical history for hormonal imbalances, thyroid disorders, and migraines which might require specific management.
Recommendations: 1. Initial management of tension headaches incl

In [17]:
import sys
sys.path.append(os.path.abspath('.')) 

### Evaluation

In [18]:
# === LLM Setup ===
gemini_llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001", temperature=0)

def judge_relevance(note: str, answer: str) -> int:
    """
    Ask Gemini to rate how well the RAG-generated answer matches the content of the note.
    """
    prompt = f"""
You are a clinical QA evaluator.

Clinical Note:
{note}

RAG-Generated Answer:
{answer}

Question: Based only on the information in the clinical note, how relevant and accurate is this answer?
Rate from 1 (irrelevant/incorrect) to 5 (very accurate and relevant).
Output only the integer.
"""
    try:
        resp = gemini_llm.invoke(prompt)
        return int(resp.content.strip())
    except Exception as e:
        print(f"[Gemini Parse Error] {e}")
        return 0


BASE_PATH = "mimic-iv-ext-direct-1.0.0/mimic-iv-ext-direct-1.0.0/samples/Finished"
MAX_DOCS = 100

print("Loading documents...")
json_files = find_all_json_files(BASE_PATH)
documents = create_documents_from_json(json_files)

if MAX_DOCS and MAX_DOCS > 0:
    documents = documents[:MAX_DOCS]

print(f"Loaded {len(documents)} clinical notes.\n")

results: List[Dict[str, Any]] = []

print("Evaluating...\n")
for doc in tqdm(documents, desc="Evaluating Notes"):
    note_text = doc.page_content
    filename = doc.metadata.get("filename", "unknown.json")

    try:
        rag_out = process_medical_query(note_text, min_sources=5)
        answer = rag_out.get("answer", "")
        confidence = rag_out.get("confidence", None)
        relevance_score = judge_relevance(note_text, answer)
    except Exception as e:
        print(f"[Error] Failed on {filename}: {e}")
        continue

    results.append({
        "file": filename,
        "category": doc.metadata.get("category"),
        "subcategory": doc.metadata.get("subcategory"),
        "relevance_score": relevance_score,
        "confidence": confidence,
        "answer": answer,
    })

if results:
    avg_relevance = np.mean([r["relevance_score"] for r in results])
    confidences = [r["confidence"] for r in results if r["confidence"] is not None]
    avg_confidence = np.mean(confidences) if confidences else None

    print(f"\n=== Evaluation Summary ({len(results)} samples) ===")
    print(f"Mean Relevance Score (1–5): {avg_relevance:.2f}")
    if avg_confidence is not None:
        print(f"Mean Model Confidence     : {avg_confidence:.2f}")
else:
    print("\nNo successful evaluations.")


Loading documents...
Loaded 100 clinical notes.

Evaluating...



Evaluating Notes: 100%|██████████| 100/100 [16:03<00:00,  9.64s/it]


=== Evaluation Summary (100 samples) ===
Mean Relevance Score (1–5): 3.41
Mean Model Confidence     : 0.81



