# Healthcare RAG System Lab
## Overview

In this lab, you'll take on the role of a junior data scientist at a healthcare technology company that specializes in creating educational resources for patients. Your team has been tasked with developing a system that can automatically generate informative responses to common patient questions about medical conditions, treatments, and wellness practices.

The challenge is to ensure these responses are both accurate and grounded in authoritative medical information. Your specific assignment is to implement a Retrieval-Augmented Generation (RAG) system that can:
1. Understand patient questions about various health topics
2. Retrieve relevant information from a trusted knowledge base
3. Generate helpful, accurate responses based on that information
4. Avoid "hallucinated" content that could potentially misinform patients

This lab follows the generative AI implementation process we've studied, with particular focus on:
- Data Strategy and Knowledge Foundation
- Model Selection and Generation Control
- Evaluation Framework Development

## Setup

First, let's import the necessary libraries:

In [1]:
import torch
import pandas as pd
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


## Part 1: Knowledge Base Setup

Let's create a sample medical knowledge base with information about common health conditions, treatments, and wellness practices:

In [None]:
# Create a sample medical knowledge base
knowledge_base = pd.DataFrame({
    'content': [
        "Diabetes is a chronic condition that affects how your body turns food into energy. There are three main types: Type 1, Type 2, and gestational diabetes. Type 2 diabetes is the most common form, accounting for about 90-95% of diabetes cases.",
        "Type 1 diabetes is an autoimmune reaction that stops your body from making insulin. Symptoms include increased thirst, frequent urination, hunger, fatigue, and blurred vision. It's usually diagnosed in children, teens, and young adults.",
        "Type 2 diabetes occurs when your body becomes resistant to insulin or doesn't make enough insulin. Risk factors include being overweight, being 45 years or older, having a parent or sibling with type 2 diabetes, and being physically active less than 3 times a week.",
        "Managing diabetes involves monitoring blood sugar levels, taking medications as prescribed, eating a healthy diet, maintaining a healthy weight, and getting regular physical activity. It's important to work with healthcare providers to develop a management plan.",
        "Hypertension, or high blood pressure, is when the force of blood pushing against the walls of your arteries is consistently too high. It's often called the 'silent killer' because it typically has no symptoms but significantly increases the risk of heart disease and stroke.",
        "Blood pressure is measured using two numbers: systolic (top number) and diastolic (bottom number). Normal blood pressure is less than 120/80 mm Hg. Hypertension is diagnosed when readings are consistently 130/80 mm Hg or higher.",
        "Lifestyle changes to manage hypertension include reducing sodium in your diet, getting regular physical activity, maintaining a healthy weight, limiting alcohol, quitting smoking, and managing stress. Medications may also be prescribed if lifestyle changes aren't enough.",
        "Regular physical activity offers numerous health benefits, including weight management, reduced risk of heart disease, strengthened bones and muscles, improved mental health, and enhanced ability to perform daily activities. Adults should aim for at least 150 minutes of moderate-intensity activity per week.",
        "A balanced diet should include a variety of fruits, vegetables, whole grains, lean proteins, and healthy fats. It's recommended to limit intake of added sugars, sodium, saturated fats, and processed foods. Proper nutrition helps prevent chronic diseases and supports overall health.",
        "Vaccination is one of the most effective ways to prevent infectious diseases. Vaccines work by helping the body recognize and fight specific pathogens. Common adult vaccines include influenza (flu), Tdap (tetanus, diphtheria, pertussis), shingles, and pneumococcal vaccines."
    ],
    'metadata': [
        {'topic': 'diabetes', 'subtopic': 'overview', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'diabetes', 'subtopic': 'type1', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'diabetes', 'subtopic': 'type2', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'diabetes', 'subtopic': 'management', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'hypertension', 'subtopic': 'overview', 'source': 'medical_guidelines', 'last_updated': '2023-07-22'},
        {'topic': 'hypertension', 'subtopic': 'diagnosis', 'source': 'medical_guidelines', 'last_updated': '2023-07-22'},
        {'topic': 'hypertension', 'subtopic': 'management', 'source': 'medical_guidelines', 'last_updated': '2023-07-22'},
        {'topic': 'wellness', 'subtopic': 'physical_activity', 'source': 'health_promotion', 'last_updated': '2023-05-15'},
        {'topic': 'wellness', 'subtopic': 'nutrition', 'source': 'health_promotion', 'last_updated': '2023-05-15'},
        {'topic': 'prevention', 'subtopic': 'vaccination', 'source': 'medical_guidelines', 'last_updated': '2023-08-05'}
    ]
})

print(f"Knowledge base loaded with {len(knowledge_base)} entries")
knowledge_base.head(2)

### Task 1: Create Document Embeddings

Complete the function below to create embeddings for each document in the knowledge base. These embeddings will be used to find relevant documents based on patient queries.

In [None]:
def create_document_embeddings(documents):
    """
    Create embeddings for a list of documents.
    
    Args:
        documents: List of text documents to embed
        
    Returns:
        Numpy array of document embeddings
    """
    # Initialize the sentence transformer model
    embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

    # Generate embeddings for all documents
    document_embeddings = embedding_model.encode(documents)  # convert_to_numpy=True ?
    return document_embeddings

# Extract document content
documents = knowledge_base['content'].tolist()

# Create document embeddings
document_embeddings = create_document_embeddings(documents)

# Verify the shape of embeddings
if document_embeddings is not None:
    print(f"Generated embeddings with shape: {document_embeddings.shape}")
else:
    print("Embeddings not created yet.")

## Part 2: Implementing the Retrieval Component

Now, let's implement the function to retrieve relevant documents based on a patient query.

In [None]:
def retrieve_documents(query, embeddings, contents, metadata, top_k=3, threshold=0.3):
    """
    Retrieve the most relevant documents for a given query.
    
    Args:
        query: The patient's question
        embeddings: The precomputed document embeddings
        contents: The text content of the documents
        metadata: The metadata for each document
        top_k: Maximum number of documents to retrieve
        threshold: Minimum similarity score to include a document
        
    Returns:
        List of (content, metadata, similarity_score) tuples
    """
    # Initialize the embedding model (same as in create_document_embeddings)
    embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

    # Embed the query
    query_embedding = embedding_model.encode(query) # convert_to_numpy=True ?

    # Calculate similarity scores between query and all documents
    similarities = cosine_similarity([query_embedding], embeddings)[0]

    # Filter by threshold and get top k results
    sorted_indices = np.argsort(similarities)[::-1]  # Sort by highest similarity
    top_indices = [idx for idx in sorted_indices if similarities[idx] >= threshold][:top_k]

    # Return the top documents with their metadata and scores
    results = [(contents[idx], metadata[idx], similarities[idx]) for idx in top_indices]

    return results

# Test the retrieval function with a sample query
if document_embeddings is not None:
    sample_query = "What are the symptoms of Type 1 diabetes?"
    retrieved_docs = retrieve_documents(
        query=sample_query,
        embeddings=document_embeddings,
        contents=documents,
        metadata=knowledge_base['metadata'].tolist(),
        top_k=2
    )
    
    print(f"Query: {sample_query}")
    print("\nRetrieved Documents:")
    for i, (content, meta, score) in enumerate(retrieved_docs):
        print(f"{i+1}. [{score:.4f}] {content[:100]}...")
        print(f"   Topic: {meta['topic']}, Subtopic: {meta['subtopic']}")
else:
    print("Cannot test retrieval without document embeddings.")

## Part 3: Building the Generation Component

Now, let's implement the generation component that will use the retrieved documents to create informative responses.

In [None]:
# Initialize the generative model
def initialize_generator(model_name="gpt2"):
    """
    Initialize the generative model and tokenizer.
    
    Args:
        model_name: Name of the pretrained model to use
        
    Returns:
        Tokenizer and model objects
    """
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    # Set padding token if needed
    # Check if pad_token exists, if not set it to eos_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    return tokenizer, model

# Initialize the generator
tokenizer, model = initialize_generator()
if tokenizer and model:
    print(f"Initialized {model.config._name_or_path} with {model.num_parameters()} parameters")

In [None]:
def generate_rag_response(query, contents, metadata, document_embeddings, 
                          tokenizer, model, max_length=100):
    """
    Generate a response using Retrieval-Augmented Generation.
    
    Args:
        query: The patient's question
        contents: List of document contents
        metadata: List of document metadata
        document_embeddings: Precomputed embeddings for the documents
        tokenizer: The tokenizer for the language model
        model: The language model for generation
        max_length: Maximum response length
        
    Returns:
        Dictionary with the generated response and the retrieved documents
    """
    # Retrieve relevant documents for the query
    retrieved_docs = retrieve_documents(query, document_embeddings, contents, 
                                        metadata, top_k=3, threshold=0.3)

    # Format prompt with retrieved context
    if retrieved_docs:
        # Extract content from tuples
        context = "\n".join([doc[0] for doc in retrieved_docs])
        prompt = f"Patient Question: {query}\n\nRelevant Medical Information:\n{context}\n\nHelpful Response:"
    else:
        prompt = f"Patient Question: {query}\n\nHelpful Response:"
    
    # Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
    
    # Generate the response
    output_sequences = model.generate(**inputs, max_length=max_length, temperature=0.7, top_k=50, do_sample=True)
    
    # Decode the response and extract the generated text
    response = tokenizer.decode(output_sequences[0], skip_special_tokens=True)

    # Return the results
    return {
        "query": query,
        "response": response,
        "retrieved_documents": retrieved_docs
        # "retrieved_documents": [(doc[0], doc[1]) for doc in retrieved_docs]
    }


# Test the RAG system with several queries
if document_embeddings is not None and tokenizer and model:
    test_queries = [
        "What are the different types of diabetes?",
        "How can I manage my high blood pressure through lifestyle changes?",
        "Why is regular physical activity important for health?",
        "What vaccines should adults consider getting?"
    ]
    
    for query in test_queries:
        print(f"\nQuery: {query}")
        result = generate_rag_response(
            query=query,
            contents=documents,
            metadata=knowledge_base['metadata'].tolist(),
            document_embeddings=document_embeddings,
            tokenizer=tokenizer,
            model=model
        )
        
        print("\nRetrieved Documents:")
        for i, (doc, meta, score) in enumerate(result["retrieved_documents"]):
            print(f"{i+1}. [{score:.4f}] Topic: {meta['topic']}, Subtopic: {meta['subtopic']}")
        
        print(f"\nGenerated Response:\n{result['response']}")
        print("-" * 80)
else:
    print("Cannot test generation without embeddings or model.")

## Part 4: Evaluation and Analysis

Let's implement a basic evaluation function to assess the quality of our generated responses.

In [None]:
def evaluate_response(response_data):
    """
    Evaluate the quality of a generated response based on various criteria.
    
    Args:
        response_data: Dictionary containing the query, response, and retrieved docs
        
    Returns:
        Dictionary of evaluation metrics
    """
    query = response_data["query"]
    response = response_data["response"]
    retrieved_docs = response_data["retrieved_documents"]

    medical_terms = {
        "diabetes", "insulin", "glucose", "hypertension", "blood pressure",
        "systolic", "diastolic", "cardiovascular", "cholesterol", "nutrition",
        "obesity", "physical activity", "vaccination", "immune", "prevention"
    }

    # 1. Content Relevance: Check if response mentions terms from retrieved docs
    retrieved_text = " ".join([doc[0] for doc in retrieved_docs])
    relevance_score = sum(1 for term in query.split() if term.lower() in retrieved_text.lower()) / len(query.split())

    # 2. Response Length Appropriateness: Compare length to query length
    # Ratio > 1 indicates response is longer than query, < 1 indicates shorter
    # Could add nltk tokenizer
    response_length = len(response.split())
    query_length = len(query.split())
    length_ratio = response_length / query_length if query_length > 0 else 0

    # 3. Medical Terminology Usage: Count medical terms present
    medical_term_count = sum(1 for term in response.split() if term.lower() in medical_terms)
    medical_term_usage = medical_term_count / response_length if response_length > 0 else 0

    metrics = {
        "content_relevance": relevance_score,  # Higher is better
        "length_appropriateness": length_ratio,  # Ideally balanced (not too short or too long)
        "medical_terminology_usage": medical_term_usage,  # Should be a reasonable fraction
        # "diabetes diabetes diabetes" seems relevant to "diabetes" but a little jargony
    }

    return metrics

# Evaluate the responses for our test queries
if 'test_queries' in locals() and document_embeddings is not None and tokenizer and model:
    for query in test_queries:
        result = generate_rag_response(
            query=query,
            contents=documents,
            metadata=knowledge_base['metadata'].tolist(),
            document_embeddings=document_embeddings,
            tokenizer=tokenizer,
            model=model
        )
        
        metrics = evaluate_response(result)
        print(f"Query: {query}")
        print(f"Evaluation Metrics: {metrics}")
        print("-" * 80)
else:
    print("Cannot evaluate without test queries or necessary components.")

## Reflection Questions

Answer the following questions about your RAG implementation and its potential applications in healthcare:

### How does the RAG approach improve factual accuracy compared to regular generation?

RAG solves the "overconfident guesser" problem that defines more general models such as ChatGPT. With a typical LLM providing medical records, it would often get things right, but it would sometimes make bizarre mistakes that even someone with basic medical knowledge would question. And it would sound 100% confident while doing so. A regular generation model would tend to make claims like "oh yes, I have your medical records right here" when it has not been provided any information about the patient - because it thinks that is what it's supposed to say.

With Retrieval-Augmented Generation, we substantially reduce the dimensionality/complexity of the problem space, constraining the models to getting information from verified sources. 

### What are potential challenges or limitations of your current implementation?

No information with which to answer since it does not run.

### How might you enhance this system for a production healthcare environment?

First, there are several important areas that cannot be seen as "enhancements" but should be design aspects from the beginning of the project - diversity of training, regulatory compliance, and privacy/HIPAA should shape the design process, not be added on near the end as an afterthought.

We can make it: 
 * Better - communicate with other healthcare systems to improve automation
 * Stronger - increase system resilience through careful coding and testing
 * Faster - deployment on scalable cloud architecture like Kubernetes
 
We have the technology.

[Your answer here]

### What ethical considerations are particularly important for healthcare content generation?

An incorrect model prediction could, in the worst case scenario, cause the death of a patient. Because of this fundamental fact, all healthcare decisions and recommendations must be approached seriously, using credible scientific information that will provide the best patient outcomes.

Healthcare is an intensely personal field where privacy is essential. Imagine if you overheard your doctor gossiping to another patient about **your** deeply person and private medical condition. You would likely lose trust in your doctor, possibly even leading you to not get essential care. Patients using model generated content need confidence that their communications are confidential.