## Medical Question answering with Retrieval Augmented Generation design pattern. 
Use Python 3 (Data Science 3.0) kernel image and `ml.m5.xlarge` for this notebook.

This includes generating embeddings of all existing documents, indexing them in a vector store. Then for every user query, generate local embeddings and search based on embedding distance. The search responses act as context to the LLM model to generate a output. 

Challenges:
How to manage large document(s) that exceed the token limitHow to find the document(s) relevant to the question being asked

## Key components

LLM (Large Language Model): Falcon-40b-instruct available through Amazon SageMaker This model will be used to understand the document chunks and provide an answer in human friendly manner.

Embeddings Model: GPT-J 6B available through Amazon SageMaker. This model will be used to generate a numerical representation of the textual documents.

Vector Store: FAISS available through LangChainIn this notebook we are using this in-memory vector-store to store both the embeddings and the documents. In an enterprise context this could be replaced with a persistent store such as AWS OpenSearch, RDS Postgres with pgVector, ChromaDB, Pinecone or Weaviate.

Index: VectorIndex The index helps to compare the input embedding and the document embeddings to find relevant document

### Dataset
To explain this architecture pattern we are using the documents from MedQA. These documents include medical textbooks such as:
Pathology, Anatomy, Pharmacology and others. 

Download textbooks that are part of Q&A dataset MedQA released as part of Jin, Di, et al. "What Disease does this Patient Have? A Large-scale Open Domain Question Answering Dataset from Medical Exams." arXiv preprint arXiv:2009.13081 (2020). 

More details are available here https://github.com/jind11/MedQA

* Data source : @article{jin2020disease,
  title={What Disease does this Patient Have? A Large-scale Open Domain Question Answering Dataset from Medical Exams},
  author={Jin, Di and Pan, Eileen and Oufattole, Nassim and Weng, Wei-Hung and Fang, Hanyi and Szolovits, Peter},
  journal={arXiv preprint arXiv:2009.13081},
  year={2020} }
  
  

### Data preparation

Download the data from 
https://d1.awsstatic.com/whitepapers/architecture/AWS_Well-Architected_Framework.pdf

##### Prerequisites

In [1]:
%pip install faiss-cpu==1.7.4 --quiet

[0m[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m[33mDEPRECATION: textract 1.6.5 has a non-standard dependency specifier extract-msg<=0.29.*. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of textract or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
%pip install langchain==0.0.222 --quiet

[0m[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m[33mDEPRECATION: textract 1.6.5 has a non-standard dependency specifier extract-msg<=0.29.*. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of textract or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m[31mERROR: Cannot uninstall 'PyYAML'. It is a distutils installed project and thus we cannot accurately determine which files belong to it which would lead to only a partial uninstall.[0m[31m
[0mNote: you may need to restart the kernel t

In [3]:
%%capture 

!pip install PyYAML

#### Imports

In [4]:
import requests
import logging 
import boto3
import yaml
import json

##### Setup logging

In [5]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

In [15]:
!pip install pypdf

Collecting pypdf
  Obtaining dependency information for pypdf from https://files.pythonhosted.org/packages/e3/a8/daf130ed0e6ead60f99b037c360e3ed910a2cd0accdaf612589b8ba83187/pypdf-3.15.5-py3-none-any.whl.metadata
  Downloading pypdf-3.15.5-py3-none-any.whl.metadata (7.1 kB)
Downloading pypdf-3.15.5-py3-none-any.whl (272 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m272.6/272.6 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hInstalling collected packages: pypdf
Successfully installed pypdf-3.15.5
[0m

##### Log versions of dependencies 

In [6]:
logger.info(f'Using requests=={requests.__version__}')
logger.info(f'Using pyyaml=={yaml.__version__}')

Using requests==2.31.0
Using pyyaml==6.0


#### Setup essentials

In [7]:
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-textembedding-gpt-j-6b-fp16'
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-llm-falcon-7b-instruct-bf16'

REGION_NAME = boto3.session.Session().region_name

#### Encode passages (chunks) using JumpStart's GPT-J text embedding model . We are specifically using only 1 of 20 textbooks from the dataset. It takes about 6 minutes to generate embeddings for one textbook (for example, Pathology). You can increase the number of textbooks indexed by adding sufficient time buffer for execution. 

In order to follow the RAG approach this notebook is using the LangChain framework where it has integrations with different services and tools that allow efficient building of patterns such as RAG. 

In [16]:
import numpy as np
import pypdf
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader, TextLoader
from langchain.document_loaders import PyPDFLoader

loader = PyPDFLoader("../data/AWS_Well-Architected_Framework.pdf")

#loader = DirectoryLoader("./data/", glob="*.pdf", loader_cls=PDFLoader)

documents = loader.load()
# - in our testing Character split works better with this PDF data set
text_splitter = RecursiveCharacterTextSplitter(
    # Set a really small chunk size, just to show.
    chunk_size = 1000,
    chunk_overlap  = 100,
)
docs = text_splitter.split_documents(documents)

In [17]:
print(docs[0])

page_content='ArchivedAWS Well-Architected Framework\nJuly 2020\nThis whitepaper describes the AWS Well-Architected Framework. It provides guidance to help cus-\ntomers apply best practices in the design, delivery, and maintenance of AWS environments. We address\ngeneral design principles as well as specific best practices and guidance in ﬁve conceptual areas that\nwe define as the pillars  of the Well-Architected Framework.This paper has been archived.\nThe latest version is available at:\nhttps://docs.aws.amazon.com/wellarchitected/latest/framework/welcome.html' metadata={'source': 'data/AWS_Well-Architected_Framework.pdf', 'page': 0}


In [18]:
avg_doc_length = lambda documents: sum([len(doc.page_content) for doc in documents])//len(documents)
avg_char_count_pre = avg_doc_length(documents)
avg_char_count_post = avg_doc_length(docs)
print(f'Average length among {len(documents)} documents loaded is {avg_char_count_pre} characters.')
print(f'After the split we have {len(docs)} documents more than the original {len(documents)}.')
print(f'Average length among {len(docs)} documents (after split) is {avg_char_count_post} characters.')

Average length among 97 documents loaded is 2259 characters.
After the split we have 282 documents more than the original 97.
Average length among 282 documents (after split) is 826 characters.


In [19]:
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.embeddings import SagemakerEndpointEmbeddings
from typing import Any, Dict, List, Optional
from langchain.llms.sagemaker_endpoint import ContentHandlerBase


class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:
        """Compute doc embeddings using a SageMaker Inference Endpoint.

        Args:
            texts: The list of texts to embed.
            chunk_size: The chunk size defines how many input texts will
                be grouped together as request. If None, will use the
                chunk size specified by the class.

        Returns:
            List of embeddings, one for each text.
        """
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size

        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i : i + _chunk_size])
            print
            results.extend(response)
        return results


class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        return embeddings


content_handler = ContentHandler()

sagemakerEndpointEmbeddingsJumpStart = SagemakerEndpointEmbeddingsJumpStart(
    endpoint_name=TEXT_EMBEDDING_MODEL_ENDPOINT_NAME,
    region_name=REGION_NAME,
    content_handler=content_handler,
)

In [20]:
print(docs[0].page_content)

ArchivedAWS Well-Architected Framework
July 2020
This whitepaper describes the AWS Well-Architected Framework. It provides guidance to help cus-
tomers apply best practices in the design, delivery, and maintenance of AWS environments. We address
general design principles as well as specific best practices and guidance in ﬁve conceptual areas that
we define as the pillars  of the Well-Architected Framework.This paper has been archived.
The latest version is available at:
https://docs.aws.amazon.com/wellarchitected/latest/framework/welcome.html


In [21]:
sample_embedding = np.array(sagemakerEndpointEmbeddingsJumpStart.embed_query(docs[0].page_content))
print("Sample embedding of a document chunk: ", sample_embedding)
print("Size of the embedding: ", sample_embedding.shape)

Sample embedding of a document chunk:  [-0.00153714  0.00214407  0.0035825  ...  0.00850618 -0.00149574
 -0.02181229]
Size of the embedding:  (4096,)


Now create embeddings for the entire document set. Note for a single medical textbook, it takes about 6 minutes.

In [22]:
from tqdm.contrib.concurrent import process_map
from multiprocessing import cpu_count

def generate_embeddings(x):
    return (x, sagemakerEndpointEmbeddingsJumpStart.embed_query(x))
    
workers = 1 * cpu_count()

texts = [i.page_content for i in docs]

In [23]:
workers

2

In [24]:
data = process_map(generate_embeddings, texts, max_workers=workers, chunksize=100)

  0%|          | 0/282 [00:00<?, ?it/s]

Next, we insert the embeddings to the FAISS vector store

In [25]:
from langchain.vectorstores import FAISS
faiss = FAISS.from_documents(docs[0:2], sagemakerEndpointEmbeddingsJumpStart)
faiss.add_embeddings(data)
faiss.save_local("faiss_index")

Next we create user query to retrieve a response from vector search and LLM combined

In [26]:
query = "What are the 5 pillars of well architected framework?"

In [27]:
query_embedding = faiss.embedding_function(query)
np.array(query_embedding)

array([ 0.00698841, -0.02292958,  0.01051318, ..., -0.01012999,
       -0.00661008, -0.001499  ])

In [28]:
relevant_documents = faiss.similarity_search_by_vector(query_embedding)
context = ""
print(f'{len(relevant_documents)} documents are fetched which are relevant to the query.')
print('----')
for i, rel_doc in enumerate(relevant_documents):
    print(f'## Document {i+1}: {rel_doc.page_content}.......')
    print('---')
    context += rel_doc.page_content
context = context.replace("\n", " ")

4 documents are fetched which are relevant to the query.
----
## Document 1: can emerge that is driven by customer need. Technology leaders (such as a CTOs or
development managers), carrying out Well-Architected reviews across all your work-
loads will allow you to better understand the risks in your technology portfolio. Using
this approach, you can identify themes across teams that your organization could ad-
dress by mechanisms, training, or lunchtime talks where your principal engineers can
share their thinking on specific areas with multiple teams.
3Working backward is a fundamental part of our innovation process. We start with the customer and what
they want, and let that define and guide our efforts.
4.......
---
## Document 2: ternet scale. We prefer to use data to define best practice, but we also use subject
matter experts, like principal engineers, to set them. As principal engineers see new
best practices emerge, they work as a community to ensure that teams follow them.
In

Now create a prompt template to trigger the model with above context from vector search. We specifically inform the model to answer only using the context provied.

In [29]:
template = """
        You are a helpful, polite, fact-based agent.
        If you don't know the answer, just say that you don't know.
        Please answer the following question using the context provided. 

        CONTEXT: 
        {context}
        =========
        QUESTION: {question} 
        ANSWER: """


In [30]:
prompt = template.format(context=context, question=query)
print(prompt)


        You are a helpful, polite, fact-based agent.
        If you don't know the answer, just say that you don't know.
        Please answer the following question using the context provided. 

        CONTEXT: 
        can emerge that is driven by customer need. Technology leaders (such as a CTOs or development managers), carrying out Well-Architected reviews across all your work- loads will allow you to better understand the risks in your technology portfolio. Using this approach, you can identify themes across teams that your organization could ad- dress by mechanisms, training, or lunchtime talks where your principal engineers can share their thinking on specific areas with multiple teams. 3Working backward is a fundamental part of our innovation process. We start with the customer and what they want, and let that define and guide our efforts. 4ternet scale. We prefer to use data to define best practice, but we also use subject matter experts, like principal engineers, to set th

Invoke the endpoint to generate a response from the LLM

In [31]:
smr_client = boto3.client("sagemaker-runtime")

In [32]:
response_model = smr_client.invoke_endpoint(
    EndpointName=TEXT_GENERATION_MODEL_ENDPOINT_NAME,
    Body=json.dumps(
        {"inputs": prompt, "parameters": {"max_new_tokens": 500}}
    ),
    ContentType="application/json",
)
response = json.loads(response_model["Body"].read())


In [33]:
print(response[0]["generated_text"])

1. Security, 2. Reliability, 3. Performance, 4. Cost optimization, 5. Operational excellence.
