# End-to-end demo with MTRAG benchmark data

This notebook shows several examples of end-to-end RAG use cases that use the retrieval
IO processor in conjunction with the IO processors for other Granite-based LoRA 
adapters.

This notebook can run its own vLLM server to perform inference, or you can host the 
models on your own server. To use your own server, set the `run_server` variable below
to `False` and set appropriate values for the constants in the cell marked
`# Constants go here`.

In [None]:
import pathlib
from granite_io.io.granite_3_2.input_processors.granite_3_2_input_processor import (
    Granite3Point2Inputs,
)
from granite_io import make_io_processor, make_backend
from granite_io.io.base import RewriteRequestProcessor
from granite_io.io.retrieval.util import download_mtrag_embeddings
from granite_io.io.retrieval import InMemoryRetriever, RetrievalRequestProcessor
from granite_io.io.answerability import (
    AnswerabilityIOProcessor,
    AnswerabilityCompositeIOProcessor,
)
from granite_io.io.query_rewrite import QueryRewriteIOProcessor
from granite_io.io.citations import CitationsCompositeIOProcessor
from granite_io.io.hallucinations import HallucinationsCompositeIOProcessor
from granite_io.backend.vllm_server import LocalVLLMServer
from IPython.display import display, Markdown
import pandas as pd
import os

In [None]:
# Constants go here
temp_data_dir = "../data/test_retrieval_temp"
corpus_name = "govt"
embeddings_data_file = pathlib.Path(temp_data_dir) / f"{corpus_name}_embeds.parquet"
embedding_model_name = "multi-qa-mpnet-base-dot-v1"
model_name = "ibm-granite/granite-3.2-8b-instruct"
query_rewrite_lora_name = "ibm-granite/granite-3.2-8b-lora-rag-query-rewrite"
citations_lora_name = "ibm-granite/granite-3.2-8b-lora-rag-citation-generation"
answerability_lora_name = "ibm-granite/granite-3.2-8b-lora-rag-answerability-prediction"
hallucination_lora_name = "ibm-granite/granite-3.2-8b-lora-rag-hallucination-detection"

run_server = False

In [None]:
if run_server:
    # Start by firing up a local vLLM server and connecting a backend instance to it.
    server = LocalVLLMServer(
        model_name,
        lora_adapters=[
            (lora_name, lora_name)
            for lora_name in (
                query_rewrite_lora_name,
                citations_lora_name,
                answerability_lora_name,
                hallucination_lora_name,
            )
        ],
    )
    server.wait_for_startup(200)
    query_rewrite_lora_backend = server.make_lora_backend(query_rewrite_lora_name)
    citations_lora_backend = server.make_lora_backend(citations_lora_name)
    answerability_lora_backend = server.make_lora_backend(answerability_lora_name)
    backend = server.make_backend()
else:  # if not run_server
    # Use an existing server.
    # The constants here are for the server that local_vllm_server.ipynb starts.
    # Modify as needed.
    openai_base_url = "http://localhost:55555/v1"
    openai_api_key = "granite_intrinsics_1234"
    backend = make_backend(
        "openai",
        {
            "model_name": model_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    query_rewrite_lora_backend = make_backend(
        "openai",
        {
            "model_name": query_rewrite_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    citations_lora_backend = make_backend(
        "openai",
        {
            "model_name": citations_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    answerability_lora_backend = make_backend(
        "openai",
        {
            "model_name": answerability_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    hallucination_lora_backend = make_backend(
        "openai",
        {
            "model_name": hallucination_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )

In [None]:
# Download the indexed corpus if it hasn't already been downloaded.
# This notebook uses a subset of the government corpus from the MTRAG benchmark.
embeddings_location = f"{temp_data_dir}/{corpus_name}_embeds.parquet"
if not os.path.exists(embeddings_location):
    download_mtrag_embeddings(embedding_model_name, corpus_name, embeddings_location)
embeddings_location

We start by creating an example chat completion request. 

*TODO: Insert description of the chat*

In [None]:
# Create an example chat completions request.
chat_input = Granite3Point2Inputs.model_validate(
    {
        "messages": [
            {
                "role": "assistant",
                "content": "Welcome to the California State Parks help desk.",
            },
            {
                "role": "user",
                "content": "I'm a student. Do you have internships?",
            },
            {
                "role": "assistant",
                "content": "The California State Parks hires Student Assistants "
                "to perform a variety of tasks that require limited or no previous "
                "work experience.",
            },
            {"role": "user", "content": "Cool, how do I sign up?"},
        ],
        "generate_inputs": {
            "temperature": 0.0,
            "max_tokens": 4096,
        },
    }
)

# Alternate conversation that can also be used with the code here:
# chat_input = Granite3Point2Inputs.model_validate(
#     {
#         "messages": [
#             {
#                 "role": "assistant",
#                 "content": "Welcome to the Alameda County Tourism help desk.",
#             },
#             {
#                 "role": "user",
#                 "content": "I'm in downtown Dublin, and I like to visit old houses. "
#                 "Is there a good one to visit?",
#             },
#             {
#                 "role": "assistant",
#                 "content": "You might want to visit the Kolb House.",
#             },
#             {"role": "user", "content": "Where is it?"},
#         ],
#         "generate_inputs": {
#             "temperature": 0.0,
#             "max_tokens": 4096,
#         },
#     }
# )

Let's start by passing the chat completion request directly to the language model,
without using retrieval-augmented generation.

In [None]:
# Spin up an IO processor for the base model
io_proc = make_io_processor(model_name, backend=backend)

# Use the IO processor to generate a chat completion
non_rag_result = io_proc.create_chat_completion(chat_input)
display(Markdown(non_rag_result.results[0].next_message.content))

*TODO: Insert description of hallucinations in the response*

*TODO: Insert link to web page with correct information*

Now let's spin up an in-memory vector database, using embeddings that we've precomputed
offline from the MTRAG benchmark's government corpus.

In [None]:
# Spin up an in-memory vector database
retriever = InMemoryRetriever(embeddings_data_file, embedding_model_name)

In [None]:
# The vector database fetches document snippets that match a given query.
# For example, the user's question in the conversation above:
print(f"Query is: '{chat_input.messages[-1].content}'")
print("Matching document snippets:")
pd.set_option("display.max_colwidth", 120)
retriever.retrieve(chat_input.messages[-1].content, top_k=3).to_pandas()

We attach a RequestProcessor to our vector database so that we can augment the chat 
completion request with retrieved document snippets.

In [None]:
retrieval_request_proc = RetrievalRequestProcessor(retriever, top_k=3)
chat_input_with_docs = retrieval_request_proc.process(chat_input)[0]
{
    k: v
    for k, v in chat_input_with_docs.model_dump().items()
    if k in ("messages", "documents")
}

The retriever operates over the last user turn.

*TODO: Describe critical information that is not in the last turn*

The snippets retrieved are not specific to the user's intended question.

Let's see what happens if we run our request through the model using the low-quality 
RAG snippets from the previous cell.


In [None]:
rag_result = io_proc.create_chat_completion(chat_input_with_docs)
display(Markdown(rag_result.results[0].next_message.content))

*TODO: Describe what is wrong about this response*

We can use the [LoRA Adapter for Answerability Classification](
    https://huggingface.co/ibm-granite/granite-3.2-8b-lora-rag-answerability-prediction)
to detect this kind of problem. Here's what happens if we run the chat completion 
request with faulty documents snippets through the answerability model, using the
`granite_io` IO processor for the model to handle input and output:

In [None]:
answerability_proc = AnswerabilityIOProcessor(answerability_lora_backend)
answerability_proc.create_chat_completion(chat_input_with_docs).results[
    0
].next_message.content

The answerability model detects that the documents we have retrieved cannot be used to
answer the user's question. We use use a composite IO processor to wrap this check in
a flow that falls back on canned response.

In [None]:
answerability_composite_proc = AnswerabilityCompositeIOProcessor(
    io_proc, answerability_proc
)
composite_result = answerability_composite_proc.create_chat_completion(
    chat_input_with_docs
).results[0]
print(composite_result.next_message.content)

We can use use the [LoRA Adapter for Query Rewrite](
    https://huggingface.co/ibm-granite/granite-3.2-8b-lora-rag-query-rewrite) to rewrite
the last user turn into a string that is more useful for retrieiving document snippets.
Here's what we get if we call this model directly on the original request:

In [None]:
rewrite_io_proc = QueryRewriteIOProcessor(query_rewrite_lora_backend)
rewrite_io_proc.create_chat_completion(chat_input).results[0].next_message.content

We can wrap the IO processor for this model in a request processor that rewrites
the last turn of the chat completion request.

In [None]:
rewrite_request_proc = RewriteRequestProcessor(rewrite_io_proc)
rewritten_chat_input = rewrite_request_proc.process(chat_input)[0]
print("Messages after rewrite:")
[{"role": m.role, "content": m.content} for m in rewritten_chat_input.messages]

In [None]:
rewritten_chat_input_with_docs = retrieval_request_proc.process(rewritten_chat_input)[0]
chat_input_with_docs_from_rewrite = rewritten_chat_input_with_docs.model_copy(
    update={"messages": chat_input.messages}
)
answerability_proc.create_chat_completion(chat_input_with_docs_from_rewrite).results[
    0
].next_message.content

We can chain all of these request processors together with the IO processor for 
the answerability model to create a single flow that processes requests in multiple
steps:
1. Rewrite the last user message for retrieval
1. Retrieve documents and attach them to the request
1. Check for answerability with the retrieved documents
1. If the answerability check passes, then send the request to the base model


In [None]:
rewrite_request_proc = RewriteRequestProcessor(rewrite_io_proc)

request = rewrite_request_proc.process(chat_input)[0]
request = retrieval_request_proc.process(request)[0]

# Switch back to original version of last turn
request = request.model_copy(update={"messages": chat_input.messages})

# Check for answerability and generate if documents pass the check
response = answerability_composite_proc.create_chat_completion(request)
rag_rewrite_result = response.results[0]
display(Markdown(rag_rewrite_result.next_message.content))

*TODO: Insert description of how this response is better.*

We can use the [LoRA Adapter for Citation Generation](
    https://huggingface.co/ibm-granite/granite-3.2-8b-lora-rag-citation-generation
) to explain exactly how this response is grounded in the documents that the rewritten
user query retrieves.

In [None]:
# Check for answerability, generate a response, then add citations
citations_composite_proc = CitationsCompositeIOProcessor(
    answerability_composite_proc, citations_lora_backend
)
result_with_citations = citations_composite_proc.create_chat_completion(
    request
).results[0]

print("Assistant response:")
display(Markdown(result_with_citations.next_message.content))
print("Citations:")
pd.set_option("display.max_colwidth", 1500)
pd.DataFrame.from_records(
    [c.model_dump() for c in result_with_citations.next_message.citations]
)

In [None]:
# Full set of documents for reference
for doc_index, doc in enumerate(request.documents):
    print(f"Document {doc_index}:")
    display(Markdown(doc.text))

In [None]:
# Raw output of the citations model is available for debugging
result_with_citations.next_message.raw

In [None]:
# Run a hallucination check over the preceding response
hallucinations_composite_proc = HallucinationsCompositeIOProcessor(
    io_proc, hallucination_lora_backend
)
result_with_hallucinations = hallucinations_composite_proc.create_chat_completion(
    request
).results[0]
result_with_hallucinations.next_message.model_dump()

We can wrap all of the functionality we've shown so far in a single class that 
inherits from the `InputOutputProcessor` interface in `granite-io`. Packaging things
this way lets applications treat this multi-step flow as if it was a single chat 
completion request to a base model.

In [None]:
from granite_io.io.base import InputOutputProcessor, RequestProcessor
from granite_io.backend import Backend
from granite_io.io.base import ChatCompletionInputs, ChatCompletionResults


class GraniteRAGCompositeIOProcessor(InputOutputProcessor):
    def __init__(
        self,
        io_proc: InputOutputProcessor,
        rewrite_request_proc: RequestProcessor,
        retrieval_request_proc: RequestProcessor,
        answerability_proc: InputOutputProcessor,
        citations_lora_backend: Backend,
        hallucination_lora_backend: Backend,
    ):
        self.rewrite_request_proc = rewrite_request_proc
        self.retrieval_request_proc = retrieval_request_proc

        # Build up a chain of IO processors:
        # answerability -> Granite -> citations -> hallucinations
        chain = AnswerabilityCompositeIOProcessor(io_proc, answerability_proc)
        chain = CitationsCompositeIOProcessor(chain, citations_lora_backend)
        chain = HallucinationsCompositeIOProcessor(chain, hallucination_lora_backend)
        self.io_proc_chain = chain

    async def acreate_chat_completion(
        self, inputs: ChatCompletionInputs
    ) -> ChatCompletionResults:
        """
        Chat completions API inherited from the ``InputOutputProcessor`` base class.

        :param inputs: Structured representation of the inputs to a chat completion
            request, possibly including additional fields that only this input-output
            processor can consume

        :returns: The next message that the model produces when fed the specified
            inputs, plus additional information about the low-level request.
        """
        original_inputs = inputs

        # Rewrite and retrieve
        inputs = (await rewrite_request_proc.aprocess(inputs))[0]
        inputs = (await retrieval_request_proc.aprocess(inputs))[0]

        # Switch back to original version of last turn
        inputs = inputs.model_copy(update={"messages": original_inputs.messages})

        # Perform answerability check, generate a response, add citations, and check
        # for hallucinations.
        return await self.io_proc_chain.acreate_chat_completion(inputs)

In [None]:
rag_io_proc = GraniteRAGCompositeIOProcessor(
    io_proc,
    rewrite_request_proc=rewrite_request_proc,
    retrieval_request_proc=retrieval_request_proc,
    answerability_proc=answerability_proc,
    citations_lora_backend=citations_lora_backend,
    hallucination_lora_backend=hallucination_lora_backend,
)

rag_result = rag_io_proc.create_chat_completion(chat_input).results[0]

In [None]:
print("Assistant response:")
display(Markdown(rag_result.next_message.content))
print("Citations:")
pd.set_option("display.max_colwidth", 1500)
display(
    pd.DataFrame.from_records(
        [c.model_dump() for c in rag_result.next_message.citations]
    )
)
print("Hallucination Checks:")
display(
    pd.DataFrame.from_records(
        [h.model_dump() for h in rag_result.next_message.hallucinations]
    )
)