## Retrieval Augmented Question & Answering with SageMaker Jumpstart Foundation Model using LangChain and Amazon OpenSearch Serverless

### Context
Previously we showed how you could build a movie assistant AI RAG Chatbot using Knowledge Bases for Bedrock. In this notebook, we are going to explore another option of building a RAG chatbot using an LLM (Meta Llama2) hosted in Amazon SageMaker through SageMaker Jumpstart.


### Architecture
![qna-rag](images/langchain-sagemaker-qa-rag.png)

### Challenges
- How to manage large document(s) that could potentially exceed the token limit
- How to find the document(s) relevant to the question being asked

### Proposal
To the above challenges, this notebook proposes the following strategy:

#### Prepare documents
![Embeddings](./images/embeddings_lang.png)

Before being able to answer the questions, the documents must be processed and a stored in a document store index
- Load the documents
- Process and split them into smaller chunks
- Create a numerical vector representation of each chunk using Amazon Bedrock Titan Embeddings model
- Create an index using the chunks and the corresponding embeddings
  
#### Ask question
![Question](./images/chatbot_lang.png)

When the documents index is prepared, you are ready to ask the questions and relevant documents will be fetched based on the question being asked. Following steps will be executed.
- Create an embedding of the input question
- Compare the question embedding with the embeddings in the index
- Fetch the (top N) relevant document chunks
- Add those chunks as part of the context in the prompt
- Send the prompt to the model hosted in SageMaker
- Get the contextual answer based on the documents retrieved

## Usecase
#### Dataset
To explain this architecture pattern we are using a few documents from MovieLens dataset. These documents explain topics such as:
- Movie synopsis.
- Release dates
- Cast members
  

#### Persona
Let's assume a persona of a user who is looking for information about movies/shows. 

The model will try to answer from the documents in easy language.

In [None]:
%pip install opensearch-py==2.4.2 langchain==0.1.9 boto3 lark -q

In [None]:
%pip install ipywidgets==8.0.4 -q

# Langchain Integration 
<img src="images/langchain-logo.png" alt="langchain" style="width: 400px;"/>
LangChain is a framework for developing applications powered by LLMs. As a high level, langchain enables applications that are:

* Data-aware: connect a language model to other sources of data
* Agentic: allow a language model to interact with its environment

The main advantages of using LangChain are:

* Provides framework abstractions for working with language models, along with a collection of implementations for each abstraction. 
* Modular design principle promotes flexibility to use any LangChain components to build an application 
* Provides many Off-the-shelf chains that makes it easy to get started. 

Langchain also has robust Sagemaker support. In this workshop, we'll be using the following langchain components to integrate with the LLM model and the embeddings model deployed in SageMaker to build a simple Q&A application.


* [Langchain SageMaker Endpoint](https://python.langchain.com/docs/integrations/providers/sagemaker_endpoint)

Setting up environment

In [None]:
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever

In [None]:
import boto3
import uuid
import json
import time
import os
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
import glob
from langchain.schema import Document
from langchain_community.vectorstores import OpenSearchVectorSearch
from typing import Any, Dict, List, Optional

In [None]:
random_id = str(uuid.uuid4().hex)[:5]
vectordb_name="sm-llm-vector-db"
vector_store_name = f'{vectordb_name}-{random_id}'
index_name = f"{vectordb_name}-index-{random_id}"
encryption_policy_name = f"{vectordb_name}-sp-{random_id}"
network_policy_name = f"{vectordb_name}-np-{random_id}"
access_policy_name = f"{vectordb_name}-ap-{random_id}"
kb_role_name = f"{vectordb_name}-role-{random_id}"
knowledge_base_name = f"{vectordb_name}-{random_id}"

## Data Preparation
In the following section, we're going to prepare our knoledge base store using Amazon OpenSearch Severless collection. The dataset is provided in the 'data' folder in this project and ready to be ingested. We'll leverage langchain framework to help us simplify the data ingestion process.
The main steps for data ingestion workflow are:

1. Create an opensearch serverless collection
2. Prepare documents from the data folder
3. Creates an Embedding model to be used for converting the texts into vectors embeddings.
4. Ingest the embeddings into the collection

In [None]:
def create_opensearch_serverless_collection(vector_store_name, 
                                            index_name, 
                                            encryption_policy_name, 
                                            network_policy_name, 
                                            access_policy_name):
    identity = boto3.client('sts').get_caller_identity()['Arn']

    aoss_client = boto3.client('opensearchserverless')

    security_policy = aoss_client.create_security_policy(
        name = encryption_policy_name,
        policy = json.dumps(
            {
                'Rules': [{'Resource': ['collection/' + vector_store_name],
                'ResourceType': 'collection'}],
                'AWSOwnedKey': True
            }),
        type = 'encryption'
    )

    network_policy = aoss_client.create_security_policy(
        name = network_policy_name,
        policy = json.dumps(
            [
                {'Rules': [{'Resource': ['collection/' + vector_store_name],
                'ResourceType': 'collection'}],
                'AllowFromPublic': True}
            ]),
        type = 'network'
    )

    collection = aoss_client.create_collection(name=vector_store_name,type='VECTORSEARCH')

    while True:
        status = aoss_client.list_collections(collectionFilters={'name':vector_store_name})['collectionSummaries'][0]['status']
        if status in ('ACTIVE', 'FAILED'): break
        time.sleep(10)

    access_policy = aoss_client.create_access_policy(
        name = access_policy_name,
        policy = json.dumps(
            [
                {
                    'Rules': [
                        {
                            'Resource': ['collection/' + vector_store_name],
                            'Permission': [
                                'aoss:CreateCollectionItems',
                                'aoss:DeleteCollectionItems',
                                'aoss:UpdateCollectionItems',
                                'aoss:DescribeCollectionItems'],
                            'ResourceType': 'collection'
                        },
                        {
                            'Resource': ['index/' + vector_store_name + '/*'],
                            'Permission': [
                                'aoss:CreateIndex',
                                'aoss:DeleteIndex',
                                'aoss:UpdateIndex',
                                'aoss:DescribeIndex',
                                'aoss:ReadDocument',
                                'aoss:WriteDocument'],
                            'ResourceType': 'index'
                        }],
                    'Principal': [identity],
                    'Description': 'Easy data policy'}
            ]),
        type = 'data'
    )
    collection_id = collection['createCollectionDetail']['id']
    collection_arn = collection['createCollectionDetail']['arn']
    host = collection['createCollectionDetail']['id'] + '.' + os.environ.get("AWS_DEFAULT_REGION", None) + '.aoss.amazonaws.com'
    return host, collection_id, collection_arn

In [None]:
host, collection_id, collection_arn = create_opensearch_serverless_collection(vector_store_name,
                                                                              index_name,
                                                                              encryption_policy_name,
                                                                              network_policy_name,
                                                                              access_policy_name)

## Creates documents and ingest into the opensearch serverless cluster.
First, we iterate through the documents in the 'data' folder and create a Document object for each txt file. 
Then we feed the documents to opensearch serverless for ingestion using an `OpenSearchVectorSearch` object supported by Langchain framework.

In [None]:
docs = []
for file in glob.glob(f"data/*.txt"): 
    with open(file, "r") as f:
        lines = f.readlines()
    movie_id = lines[0].split(":")[1].strip()
    title = lines[1].split(":")[1].strip()
    genres = lines[2].split(":")[1].strip()
    spoken_languages = lines[3].split(":")[1].strip()
    release_date = lines[4].split(":")[1].strip()
    rating = lines[5].split(":")[1].strip()
    if rating == "nan":
        rating = "0"
    cast = lines[6].split(":")[1].strip()
    overview = lines[7].split(":")[1].strip()
    doc = Document(
        page_content=f"{''.join(lines)}",
        metadata={
            "movie_id": movie_id,
            "rating": float(rating),
            "genres": genres.split(","),
            "spoken_languages": spoken_languages.split(","),
            "release_date": release_date,
            "cast" : cast.split(",")}
        )
    docs.append(doc)

In [None]:
boto3_credentials = boto3.Session().get_credentials() # needed for authenticating against opensearch cluster for index creation
region = boto3.client("sts").meta.region_name
service = "aoss"
auth = AWSV4SignerAuth(boto3_credentials, region, service)

Define an embedding model and an LLM. In our example, we'll use Amazon Titan Embedding model as the embedding model, and Llama2-7b chat model hosted in Amazon SageMaker. 

In [None]:
from langchain_community.embeddings import BedrockEmbeddings

embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")

In [None]:
vectorstore = OpenSearchVectorSearch.from_documents(
    docs,
    embeddings,
    index_name="opensearch-self-query-demo",
    opensearch_url=f"{host}:443",
    http_auth = auth,
    use_ssl = True,
    verify_certs = True,
    connection_class = RequestsHttpConnection,
    timeout = 100,
    engine="faiss")

Let's validate the vectorDB by using a vector store retriever.

In [None]:
relevant_documents = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}).get_relevant_documents("I want to watch an action movie with friendship and murder")

In [None]:
for doc in relevant_documents:
    print(doc.page_content)

In [None]:
def format_instructions(instructions: List[Dict[str, str]]) -> List[str]:
    """Format instructions where conversation roles must alternate user/assistant/user/assistant/..."""
    prompt: List[str] = []
    for user, answer in zip(instructions[::2], instructions[1::2]):
        prompt.extend(["<s>", "[INST] ", (user["content"]).strip(), " [/INST] ", (answer["content"]).strip(), "</s>"])
    prompt.extend(["<s>", "[INST] ", (instructions[-1]["content"]).strip(), " [/INST] "])
    return "".join(prompt)

Define a ContentHandler class for langchain LLM integration

In [None]:
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler

class SMLLMContentHandler(LLMContentHandler):
        content_type = "application/json"
        accepts = "application/json"

        def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
            input_data = json.dumps([[{"role" : "system", "content" : "You are a movie assistant."},
                                    {"role" : "user", "content" : prompt}]])
            input_str = json.dumps({"inputs" : input_data, "parameters" : {**model_kwargs}})
            return input_str.encode('utf-8')

        def transform_output(self, output: bytes) -> str:
            response_json = json.loads(output.read().decode("utf-8"))
            return response_json[0]["generated_text"]

TODO: Need to add screenshots for creating the SageMaker endpoint from Jumpstart

In [None]:
llm_endpoint_name = "jumpstart-llama2-7b-chat"
llm_inference_component_name = "meta-textgeneration-llama-2-7b-f-20240223-235028"

In [None]:
from langchain_community.llms import SagemakerEndpoint

region_name = "us-east-1"
model_params = { 
                    "do_sample": True,
                    "top_p": 0.9,
                    "temperature": 0.1,
                    "max_new_tokens": 1000,
                    "stop": ["<|endoftext|>", "</s>"],
                    "repetition_penalty": 1.1
               }

llm = SagemakerEndpoint(
    endpoint_name=llm_endpoint_name,
    region_name=region_name,
    content_handler = SMLLMContentHandler(),
    model_kwargs = model_params,
    endpoint_kwargs = {"InferenceComponentName" : llm_inference_component_name})

In [None]:
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory

prompt_template = """Given the following context and conversation history:

Context: 
{context}


Conversation History: 
{chat_history}

Answer the question as truthfully as possible. Your answer must only be coming from the context given above. If the answer is not found in the given context. Say "I don't know"
Your answer must be in a summary and direct in a concise manner. It's critical that you are only allowed to use the context given to you in answering the question.

User question: {question}""".strip()

PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["chat_history", "context", "question"]
)

memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer')


qa = ConversationalRetrievalChain.from_llm(llm=llm,
    memory=memory,
    retriever=vectorstore.as_retriever(
        search_type="similarity", search_kwargs={"k": 3}), 
    return_source_documents=True,
    combine_docs_chain_kwargs={"prompt": PROMPT})


In [None]:
query = "Which animation movies contain toys in the plot?"
result = qa({"question": query})
print(result['answer'])

In [None]:
query = "Who is Woody?"
result = qa({"question" : query})
print(result['answer'])

In [None]:
for idx, doc in enumerate(result['source_documents']):
    print(f"=== doc {idx+1} ====")
    print(doc.page_content.replace("\n", " "))

In [None]:
import ipywidgets as ipw
from IPython.display import display, clear_output

class ChatUX:
    """ A chat UX using IPWidgets
    """
    def __init__(self):
        memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer')
        self.qa = ConversationalRetrievalChain.from_llm(llm=llm,
            memory=memory,
            retriever=vectorstore.as_retriever(
                search_type="similarity", search_kwargs={"k": 3}), 
            return_source_documents=True,
            combine_docs_chain_kwargs={"prompt": PROMPT})

        self.name = None
        self.b=None
        self.out = ipw.Output()

    def start_chat(self):
        print("Let's chat!")
        display(self.out)
        self.chat(None)

    def chat(self, _):
        if self.name is None:
            prompt = ""
        else:
            prompt = self.name.value
        if 'q' == prompt or 'quit' == prompt or 'Q' == prompt:
            print("Thank you , that was a nice chat !!")
            return
        elif len(prompt) > 0:
            with self.out:
                thinking = ipw.Label(value=f"Thinking...")
                display(thinking)
                try:
                    response = self.qa({"question" : prompt})
                    result = response['answer']

                except Exception as e:
                    print(e)
                    result = "No answer"
                thinking.value=""
                print(f"AI: {result}")
                self.name.disabled = True
                self.b.disabled = True
                self.name = None

        if self.name is None:
            with self.out:
                self.name = ipw.Text(description="You: ", placeholder='q to quit')
                self.b = ipw.Button(description="Send")
                self.b.on_click(self.chat)
                display(ipw.Box(children=(self.name, self.b)))

## Sample Questions
* What's the movie "Jumanji" all about?
* When was this movie released?
* Who were the actors in this movie?
* What movie would you recommend me watch after watching this movie?
* What other movies were released in the same year?

In [None]:
chat = ChatUX()
chat.start_chat()