# Retrieval Augmented Generation and Chatbot Application

Many use cases such as building a chatbot require text (text2text) generation models to respond to user questions with insightful answers. The leading LLM models have picked up a lot of general knowledge in training, but we often need to ingest and use a large library of more specific information.

In this notebook we will demonstrate how to use a LLM to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from Huggingface embedding model. 



## 1. Quick introduction to LangChain and why is it useful for RAG based applications?

LangChain is a framework for developing applications powered by language models. 

**LLMs are powerful by themselves. Why do we need libraries like LangChain?**

While LLMs are powerful, they are also general in nature (& thus kinda boring & limited). While LLMs can perform many tasks effectively, they are not able to provide specific answers to questions or tasks that require deep domain knowledge or expertise. For example, imagine you want to use an LLM to answer questions about a specific field, like medicine or law. While the LLM may be able to answer general questions about the field, it may not be able to provide more detailed or nuanced answers that require specialized knowledge or expertise. To work around this limitation, LangChain offers a useful approach where the corpus of text is preprocessed by breaking it down into chunks or summaries, embedding them in a vector space, and searching for similar chunks when a question is asked. LangChain also provides a level of abstraction, making it super easy to use. LangChain's popularity has grown exponentially since it was first introduced and being an open source library, it is constantly evolving!

**How do we solve for this problem?**

RAG - Retrieval Augmented Generation. Is a design pattern that enterprise customers could leverage to bring domain context and their artifacts (securely) while leveraging LLMs to answer questions. 

### 1.1 Key components of LangChain

Let us examine the key components of Langchain. These components are, in increasing order of complexity:

#### Models

Building blocks to interface with any language model. LangChain does not serve its own LLMs, but rather provides a standard interface for interacting with many different LLMs, including integration with LLM hosted on SageMaker Endpoint.

<img src='./images/models.png' max-width ="1080"/>

    
#### Prompts

The new way of programming models is through prompts. A "prompt" refers to the input to the model. This input is rarely hard coded, but rather is often constructed from multiple components. A PromptTemplate is responsible for the construction of this input. LangChain provides several classes and functions to make constructing and working with prompts easy, such as:

- Prompt templates: Parameterized model inputs
- Example selectors: Dynamically select examples to include in prompts
    
<img src="images/prompt.png" max-width="1080"/>

#### Indexes

Many LLM applications require user-specific data that is not part of the model's training set. Indexes refer to ways to structure documents so that LLMs can best interact with them. LangChain gives you the building blocks to load, transform, store and query your data via:

- Document loaders: Load documents from many different sources
- Document transformers: Split documents, convert documents into Q&A format, drop redundant documents, and more
- Text embedding models: Take unstructured text and turn it into a list of floating point numbers
- Vector stores: Store and search over embedded data
- Retrievers: Query your data
        
The primary index and retrieval types supported by LangChain are currently centered around vector databases, but it can also interact with structured data (SQL tables, etc) or external APIs.

<img src="images/vectorstore.png" max-width="1080"/>

#### Memory

Most LLM applications have a conversational interface. Memory is the concept of storing and retrieving data in the process of a conversation. There are two main methods:

#### Chains

Using an LLM in isolation is fine for simple applications, but more complex applications require chaining LLMs - either with each other or with other components.

LangChain provides the Chain interface for such "chained" applications. LangChain provides a standard interface for chains, lots of integrations with other tools, and end-to-end chains for common applications. 

We define a Chain very generically as a sequence of calls to components, which can include other chains. Chains allow us to combine multiple components together to create a single, coherent application. For example, we can create a chain that takes user input, formats it with a PromptTemplate, and then passes the formatted response to an LLM. We can build more complex chains by combining multiple chains together, or by combining chains with other components.

<img src="images/chains.png" max-width="1080"/>

#### Agents

The core idea of agents is to use an LLM to choose a sequence of actions to take. In chains, a sequence of actions is hardcoded (in code). In agents, a language model is used as a reasoning engine to determine which actions to take and in which order.
    
#### Callbacks

It can be difficult to track all that occurs inside a chain or agent. Callbacks help add a level of observability and introspection.
 
    

### 1.2 Chat Bot key elements

The first process in a chat bot is to generate embeddings. Typically you will have an ingestion process which will run through your embedding model and generate the embeddings which will be stored in a sort of a vector store. In this example we are using a GPT-J embeddings model for this

<img src="images/Embeddings_lang.png" max-width="1080"/>

Second process is the user request orchestration , interaction,  invoking and returing the results

<img src="images/Chatbot_lang.png" max-width="1080"/>

For processes which need deeper analysis, conversation history we will need to summarize every interaction to keep it succinct and for that we can follow this flow below which uses PineCone as an example for the various Tools which are available 

<img src="images/chatbot_internet.jpg" width="1080"/>

### 1.3 Key points for consideration

1. If we have long Document that exceed the LLM token limit, consider using Chains interface to process documents: Map Reduce, Refine, Map-Rerank
2. To optimize cost per token -- minimize the tokens and send in only relevant tokens to Model
3. Which model to use --
    - Cohere, AI21, Huggingface Hub, Manifest, Goose AI, Writer, Banana, Modal, StochasticAI, Cerebrium, Petals, Forefront AI, Anthropic, DeepInfra, and self-hosted Models.
    - Example LLM Cohere = Cohere(model='command-xlarge')
    - Example LLM Flan = HuggingFaceHub(repo_id="google/flan-t5-xl")
4. Input data sources could be PDF, WebPages, CSV, S3, EFS
5. Orchestration with external tasks
    - External tasks - Agent SerpApi, SEARCH Engines
    - Math calculator
6. Conversation management and history

## 2. Pre-Requisites

There are a few pre-reqs to be completed when running this notebook. The key one being setting up the LLM to be used.

### 2.1 Install certain libraries which are needed for this run. 

These are provided in the requirements.txt or you can run these cells to fine control which libraries you need

In [None]:
!apt update

In [None]:
!apt install wkhtmltopdf -y

In [None]:
!pip install --upgrade pip

In [None]:
# !pip install chromadb==0.3.21 --quiet
!pip install langchain==0.0.161 boto3 html2text jinja2 --quiet
!pip install faiss-cpu==1.7.4 --quiet
!pip install pypdf==3.8.1 --quiet
!pip install transformers==4.24.0 --quiet
!pip install sentence_transformers==2.2.2
print("all libraries installed")

In [None]:
import sentence_transformers 
sentence_transformers.__version__

### 2.2 Import statements for our chain and indexers

In [None]:
#from aws_langchain.kendra_index_retriever import KendraIndexRetriever
from langchain.chains import ConversationalRetrievalChain
from langchain import SagemakerEndpoint
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.prompts import PromptTemplate
import sys
import json
import os
import time
import sagemaker, boto3, json
from sagemaker.session import Session
from sagemaker.model import Model
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base
from typing import Any, Dict, List, Optional
import jinja2


In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
role

In [None]:
%store -r endpoint_name

In [None]:
os.environ["LLM_ENDPOINT"]=endpoint_name
os.environ["REGION"]='us-west-2' # change this if needed
print(os.environ["LLM_ENDPOINT"])
print(os.environ["REGION"])

## 3. Topics covered in this lab:

In this notebook we will be covering the below topics:

1. **LLM** &#8594; Examine asking an LLM without providing context
1. **Prompt Engineering** &#8594; Improving the answer by providing insightful context
1. **RAG based approach** &#8594; Use vector DB and prompt template to build question answering application with Retrieval Augmented Generation (RAG) approach
1. **Chatbot** &#8594; Build a Interactive Chatbot with Memory 

### 3.1 LLM

To better illustrate why we need retrieval-augmented generation (RAG) based approach to solve the question and anwering problem. Let's directly ask the model a question and see how they respond. 

Make sure that you have ran the Notebook `1_deploy-falcon.ipynb` and deploy Falcon-7B model to SageMaker Endpoint.

In [None]:
# These are hyper-parameters; Hyperparameters are used before inferencing a model because they have a
# direct impact on the performance of the resulting machine learning model. 
# Hyperparameters are used before inferencing a model because they control the behavior of the model, 
# and optimize its performance for the job at hand.
# For this workshop, hyper parameters have been identified for you. 
# If you like, you can use some of these in the code below.
# They will impact the behavior of your LLM response. 

parameters = {
    "max_new_tokens": 300,
    "num_return_sequences": 1,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": False,
    "return_full_text": True,
    "temperature": 0.2,
    "stop": ['\n'],
    "return_full_text": False
}

In [None]:
boto3_kwargs = {}
session = boto3.Session()

boto3_sm_client = boto3.client("sagemaker-runtime")
print(boto3_sm_client)

question = "Which instances can I use with Managed Spot Training in SageMaker?"
prompt = f"Answer this question below, {question}"
print(f"Question being asked is -- > {prompt}:")

payload = {"inputs": prompt, "parameters": parameters}

payload = json.dumps(payload).encode('utf-8')

boto3_sm_client.invoke_endpoint(
    EndpointName=os.environ["LLM_ENDPOINT"],
    Body=payload,
    ContentType="application/json",
)["Body"].read().decode("utf8")

### 3.2 Prompt Engineering

To better answer the question well, we provide extra contextual information, combine it with a prompt, and send it to model together with the question. Below is an example.


In [None]:
question = "Which instances can I use with Managed Spot Training in SageMaker?"
context = """Managed Spot Training can be used with all instances supported in Amazon SageMaker. Managed Spot Training is supported in all AWS Regions where Amazon SageMaker is currently available."""

prompt = f""""Context: {context}\n\nQuestion: {question}\n\nAnswer:"""
print(f"Question being asked is -- > {prompt}")

payload = {"inputs": prompt, "parameters": parameters}

payload = json.dumps(payload).encode('utf-8')

boto3_sm_client.invoke_endpoint(
    EndpointName=os.environ["LLM_ENDPOINT"],
    Body=payload,
    ContentType="application/json",
)["Body"].read().decode("utf8")

The output is already significantly better than asking without providing any context. 

Now, the question becomes where can I find the insightful context based on the user query? The answer is to use a pre-stored knowledge data base with retrieval augmented generation, as shown below.

### 3.3 RAG

We plan to use document embeddings to fetch the most relevant documents in our document knowledge library and combine them with the prompt that we provide to LLM.

To achieve that, we will do following.

1. Generate embedings for each of document in the knowledge library with HuggingFace embedding model.
2. Identify top K most relevant documents based on user query.
    - 2.1 For a query of your interest, generate the embedding of the query using the same embedding model.
    - 2.2 Search the indexes of top K most relevant documents in the embedding space using in-memory Faiss search.
    - 2.3 Use the indexes to retrieve the corresponded documents.
3. Combine the retrieved documents with prompt and question and send them into SageMaker LLM.



Note: The retrieved document/text should be large enough to contain enough information to answer a question; but small enough to fit into the LLM prompt.

<img src='./images/rag.jpg' max-width ="1080"/>

First, let's prepare by wrapping up our LLM into `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. 

In [None]:
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
import ast

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]


content_handler = ContentHandler()

sm_llm = SagemakerEndpoint(
    endpoint_name=os.environ["LLM_ENDPOINT"],
    region_name=os.environ["REGION"],
    model_kwargs=parameters,
    content_handler=content_handler,
)

print(f"SageMaker LLM created at {sm_llm}::")

Now, let's download the example data and prepare it for demonstration. We will use [Amazon SageMaker FAQs](https://aws.amazon.com/sagemaker/faqs/) as knowledge library. The data are formatted in a CSV file with two columns Question and Answer. We use the Answer column as the documents of knowledge library, from which relevant documents are retrieved based on a query.

In [None]:
original_data = "s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/"

!mkdir -p rag_data
!aws s3 cp --recursive $original_data rag_data

For the case when you have data saved in multiple subsets. The following code will read all files that end with .csv and concatenate them together. Please ensure each csv file has the same format.

In [None]:
import glob
import os
import pandas as pd

all_files = glob.glob(os.path.join("rag_data/", "*.csv"))

df_knowledge = pd.concat(
    (pd.read_csv(f, header=None, names=["Question", "Answer"]) for f in all_files),
    axis=0,
    ignore_index=True,
)

df_knowledge.head(5)

Drop the `Question` column as it is not used in this workshop.

In [None]:
df_knowledge.drop(["Question"], axis=1, inplace=True)
df_knowledge.to_csv("rag_data/processed.csv", header=False, index=False)

Use langchain to read the csv data. There are multiple built-in functions in LangChain to read different format of files such as txt, html, and pdf. For details, see [LangChain document loaders](https://python.langchain.com/en/latest/modules/indexes/document_loaders.html).

In [None]:
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import Chroma, AtlasDB, FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.document_loaders.csv_loader import CSVLoader

loader = CSVLoader(file_path="rag_data/processed.csv")
documents = loader.load()

Next, let's generate embeddings for our docs, and store it in a LangChain VectorStore. Embeddings are a way to represent words, phrases or any other discrete items as vectors in a continuous vector space. This allows machine learning models to perform mathematical operations on these representations and capture semantic relationships between them. We'll start with initializing a HuggingFace Embeddings Model.

In [None]:
# Initialize the Huggingface Embeddings Model
from langchain.embeddings import HuggingFaceEmbeddings

hf_embeddings = HuggingFaceEmbeddings()

We will store and match the embeddings using the VectorStore indexer. In this notebook, we will showcase [FAISS](https://github.com/facebookresearch/faiss) which will be transient and in memory.

In [None]:
index_creator = VectorstoreIndexCreator(
    vectorstore_cls=FAISS,
    embedding=hf_embeddings,
    text_splitter=CharacterTextSplitter(chunk_size=300, chunk_overlap=0),
)

In [None]:
index = index_creator.from_loaders([loader])

Now it's easy to pull context from our data stores to answer prompt. We can simply use the query method on the created index and pass the user’s question and SageMaker endpoint LLM. LangChain selects the top four closest documents (K=4) and passes the relevant context extracted from the documents to generate an accurate response.

In [None]:
question="Which instances can I use with Managed Spot Training in SageMaker?"
index.query(question=question, llm=sm_llm)

The response looks more accurate compared to the response we got with other approaches that we demonstrated earlier that have no context or static context that may not be always relevant.

#### RAG alternative approach

Alternatively, we can do manual approach that will let us do more customization. This approach offers the flexibility to configure top K parameters for a relevancy search in the documents. It also allows you to use the LangChain feature of [prompt templates](https://python.langchain.com/en/latest/modules/prompts/prompt_templates.html), which allow you to easily parameterize the prompt creation instead of hard coding the prompts.

First, we generate embedings for each of document in the knowledge library.

In [None]:
docsearch = FAISS.from_documents(documents, hf_embeddings)

Based on the question above, we then identify top K most relevant documents based on user query, where K = 3 in this setup.

In [None]:
question="Which instances can I use with Managed Spot Training in SageMaker?"

docs = docsearch.similarity_search(question, k=3)
docs

Finally, we use a prompt template and chain it with the SageMaker LLM.

In [None]:
prompt_template = """{context}\n\nGiven the above context, answer the following question:\n{question}\nAnswer: """
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain = load_qa_chain(llm=sm_llm, prompt=PROMPT)

In [None]:
result = chain({"input_documents": docs, "question": question}, return_only_outputs=True)[
    "output_text"
]

result

With this approach of RAG implementation, we were able to take advantage of the additional flexibility of LangChain prompt templates and customize the number of documents searched for a relevancy match using the top K hyperparameter.

### 3.4 Chatbot

Chatbots needs to remember the previous interactions. Conversational memory allows us to do that. There are several ways that we can implement conversational memory.  In the context of LangChain, they are all built on top of the ConversationChain.

Let's start with learning how to make a simple chatbot.

In [None]:
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory

memory = ConversationBufferMemory()
conversation = ConversationChain(
    llm=sm_llm, memory=memory
)


In [None]:
print(conversation.predict(input="What is Indonesia?"))

The model has responded, now let's ask follow-up question.

In [None]:
print(conversation.predict(input="What is it famous for?"))

We can see that the model can understand the previous conversation. Now you can clear the memory if you want to.

In [None]:
memory.clear()

Now let's build up on QA with RAG capability that we've done before.

For our chatbot, we will use [ConversationalRetrievalChain](https://api.python.langchain.com/en/latest/chains/langchain.chains.conversational_retrieval.base.ConversationalRetrievalChain.html) chain to take in chat history (a list of messages) and new questions, and then returns an answer to that question based on retrieved documents.

In [None]:
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT


def create_prompt_template():
    _template = """{chat_history}

Answer only with the new question.
How would you ask the question considering the previous conversation: {question}
Question:"""
    CONVO_QUESTION_PROMPT = PromptTemplate.from_template(_template)
    return CONVO_QUESTION_PROMPT

memory_chain = ConversationBufferMemory(memory_key="chat_history", input_key="question", return_messages=True)
chat_history=[]

In [None]:
# Parameters for ConversationRetrievalChain

# retriever: We used VectoreStoreRetriver, which is backed by a VectorStore. To retrieve text, there are two search types you can choose: search_type: “similarity” or “mmr”. search_type="similarity" uses similarity search in the retriever object where it selects text chunk vectors that are most similar to the question vector.

# memory: Memory Chain to store the history

# condense_question_prompt: Given a question from the user, we use the previous conversation and that question to make up a standalone question

# chain_type: If the chat history is long and doesn't fit the context you use this parameter and the options are "stuff", "refine", "map_reduce", "map-rerank". Look up at docs for LangChain Chains Documents to learn the difference

# verbose: Set to true to see the full logs and documents 

qa = ConversationalRetrievalChain.from_llm(
    llm=sm_llm, 
    retriever=docsearch.as_retriever(), 
    #retriever=docsearch.as_retriever(search_type='similarity', search_kwargs={"k": 8}),
    memory=memory_chain,
    #verbose=True,
    #condense_question_prompt=CONDENSE_QUESTION_PROMPT, # create_prompt_template(), 
    chain_type='stuff', # 'refine',
    #max_tokens_limit=100
)

qa.combine_docs_chain.llm_chain.prompt = PromptTemplate.from_template("""
{context}

Use at maximum 3 sentences to answer the question inside the <q></q> XML tags. 

<q>{question}</q>

Do not use any XML tags in the answer. If the answer is not in the context say "Sorry, I don't know, as the answer was not found in the context."

Answer:""")

Now let's build a utility interface for our chatbot. 

In [None]:
%pip install --quiet "ipywidgets>=7,<8"

In [None]:
import ipywidgets as ipw
from IPython.display import display, clear_output

class ChatUX:
    """ A chat UX using IPWidgets
    """
    def __init__(self, qa, retrievalChain = False):
        self.qa = qa
        self.name = None
        self.b=None
        self.retrievalChain = retrievalChain
        self.out = ipw.Output()


    def start_chat(self):
        print("Starting chat bot")
        display(self.out)
        self.chat(None)


    def chat(self, _):
        if self.name is None:
            prompt = ""
        else: 
            prompt = self.name.value
        if 'q' == prompt or 'quit' == prompt or 'Q' == prompt:
            print("Thank you , that was a nice chat !!")
            return
        elif len(prompt) > 0:
            with self.out:
                thinking = ipw.Label(value="Thinking...")
                display(thinking)
                try:
                    if self.retrievalChain:
                        result = self.qa.run({'question': prompt })
                    else:
                        result = self.qa.run({'input': prompt }) #, 'history':chat_history})
                except:
                    result = "No answer"
                thinking.value=""
                print(f"AI:{result}")
                self.name.disabled = True
                self.b.disabled = True
                self.name = None

        if self.name is None:
            with self.out:
                self.name = ipw.Text(description="You:", placeholder='q to quit')
                self.b = ipw.Button(description="Send")
                self.b.on_click(self.chat)
                display(ipw.Box(children=(self.name, self.b)))

Run our chatbot.

In [None]:
chat = ChatUX(qa, retrievalChain=True)
chat.start_chat()