# Retrieval Augmented Generation (RAG) using Foundation Models in SageMaker

In this notebook we demonstrate how to use Retrieval Augmented Generation (RAG) to build a question-and-answer chatbot to converse with the **Construction Doc** using Foundation Models in SageMaker.

Foundation models are usually trained offline, making the model agnostic to any data that is created after the model was trained. Additionally, foundation models are trained on very general domain corpora, making them less effective for domain-specific tasks. Retrieval Augmented Generation (RAG) is used to retrieve data from outside a foundation model and augment your prompts by adding the relevant retrieved data in context. For more information about RAG model architectures, see [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/abs/2005.11401).

With RAG, the external data used to augment your prompts can come from multiple data sources, such as a document repositories, databases, or APIs. The first step is to convert your documents and any user queries into a compatible format to perform relevancy search. To make the formats compatible, a document collection, or knowledge library, and user-submitted queries are converted to numerical representations using embedding language models. Embedding is the process by which text is given numerical representation in a vector space. RAG model architectures compare the embeddings of user queries within the vector of the knowledge library. The original user prompt is then appended with relevant context from similar documents within the knowledge library. This augmented prompt is then sent to the foundation model. You can update knowledge libraries and their relevant embeddings asynchronously.

In the previous sections of this workshop, you deployed the **llam2** Foundation Model to SageMaker Endpoints and used these models for various Natural Language Processing (NLP) tasks such as text summarization, common sense reasoning, translation and question and answering. In this section, we will use this SageMaker endpoints to create vector embeddings that are stored in Amazon OpenSearch. We then use these embeddings in a RAG-model for a question-and-answer chatbot. The diagram below depicts this architecture.

We will also use **LangChain**, an opensource framework for developing and interfacing with applications powered by language models.

## Prerequisites

The following are the prerequisites for this notebook:
1. Run the Jupyter Notebook titled `01-deploy-text2text-model.jpynb`. This notebook deploys the FLAN-T5-XL LLM to a SageMaker Endpoint.
2. Deploy the SageMaker Jumpstart Model called `GPT-J 6B Embedding FP16` text embeddings model.
3. [Not required if you do step 2.] Run the Jupyter Notebook titled `02-deploy-text2emb-model.jpynb`. This notebook deploys the gpt-j-6b-fp16 LLM to a SageMaker Endpoint.
4. [Non-AWS Event] Run the Jupyter Notebook titled `03-create-vector-store.jpynb`. This notebook creates an Amazon OpenSearch Cluster and required Index for the vector database. This notebook is not required if you are running an AWS Event.

In [100]:
%pip install -U pip --quiet
%pip install --upgrade sagemaker --quiet 
%pip install langchain --quiet
%pip install opensearch-py --quiet
%pip install regex --quiet
%pip install tqdm --quiet
%pip install requests_aws4auth --quiet
%pip install PyPDF2 --quiet 
%pip install pypdf --quiet

[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [101]:
# Setup SageMaker Session
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

In [102]:
# Restore SageMaker Model Endpoints from previous notebooks
#%store -r LLM

LLM = "jumpstart-dft-GA-text2text-flan-t5-xl"
ELLM = "jumpstart-dft-ga-textembedding-gpt-j-6b-fp16"
# %store -r embeddings_model_endpoint_name
#%store -r ELLM

In [103]:
# Set variables for Amazon OpenSearch
CFN_STACK_NAME = "GenAI-Opensearch"

import sys
import logging
logger = logging.getLogger()
logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)

import boto3
from typing import List
stacks = boto3.client('cloudformation').list_stacks()
stack_found = CFN_STACK_NAME in [stack['StackName'] for stack in stacks['StackSummaries']]

def get_cfn_outputs(stackname: str) -> List:
    cfn = boto3.client('cloudformation')
    outputs = {}
    for output in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']:
        outputs[output['OutputKey']] = output['OutputValue']
    return outputs

def get_cfn_parameters(stackname: str) -> List:
    cfn = boto3.client('cloudformation')
    params = {}
    for param in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Parameters']:
        params[param['ParameterKey']] = param['ParameterValue']
    return params

if stack_found is True:
    outputs = get_cfn_outputs(CFN_STACK_NAME)
    params = get_cfn_parameters(CFN_STACK_NAME)
    logger.info(f"cfn outputs={outputs}\nparams={params}")

    opensearch_domain_endpoint = f"https://{outputs['OpenSearchDomainEndpoint']}"
    opensearch_domain_name =  outputs['OpenSearchDomainName']
    aws_region = outputs['Region']
    opensearch_secretid = outputs['OpenSearchSecret']
    opensearch_domain_name =  outputs['OpenSourceDomainArn']
    # ARN of the secret is of the following format arn:aws:secretsmanager:region:account_id:secret:my_path/my_secret_name-autoid
    os_creds_secretid_in_secrets_manager = "-".join(outputs['OpenSearchSecret'].split(":")[-1].split('-')[:-1])
else:
    logger.info(f"cloud formation stack {CFN_STACK_NAME} not found, set parameters manually here")




2023-07-26 16:56:06,937,1847989606,MainProcess,INFO,cfn outputs={'OpenSourceDomainArn': 'arn:aws:es:us-east-1:924118560136:domain/opensearchservi-orshpbtsh2xx', 'OpenSearchDomainEndpoint': 'search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com', 'Region': 'us-east-1', 'OpenSearchDomainName': 'opensearchservi-orshpbtsh2xx', 'OpenSearchSecret': 'arn:aws:secretsmanager:us-east-1:924118560136:secret:OpenSearchSecret-GenAI-Opensearch-7B1eTo'}
params={'OpenSearchIndexName': 'gen_workshop_index', 'OpenSearchUsername': 'opensearchuser', 'OpenSearchPassword': '****'}


## Chunk your Data and Load into Amazon OpenSearch

In this section we will chunk the data into smaller documents. Chunking is a technique for splitting large texts into smaller chunks. It is an important step as it optimizes the relevance of the search query for our RAG-model. Which in turn improves the quality of the chatbot. 

In [104]:
import pathlib
from PyPDF2 import PdfReader
from langchain.document_loaders import PyPDFLoader

In [105]:
file = "SOCI.pdf"
path = f"{pathlib.Path().absolute()}/{file}"
loader = PyPDFLoader(path)
data = loader.load()

In [106]:
print (f'You have {len(data)} document(s) in your data')
print (f'There are {len(data[0].page_content)} characters in your document')

You have 6 document(s) in your data
There are 4189 characters in your document


In [107]:
from langchain.text_splitter import RecursiveCharacterTextSplitter,CharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1600, chunk_overlap=200)
docs = text_splitter.split_documents(data)

print (f'Now you have {len(docs)} documents')
print (f'There are {len(docs[0].page_content)} characters in your document')

Now you have 15 documents
There are 1547 characters in your document


In [108]:
# Helper function to process document

import regex as re

def postproc(s):
    s = s.replace(u'\xa0', u' ') # no-break space 
    s = s.replace('\n', ' ') # new-line
    s = re.sub(r'\s+', ' ', s) # multiple spaces
    return s

for doc in docs:
    doc.page_content = postproc(doc.page_content)

In [109]:
# Review the random document for correctness
docs[9]

Document(page_content="Amazon Confidential Page 4 of 6 4. Filtering Lambda uses the SOCIRepositoryImageTagFilters CFN parameter to determine if the image from the event matches any of the filters. If there is a match, the Fil tering Lambda invokes the SOCIIndexBuilder Lambda with the image's descriptor. 5. SOCIIndexBuilder Lambda generates SOCI index artifacts and uploads them into ECR. Note : Auto -Index generation currently supports a max imum compressed image size of 6 GiB. Q7. How do I know that an index has been generated for my image? ECR sends an EventBridge notification once a container image is successfully associated with index arti fact. You can also see the ‘Image I ndex ’ and the ‘Other’ artifact s associated to the image using ECR DescribeImages API and on the ECR Console. When a contain er image is deleted from ECR, the associated index artifact s will also be automatically deleted. Note: It is possible to push and associate multiple indices to the same image , although 

In [110]:
# Limit the number of total chunks to 4000
MAX_DOCS = 4000
if len(docs) > MAX_DOCS:
    docs = docs[:MAX_DOCS]

### Prior to populating a vector store, compute embedding to validate the smoothness / no exceptions.

### Read credentials from AWS Secrets Manager
The credentials for the OpenSearch cluster are store in AWS Secrets Mananger, our code reads the credentials from there and provides them to the opensearch-py package (through langchain API).

In [111]:
%%writefile credentials.py

"""
Retrieve credentials password for given username from AWS SecretsManager
"""
import json
import boto3

def get_credentials(secret_id: str, region_name: str) -> str:
    
    client = boto3.client('secretsmanager', region_name=region_name)
    response = client.get_secret_value(SecretId=secret_id)
    secrets_value = json.loads(response['SecretString'])    
    
    return secrets_value

Overwriting credentials.py


In [112]:
import time
import json
import logging
from typing import List
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

logger = logging.getLogger(__name__)

# extend the SagemakerEndpointEmbeddings class from langchain to provide a custom embedding function
class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(
        self, texts: List[str], chunk_size: int = 5
    ) -> List[List[float]]:
        """Compute doc embeddings using a SageMaker Inference Endpoint.

        Args:
            texts: The list of texts to embed.
            chunk_size: The chunk size defines how many input texts will
                be grouped together as request. If None, will use the
                chunk size specified by the class.

        Returns:
            List of embeddings, one for each text.
        """
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
        st = time.time()
        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i:i + _chunk_size])
            results.extend(response)
        time_taken = time.time() - st
        logger.info(f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}")
        return results


# class for serializing/deserializing requests/responses to/from the embeddings model
class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:

        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode('utf-8') 

    def transform_output(self, output: bytes) -> str:

        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        if len(embeddings) == 1:
            return [embeddings[0]]
        return embeddings
    

def create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name: str, aws_region: str) -> SagemakerEndpointEmbeddingsJumpStart:
    # all set to create the objects for the ContentHandler and 
    # SagemakerEndpointEmbeddingsJumpStart classes
    content_handler = ContentHandler()

    # note the name of the LLM Sagemaker endpoint, this is the model that we would
    # be using for generating the embeddings
    embeddings = SagemakerEndpointEmbeddingsJumpStart( 
        endpoint_name=embeddings_model_endpoint_name,
        region_name=aws_region, 
        content_handler=content_handler
    )
    return embeddings

In [112]:
import time
import json
import logging
from typing import List
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

logger = logging.getLogger(__name__)

# extend the SagemakerEndpointEmbeddings class from langchain to provide a custom embedding function
class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(
        self, texts: List[str], chunk_size: int = 5
    ) -> List[List[float]]:
        """Compute doc embeddings using a SageMaker Inference Endpoint.

        Args:
            texts: The list of texts to embed.
            chunk_size: The chunk size defines how many input texts will
                be grouped together as request. If None, will use the
                chunk size specified by the class.

        Returns:
            List of embeddings, one for each text.
        """
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
        st = time.time()
        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i:i + _chunk_size])
            results.extend(response)
        time_taken = time.time() - st
        logger.info(f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}")
        return results


# class for serializing/deserializing requests/responses to/from the embeddings model
class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:

        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode('utf-8') 

    def transform_output(self, output: bytes) -> str:

        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        if len(embeddings) == 1:
            return [embeddings[0]]
        return embeddings
    

def create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name: str, aws_region: str) -> SagemakerEndpointEmbeddingsJumpStart:
    # all set to create the objects for the ContentHandler and 
    # SagemakerEndpointEmbeddingsJumpStart classes
    content_handler = ContentHandler()

    # note the name of the LLM Sagemaker endpoint, this is the model that we would
    # be using for generating the embeddings
    embeddings = SagemakerEndpointEmbeddingsJumpStart( 
        endpoint_name=embeddings_model_endpoint_name,
        region_name=aws_region, 
        content_handler=content_handler
    )
    return embeddings

Next, we create the embeddings object and batch the create the document embeddings.

In [113]:
#embeddings = SagemakerEndpointEmbeddingsJumpStart(endpoint_name="hf-textgeneration1-gpt-j-6b-fp16-2023-07-24-03-32-58-243", region_name=aws_region, content_handler=content_handler)
!echo $ELLM
embeddings = create_sagemaker_embeddings_from_js_model(ELLM, aws_region)

jumpstart-dft-ga-textembedding-gpt-j-6b-fp16


### Create embeddings of your documents to get ready for semantic search

In [114]:
from credentials import get_credentials
from langchain.vectorstores import OpenSearchVectorSearch


creds = get_credentials(opensearch_secretid, aws_region)
http_auth = (creds['username'], creds['password'])
opensearch_index_name = "genai9-index"




docsearch = OpenSearchVectorSearch.from_texts(index_name=opensearch_index_name,texts = [d.page_content for d in docs],metadatas = [d.metadata for d in docs],embedding=embeddings,opensearch_url=opensearch_domain_endpoint,http_auth=http_auth,bulk_size =4000)

#text_splitter = RecursiveCharacterTextSplitter(
        # Set a really small chunk size, just to show.
#        chunk_size=1600,
#        chunk_overlap=200,
#        length_function=len,
#    )

#for doc in data:
#    doc.metadata['timestamp'] = time.time()
#    doc.metadata['embeddings_model'] = ELLM
    
#chunks = text_splitter.create_documents([doc.page_content for doc in data], metadatas=[doc.metadata for doc in data])


#docsearch = OpenSearchVectorSearch.from_documents(documents=docs,embedding=embeddings,opensearch_url=opensearch_domain_endpoint,http_auth=http_auth)                                                            
                                                              


2023-07-26 16:56:09,866,3956364129,MainProcess,INFO,got results for 15 in 1.499828577041626s, length of embeddings list is 15
2023-07-26 16:56:10,300,base,MainProcess,INFO,GET https://search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com:443/genai9-index [status:200 request:0.433s]
2023-07-26 16:56:10,406,base,MainProcess,INFO,POST https://search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com:443/_bulk [status:200 request:0.062s]
2023-07-26 16:56:10,440,base,MainProcess,INFO,POST https://search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com:443/_bulk [status:200 request:0.023s]
2023-07-26 16:56:10,463,base,MainProcess,INFO,POST https://search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com:443/genai9-index/_refresh [status:200 request:0.022s]


## Question answering over Documents 

So far, we have chunked a large document into smaller ones, created vector embedding and stored them in an OpenSearch Vector Database. Now, we can answer questions over this document data.

Since we have created an index over the data, we can do a semantic search over the documents; this way only the most relevant documents to answer the question are passed via the prompt to the Large Language Model (LLM). You save both time and money by not passing all the documents to the LLM.

We use langchains **question_answering** `stuff` document chain in this example. Further details on Document Chains can be found by visiting the langchain [documentation, here](https://python.langchain.com/docs/modules/chains/document/)

In [115]:
from typing import Dict

from langchain import PromptTemplate, SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains.question_answering import load_qa_chain
import json

parameters = {
    "do_sample": True,
    "top_p": 0.95,
    "temperature": 1e-10,

}

class SageMakerLLMContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        # input_str = json.dumps({prompt: prompt, **model_kwargs})
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        # return response_json[0]["generated_text"]
        return response_json['generated_texts'][0]
    
custom_attributes = "accept_eula=true"    


sagemaker_llm_content_handler= SageMakerLLMContentHandler()

chain = load_qa_chain(
    llm=SagemakerEndpoint(
        endpoint_name=LLM,
        # credentials_profile_name="credentials-profile-name",
        region_name=aws_region,
        model_kwargs={"temperature": 1e-10},
        content_handler=sagemaker_llm_content_handler,
    ),
    chain_type="stuff"
)

In [116]:
query = "What is SOCI?"
ss_docs = docsearch.similarity_search(query, include_metadata=True)
chain.run(input_documents=ss_docs, question=query)

2023-07-26 16:56:10,663,base,MainProcess,INFO,POST https://search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com:443/genai9-index/_search [status:200 request:0.034s]


'indices'

In [117]:
for person in ['sherlock', 'ettie']:
    for query_template in [
                    "How old is {PERSON}?",
                    "What is {PERSON} current position and what is the name of the organization he/she currently works for?"
                 ]:
    
        query = query_template.format(PERSON=person)
        print('Q:', query)

        sim_docs = docsearch.similarity_search(query, include_metadata=True)
        answer = chain.run(input_documents=sim_docs, question=query)    
        print('A:', answer)
        print('\n---\n')

2023-07-26 16:56:11,062,base,MainProcess,INFO,POST https://search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com:443/genai9-index/_search [status:200 request:0.019s]


Q: How old is sherlock?


2023-07-26 16:56:11,347,base,MainProcess,INFO,POST https://search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com:443/genai9-index/_search [status:200 request:0.019s]


A: not enough information

---

Q: What is sherlock current position and what is the name of the organization he/she currently works for?


2023-07-26 16:56:11,630,base,MainProcess,INFO,POST https://search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com:443/genai9-index/_search [status:200 request:0.020s]


A: not enough information

---

Q: How old is ettie?


2023-07-26 16:56:11,916,base,MainProcess,INFO,POST https://search-opensearchservi-orshpbtsh2xx-doerunuins3frn4m3dl33j5w5a.us-east-1.es.amazonaws.com:443/genai9-index/_search [status:200 request:0.019s]


A: not enough information

---

Q: What is ettie current position and what is the name of the organization he/she currently works for?
A: not enough information

---

