## Medical Question answering with Retrieval Augmented Generation design pattern. 
Use Python 3 (Data Science 3.0) kernel image and `ml.m5.2xlarge` for this notebook.

This includes generating embeddings of all existing documents, indexing them in a vector store. Then for every user query, generate local embeddings and search based on embedding distance. The search responses act as context to the LLM model to generate a output. 

Challenges:
How to manage large document(s) that exceed the token limitHow to find the document(s) relevant to the question being asked

## Key components

LLM (Large Language Model): Mistral-7b-instruct available through Amazon SageMaker This model will be used to understand the document chunks and provide an answer in human friendly manner.

Embeddings Model: BGE Small available through Amazon SageMaker. This model will be used to generate a numerical representation of the textual documents.

Vector Store: FAISS available through LangChainIn this notebook we are using this in-memory vector-store to store both the embeddings and the documents. In an enterprise context this could be replaced with a persistent store such as AWS OpenSearch, RDS Postgres with pgVector, ChromaDB, Pinecone or Weaviate.

Index: VectorIndex The index helps to compare the input embedding and the document embeddings to find relevant document

### Dataset
To explain this architecture pattern we are using the documents from MedQA. These documents include medical textbooks such as:
Pathology, Anatomy, Pharmacology and others. 

Download textbooks that are part of Q&A dataset MedQA released as part of Jin, Di, et al. "What Disease does this Patient Have? A Large-scale Open Domain Question Answering Dataset from Medical Exams." arXiv preprint arXiv:2009.13081 (2020). 

More details are available here https://github.com/jind11/MedQA

* Data source : @article{jin2020disease,
  title={What Disease does this Patient Have? A Large-scale Open Domain Question Answering Dataset from Medical Exams},
  author={Jin, Di and Pan, Eileen and Oufattole, Nassim and Weng, Wei-Hung and Fang, Hanyi and Szolovits, Peter},
  journal={arXiv preprint arXiv:2009.13081},
  year={2020} }
  
  

### Data preparation

> **NOTICE**: "This link leads to a Third-Party Dataset. AWS does not own, nor does it have any control over the Third-Party Dataset. You should perform your own independent assessment, and take measures to ensure that you comply with your own specific quality control practices and standards, and the local rules, laws, regulations, licenses and terms of use that apply to you, your content, and the Third-Party Dataset. AWS does not make any representations or warranties that the Third-Party Dataset is secure, virus-free, accurate, operational, or compatible with your own environment and standards. AWS does not make any representations, warranties or guarantees that any information in the Third-Party Dataset will result in a particular outcome or result."

1. The full dataset can be downloaded can be seen from https://drive.google.com/file/d/1ImYUSLk9JbgHXOemfvyiDiirluZHPeQw/view?usp=sharing. You can read more about the dataset in https://github.com/jind11/MedQA#data.
2. To speed up the uploading for this lab, a smaller version of dataset is already downloaded -  https://d2qrbbbqnxtln.cloudfront.net/Pathology_Robbins.txt

##### Prerequisites

In [1]:
%pip install faiss-cpu==1.7.4 --quiet

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jupyter-ai 2.29.1 requires faiss-cpu!=1.8.0.post0,<2.0.0,>=1.8.0, but you have faiss-cpu 1.7.4 which is incompatible.[0m[31m
Note: you may need to restart the kernel to use updated packages.


In [2]:
# %pip install langchain==0.0.222 --quiet

In [2]:
%%capture 

!pip install PyYAML --quiet

In [3]:
!pip install --upgrade pydantic langchain --quiet


#### Imports

In [4]:
import boto3
import json
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import TextLoader
from langchain.chains import RetrievalQA
from langchain.llms import SagemakerEndpoint
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.vectorstores import FAISS
from typing import Any, Dict, List, Optional
import os
import logging
import requests
import yaml
import faiss

##### Setup logging

In [5]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [6]:
logger.info(f'Using requests=={requests.__version__}')
logger.info(f'Using pyyaml=={yaml.__version__}')

Using requests==2.32.3
Using pyyaml==6.0.2


#### Setup essentials

In [29]:
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-sentencesimilarity-20250327-021732' #INSERT EMBEDDING ENDPOINT NAME IF DIFFERENT
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-llm-mistral-7b-ins-20250327-022258' #INSERT TEXT GENERATION ENDPOINT NAME IF DIFFERENT

REGION_NAME = boto3.session.Session().region_name

#### Encode passages (chunks) using JumpStart's GPT-J text embedding model . We are specifically using only 1 of 20 textbooks from the dataset. It takes about 6 minutes to generate embeddings for one textbook (for example, Pathology). You can increase the number of textbooks indexed by adding sufficient time buffer for execution. 

In order to follow the RAG approach this notebook is using the LangChain framework where it has integrations with different services and tools that allow efficient building of patterns such as RAG. 

In [8]:
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader, TextLoader

loader = DirectoryLoader("./", glob="**/Pathology*.txt", loader_cls=TextLoader)

documents = loader.load()
# - in our testing Character split works better with this PDF data set
text_splitter = RecursiveCharacterTextSplitter(
    # Set a really small chunk size, just to show.
    chunk_size = 1000,
    chunk_overlap  = 100,
)
docs = text_splitter.split_documents(documents)

In [9]:
print(docs[0])

page_content='Plasma Membrane: Protection and Nutrient Acquisition

Biosynthetic Machinery: Endoplasmic Reticulum and Golgi Apparatus

Waste Disposal: Lysosomes and Proteasomes

Modular Signaling Proteins, Hubs, and

Components of the Extracellular Matrix

Proliferation and the Cell Cycle

Pathology literally translates to the study of suffering (Greek pathos = suffering, logos = study); as applied to modern medicine, it is the study of disease. Virchow was certainly correct in asserting that disease originates at the cellular level, but we now realize that cellular disturbances arise from alterations in molecules (genes, proteins, and others) that influence the survival and behavior of cells. Thus, the foundation of modern pathology is understanding the cellular and molecular abnormalities that give rise to diseases. It is helpful to consider these abnormalities in the context of normal cellular structure and function, which is the theme of this introductory chapter.' metadata={'sourc

In [10]:
avg_doc_length = lambda documents: sum([len(doc.page_content) for doc in documents])//len(documents)
avg_char_count_pre = avg_doc_length(documents)
avg_char_count_post = avg_doc_length(docs)
print(f'Average length among {len(documents)} documents loaded is {avg_char_count_pre} characters.')
print(f'After the split we have {len(docs)} documents more than the original {len(documents)}.')
print(f'Average length among {len(docs)} documents (after split) is {avg_char_count_post} characters.')

Average length among 1 documents loaded is 3784898 characters.
After the split we have 5171 documents more than the original 1.
Average length among 5171 documents (after split) is 744 characters.


In [30]:
# Embedding Setup
# Embedding Setup
class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:
        """Compute doc embeddings using a SageMaker Inference Endpoint.

        Args:
            texts: The list of texts to embed.
            chunk_size: The chunk size defines how many input texts will
                be grouped together as request. If None, will use the
                chunk size specified by the class.

        Returns:
            List of embeddings, one for each text.
        """
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size

        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i : i + _chunk_size])
            print
            results.extend(response)
        return results


class ContentHandler(EmbeddingsContentHandler):  # Inherit from EmbeddingsContentHandler
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        # Use "embedding" mode for both documents and queries.
        input_str = json.dumps({"text_inputs": prompt, "mode": "embedding", **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[float]: #Expects to return a list of floats
        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        return embeddings


content_handler = ContentHandler()

sagemakerEndpointEmbeddingsJumpStart = SagemakerEndpointEmbeddingsJumpStart(
    endpoint_name=TEXT_EMBEDDING_MODEL_ENDPOINT_NAME,
    region_name=REGION_NAME,
    content_handler=content_handler
)


# Load Data and Split (important to call after sagemakerEndpointEmbeddingsJumpStart is initiated)
loader = TextLoader("./Pathology_Robbins.txt")  # Replace with your data file
documents = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
docs = text_splitter.split_documents(documents)

print(f"Number of chunks: {len(docs)}")



Number of chunks: 5171


In [31]:
print(docs[0].page_content)

Plasma Membrane: Protection and Nutrient Acquisition

Biosynthetic Machinery: Endoplasmic Reticulum and Golgi Apparatus

Waste Disposal: Lysosomes and Proteasomes

Modular Signaling Proteins, Hubs, and

Components of the Extracellular Matrix

Proliferation and the Cell Cycle

Pathology literally translates to the study of suffering (Greek pathos = suffering, logos = study); as applied to modern medicine, it is the study of disease. Virchow was certainly correct in asserting that disease originates at the cellular level, but we now realize that cellular disturbances arise from alterations in molecules (genes, proteins, and others) that influence the survival and behavior of cells. Thus, the foundation of modern pathology is understanding the cellular and molecular abnormalities that give rise to diseases. It is helpful to consider these abnormalities in the context of normal cellular structure and function, which is the theme of this introductory chapter.


#### Encode passages (chunks) using JumpStart's BGE text embedding model . We are specifically using only 1 of 20 textbooks from the dataset. It takes about 6 minutes to generate embeddings for one textbook (for example, Pathology). You can increase the number of textbooks indexed by adding sufficient time buffer for execution.

---
## Semantic Similarity with Amazon Jumpstart Embedding Models

Semantic search refers to searching for information based on the meaning and concepts of words and phrases, rather than just matching keywords. Embedding models like Amazon Titan Embeddings allow semantic search by representing words and sentences as dense vectors that encode their semantic meaning.

Semantic matching is extremely helpful for RAG because it returns results that are conceptually related to the user's query, even if they don't contain the exact keywords. This leads to more relevant and useful search results which can be injected into our LLM's prompts.

First, let's take a look below to illustrate the sample of an embedding

In [32]:
sample_embedding = np.array(sagemakerEndpointEmbeddingsJumpStart.embed_query(docs[0].page_content))
print("Sample embedding of a document chunk: ", sample_embedding)
print("Size of the embedding: ", sample_embedding.shape)

Sample embedding of a document chunk:  [-2.68548373e-02 -5.14841871e-03  8.22765604e-02  2.51330547e-02
 -5.26014669e-03 -5.87520450e-02  7.25993514e-02  5.83153265e-03
  1.96829159e-03 -1.03373043e-02  2.18124315e-02 -7.02288449e-02
  3.04869432e-02 -2.42926478e-02  1.07806809e-02  1.18084755e-02
 -4.04602120e-04  4.70750406e-02 -4.19485047e-02  2.68056672e-02
 -1.20089045e-02 -4.66680191e-02  7.61874812e-03 -7.36146793e-02
  3.83745246e-02  5.43740168e-02 -7.34584220e-03 -4.35480615e-03
 -3.49703655e-02 -1.43278882e-01  5.94927871e-04 -7.77929882e-03
  3.37314117e-03  3.28561850e-02  1.15412083e-02 -2.63297465e-02
 -1.76403560e-02 -4.29618685e-03 -3.88793573e-02  5.76247927e-03
  5.42547517e-02  8.73222575e-03 -4.60612774e-02  6.76298980e-03
  3.03594731e-02 -4.08621617e-02  3.95240774e-03 -2.54773833e-02
  4.18353826e-02 -4.78082784e-02 -4.71962839e-02 -1.01800738e-02
  1.61627885e-02  4.23433036e-02  4.70918603e-03 -1.95064899e-02
  5.22197708e-02  3.16964127e-02  5.02787903e-02  4

Now create embeddings for the entire document set. Note for a single medical textbook, it takes about 6 minutes.

In [14]:
#FAISS Indexing
db = FAISS.from_documents(docs, sagemakerEndpointEmbeddingsJumpStart) #embeddings
db.save_local("faiss_index")


In [33]:
loaded_db = FAISS.load_local(
    "faiss_index", 
    sagemakerEndpointEmbeddingsJumpStart,
    allow_dangerous_deserialization=True  # Add this parameter
)

Next, we insert the embeddings to the FAISS vector store

Next we create user query to retrieve a response from vector search and LLM combined

In [34]:
# Method 1: Simple similarity search
query = "What is acute kidney injury?"
k = 3  # number of results you want to retrieve
docs = loaded_db.similarity_search(query, k=k)

# Print the results
for doc in docs:
    print(doc.page_content)
    print("-------------------")

Acute tubular injury (ATI) is a clinicopathologic entity characterized by damage to tubular epithelial cells and an acute decline in renal function, often associated with shedding of granular casts and tubular cells into the urine. Clinicians use the term acute tubular necrosis, but frank necrosis is rarely observed in a kidney biopsy, so pathologists prefer the term acute tubular injury. The constellation of changes, broadly termed acute kidney injury, manifests clinically as decreased GFR with concurrent elevation of serum creatinine. ATI is the most common cause of acute kidney injury and may produce oliguria (defined as urine output of <400 mL/day).

http://ebooksmedicine.net

There are two forms of ATI that differ in the underlying causes.
-------------------
Shiga toxin–associated HUS is characterized by the sudden onset, usually after a gastrointestinal or flulike prodromal episode, of bleeding manifestations (especially hematemesis and melena), severe oliguria, hematuria, micro

In [25]:
# Method 2: Using similarity_search_with_score
relevant_documents_with_scores = loaded_db.similarity_search_with_score(query, k=4)
for doc, score in relevant_documents_with_scores:
    print(f"Score: {score}")
    print(f"Content: {doc.page_content}")
    print("---")

Score: 0.31192612648010254
Content: Acute tubular injury (ATI) is a clinicopathologic entity characterized by damage to tubular epithelial cells and an acute decline in renal function, often associated with shedding of granular casts and tubular cells into the urine. Clinicians use the term acute tubular necrosis, but frank necrosis is rarely observed in a kidney biopsy, so pathologists prefer the term acute tubular injury. The constellation of changes, broadly termed acute kidney injury, manifests clinically as decreased GFR with concurrent elevation of serum creatinine. ATI is the most common cause of acute kidney injury and may produce oliguria (defined as urine output of <400 mL/day).

http://ebooksmedicine.net

There are two forms of ATI that differ in the underlying causes.
---
Score: 0.4701605439186096
Content: Shiga toxin–associated HUS is characterized by the sudden onset, usually after a gastrointestinal or flulike prodromal episode, of bleeding manifestations (especially hem

---
## Create RAG with Langchain and LLM hosted on SageMaker


In [35]:
import boto3
import json
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from typing import Any, Dict, List, Optional
import os


# Ensure correct naming
os.environ["AWS_DEFAULT_REGION"] = REGION_NAME

#Text Generation with Mistal
class ContentHandlerFalcon(LLMContentHandler):
    content_type = "application/json"
    accepts = "text/plain"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        # Format the prompt according to Falcon-Instruct expectations
        formatted_prompt = f"{prompt}"
        input_str = json.dumps({
            "inputs": formatted_prompt,
            "parameters": {
                "max_new_tokens": 500,
                "temperature": 0.7,
                "top_p": 0.95,
                "do_sample": True
            }
        })
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        try:
            raw_response = output.read().decode("utf-8")
            print(f"Raw response from model: {raw_response}")  # Debug print
            response_json = json.loads(raw_response)
            
            if isinstance(response_json, list) and len(response_json) > 0:
                if "generated_text" in response_json[0]:
                    return response_json[0]["generated_text"]
            elif isinstance(response_json, dict):
                return response_json.get("generated_text", "")
            
            print(f"Unexpected response format: {response_json}")  # Debug print
            return ""
        except Exception as e:
            print(f"Error processing output: {str(e)}")
            return ""

content_handler_falcon = ContentHandlerFalcon()

llm = SagemakerEndpoint(
    endpoint_name=TEXT_GENERATION_MODEL_ENDPOINT_NAME,
    region_name=REGION_NAME,
    content_handler=content_handler_falcon,
    model_kwargs={"max_new_tokens": 500}  # Simplify parameters
)


In [36]:
# RAG QA Chain
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=db.as_retriever(search_kwargs={"k": 1}),
    return_source_documents=True,
    verbose=True  # Add this to see the chain's operation
)


In [37]:
# Example Question
try:
    query = "What are the main causes of heart failure?"
    result = qa_chain({"query": query})
    print("Question:", query)
    print("Answer:", result["result"])
    print("Source Documents:", result["source_documents"])
except Exception as e:
    print(f"Error during query: {str(e)}")




[1m> Entering new RetrievalQA chain...[0m
Raw response from model: {"generated_text": " Heart failure can be caused by either systolic or diastolic dysfunction. Systolic dysfunction results from weakened heart muscle contractions, often due to ischemic heart disease or hypertension. Diastolic dysfunction is characterized by an inability of the heart to relax and fill properly, and can be caused by conditions such as massive left ventricular hypertrophy, myocardial fibrosis, amyloid deposition, or constrictive pericarditis. Approximately half of heart failure cases are caused by diastolic dysfunction, and it is more common in older adults, diabetic patients, and women. Other causes of heart failure include valve dysfunction and rapid increases in blood volume or pressure."}

[1m> Finished chain.[0m
Question: What are the main causes of heart failure?
Answer:  Heart failure can be caused by either systolic or diastolic dysfunction. Systolic dysfunction results from weakened heart m