# Retrieval Augmented Generation with SageMaker

Text to text Generative AI models have a well documented problem which is the issue of having only information up to the date for which they were trained. This notebook shows how to use retrieval augmented generation (RAG), otherwise known as data augmented generation, to help suppliment text generation models with up to date information via document search. We will use two different models to do this. First, we will use the HuggingFace FLAN T5 for document and question embedding. Second, we will use AI21 Lab's Jurassic Instructor Jumbo model for text generation. 

**Please note: this notebook requires access to the foundation models in SageMaker Jumpstart which is in private preview at the time of writing.**

# Setup Environment

We will install a few libraries and import necessary packages for the notebook. We will use the `transformers` library to produce our embeddings and the `ai21` lab to interact with the Jurassic model

In [None]:
!pip install setuptools~=46.0.0 --quiet
!pip install "ai21[SM]" --quiet
!pip install torch transformers --quiet

In [None]:
import transformers
from transformers import AutoTokenizer, T5EncoderModel
import torch
import ai21
import pandas as pd
import numpy as np

from sagemaker import ModelPackage
from sagemaker import get_execution_role
from sagemaker import ModelPackage
import sagemaker

import json
import boto3
import requests
import logging


region = boto3.Session().region_name
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())
url = "https://nm3yyjazj1.execute-api.us-east-1.amazonaws.com/Prod/invoke"

role = get_execution_role()
sagemaker_session = sagemaker.Session()
runtime_sm_client = boto3.client("runtime.sagemaker")

In [None]:
def query_endpoint_with_json_payload(url, payload):
    response = requests.post(
        url,
        json=payload,
    )
    #print(payload)
    return response

def parse_response_multiple_texts(query_response):
    model_predictions = query_response.json()
    #print(query_response)
    generated_text = model_predictions['message']
    return generated_text

# Deploy Jurassic Model to SageMaker Endpoint

## Optional - Retrieve & deploy the jurrasic image for deployment 
The first step is to set up a SageMaker session and collect the Jurassic Jumbo Instruct model ARN. Use the cells below to deploy the model. In this particular use case we have already deployed and exposed a jurrasic model via an API GW so we can skip the cells below. 

In [None]:
model_package_map = {
    "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-jumbo-instruct-v1-0-20-8b2be365d1883a15b7d78da7217cdeab",
}
region = boto3.Session().region_name
model_package_arn = model_package_map[region]


## Deploy Endpoint

You can now deploy the Jurassic model to a SageMaker endpoint in order to send text into the model in real time.

In [None]:
model_name = "j2-jumbo-instruct"
content_type = "application/json"
real_time_inference_instance_type = (
    "ml.g5.48xlarge"
)

In [None]:
# create a deployable model from the model package.
model = ModelPackage(
    role=role, model_package_arn=model_package_arn, sagemaker_session=sagemaker_session
)

In [None]:
# uncomment the bwlow lines to deploy the model. In this use case we have already deployed a jurrasic mid model and exposed via API GW.
# Deploy the model
#predictor = model.deploy(
#    1, real_time_inference_instance_type, endpoint_name=model_name, 
#    model_data_download_timeout=3600,
#    container_startup_health_check_timeout=600,
#)

# The Hallucination Issue

Now that we have an endpoint up and running, the example below shows how the Jurassic model "hallucinates" that France won the 2022 world cup. The actual fact is that Argentina won in 2022 and France won in 2018. Here in lies the problem we need to fix with RAG.

In [None]:

prompt = f'''Answer the following question.
Question: Who won the 2022 world cup?
Answer:
'''

payload = {"prompt":prompt, "max_token":200, "temperature": 0}
response = query_endpoint_with_json_payload(url, payload)

generated_text = parse_response_multiple_texts(response)
logger.info(f'Generated text: \n{generated_text}')

#response = ai21.Completion.execute(
#    sm_endpoint="j2-jumbo-instruct",
#    prompt=prompt,
#    maxTokens=200,
#    temperature=0,
#    numResults=1
#)
#print(response['completions'][0]['data']['text'])



In this cell promt engineering with adding the line `If you do not have the information to answer the question, say "I don't know".` to the prompt produces the answer of "I don't know" which is better than producing a wrong answer.

In [None]:
prompt = f'''Answer the following question. If you do not have the information to answer the question, say "I don't know".
Question: Who won the 2022 world cup?
Answer:
'''

payload = {"prompt":prompt, "max_token":200, "temperature": 0}
response = query_endpoint_with_json_payload(url, payload)

generated_text = parse_response_multiple_texts(response)
logger.info(f'Generated text: \n{generated_text}')

# Get a HuggingFace Model for Embeddings

Load in the [FLAN T5 large model](https://huggingface.co/google/flan-t5-large) from HuggingFace. This will be the model we use to create our document search embeddings.

In [None]:
# this will download the flan t5 tokenizer for use of embedding. This may take a few moments to download. 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
model = T5EncoderModel.from_pretrained("google/flan-t5-large").to(DEVICE)

# Create Embedding Database

Now use the HuggingFace model to create embeddings for each of the three documents which have been provided. The documents used here are only illustrative. These documents could be extended to any collection of text to help supplement your use case.

In [None]:
def get_embedding(text, model, tokenizer):
    with torch.no_grad():
        input_ids = tokenizer(
            text, return_tensors="pt", max_length=512, truncation=True
        ).input_ids.to(DEVICE)
        outputs = model(input_ids=input_ids)
        last_hidden_states = outputs.last_hidden_state
        e = last_hidden_states.mean(dim=1)
    return e

# helper function to create the document embeddings from a document
def create_doc_database(docs, model, tokenizer):
    database = []
    for i in range(docs.shape[0]):
        text = docs['title'].values[i] + ' - ' + docs['document'].values[i]
        e = get_embedding(text, model, tokenizer)
        database.append(e)
    database = torch.cat(database)
    return database

In [None]:
docs = pd.read_csv('document-corpus.txt', delimiter="::: ", engine='python')
docs

### Embeddings 

So far we have taken the document corpus which is csv file containing two rows with results from world cup 2022, champions league and ballon dor. We are simply using the document text and passing it to the tokenizer to get the embeddings. 

The end result is a vector representation of each document which has 1024 floating points. We will use this vector embedding to load into a database for search 

In [None]:
database = create_doc_database(docs, model, tokenizer)

In [None]:
database.shape

# Create Search Ability

Now that you have a database of embeddings, we can search the database against a text input `"Who won the 2022 world cup?"` to see which document is most relevant to the question by looking at the dot product of the embeddings. Behind the scene, we have to convert to query/search request to a vector embedding using the same tokenizer used to create the document embedding and search for similarities. 

In [None]:
def search_database(search_embedding, database):
    similarities = []
    for i in range(database.shape[0]):
        similarities.append(
            float(torch.dot(search_embedding[0], database[i]))
        )
    return np.argmax(similarities), similarities

In [None]:
search = 'Who won the 2022 world cup?'
search_embedding = get_embedding(search, model, tokenizer)
doc_index, similarities = search_database(search_embedding, database)
print(f"Input: {search}\nWas matched with document #{doc_index} which is titled \"{docs.loc[doc_index]['title']}\"")

# Dynamically Engineer the Prompt

Now that we have a user input matched with a relevant document, we can engineer a prompt which includes both the question and context from the document.

In [None]:
prompt_eng_base = '''Answer the following question with the following context. If you do not have the information to answer the question, say "I don't know".

Context: [PLACE DOC HERE]

Question: [PLACE QUESTION HERE]
Answer:
'''

In [None]:
def make_prompt(search, context, prompt_eng_base):
    prompt = prompt_eng_base.replace('[PLACE DOC HERE]', context)
    prompt = prompt.replace('[PLACE QUESTION HERE]', search)
    return prompt

In [None]:
prompt_custom = make_prompt(search, docs.loc[doc_index]['document'], prompt_eng_base)

# Wrap the RAG Flow into a Function

In [None]:
base_prompt = f'''Answer the following question. If you do not have the information to answer the question, say "I don't know".

Question: [SEARCH HERE]
Answer:
'''

def rag_demo(search, use_search=True):
    search_embedding = get_embedding(search, model, tokenizer)
    doc_index, similarities = search_database(search_embedding, database)
    if use_search:
        prompt_custom = make_prompt(search, docs.loc[doc_index]['document'], prompt_eng_base)
    else:
        prompt_custom = base_prompt.replace('[SEARCH HERE]', search)
        
    payload = {"prompt":prompt_custom, "max_token":200, "temperature": 0}
    response = query_endpoint_with_json_payload(url, payload)

    generated_text = parse_response_multiple_texts(response)

    return generated_text

# Example Outputs

The outputs below show how you can now get relevant information to the model in order to give informed responses back to the user!

In [None]:
out = rag_demo('Who won the 2022 world cup?', use_search=False)
print(out)

In [None]:
out = rag_demo('Who won the 2022 world cup?', use_search=True)
print(out)

# Suggested Next Steps

* Explore libraries which can help with this kind of workflow. See: [LangChain](https://github.com/hwchase17/langchain)
* Bring your own documents or information to this workflow to explore creating RAG based systems.
* Look into fine tuning your embedding model to produce better searching.
* Integrate this RAG flow with integrations to your own search capabilities.

# Cleanup 

Uncomment the cells below to delete the endpoint if an endpoint was deployed. 

In [None]:
#j2 = sagemaker.predictor.Predictor('j2-jumbo-instruct"')
#j2.delete_model()
#j2.delete_endpoint()