# End-to-end demo with MTRAG benchmark data

This notebook shows several examples of end-to-end RAG use cases that use the retrieval
IO processor in conjunction with the IO processors for other Granite-based LoRA 
adapters. More information about the models used here can be found in our [technical
report](https://arxiv.org/html/2504.11704v1).

This notebook can run its own vLLM server to perform inference, or you can host the 
models on your own server. To use your own server, set the `run_server` variable below
to `False` and set appropriate values for the constants in the cell marked
`# Constants go here`.

In [None]:
import pathlib
from granite_io.io.granite_3_3.input_processors.granite_3_3_input_processor import (
    Granite3Point3Inputs,
)
from granite_io import make_io_processor, make_backend
from granite_io.io.base import RewriteRequestProcessor
from granite_io.io.voting import MBRDMajorityVotingProcessor
from granite_io.io.retrieval.util import download_mtrag_embeddings
from granite_io.io.retrieval import InMemoryRetriever, RetrievalRequestProcessor
from granite_io.io.answerability import (
    AnswerabilityIOProcessor,
    AnswerabilityCompositeIOProcessor,
)
from granite_io.io.query_rewrite import QueryRewriteIOProcessor
from granite_io.io.citations import CitationsCompositeIOProcessor
from granite_io.io.hallucinations import HallucinationsCompositeIOProcessor
from granite_io.backend.vllm_server import LocalVLLMServer
from granite_io.io.certainty import CertaintyIOProcessor
from IPython.display import display, Markdown
from granite_io.io.rag_agent_lib import obtain_loras
import pandas as pd
import os

In [None]:
# Constants go here
temp_data_dir = "../data/test_retrieval_temp"
corpus_name = "govt"
embeddings_data_file = pathlib.Path(temp_data_dir) / f"{corpus_name}_embeds.parquet"
embedding_model_name = "multi-qa-mpnet-base-dot-v1"
model_name = "ibm-granite/granite-3.3-8b-instruct"

query_rewrite_lora_name = "query_rewrite"
citations_lora_name = "citation_generation"
answerability_lora_name = "answerability_prediction"
hallucination_lora_name = "hallucination_detection"
certainty_lora_name = "certainty"
all_lora_names = [
    query_rewrite_lora_name,
    citations_lora_name,
    answerability_lora_name,
    hallucination_lora_name,
    certainty_lora_name,
]

run_server = True

In [None]:
if run_server:
    # Start by firing up a local vLLM server and connecting a backend instance to it.
    # Download and cache LoRA adapters.
    lora_model_paths = obtain_loras(all_lora_names)
    server = LocalVLLMServer(
        model_name,
        lora_adapters=lora_model_paths,
    )
    server.wait_for_startup(200)
    query_rewrite_lora_backend = server.make_lora_backend(query_rewrite_lora_name)
    citations_lora_backend = server.make_lora_backend(citations_lora_name)
    answerability_lora_backend = server.make_lora_backend(answerability_lora_name)
    hallucination_lora_backend = server.make_lora_backend(hallucination_lora_name)
    certainty_lora_backend = server.make_lora_backend(certainty_lora_name)
    backend = server.make_backend()
else:  # if not run_server
    # Use an existing server.
    # The constants here are for the server that local_vllm_server.ipynb starts.
    # Modify as needed.
    openai_base_url = "http://localhost:55555/v1"
    openai_api_key = "granite_intrinsics_1234"
    backend = make_backend(
        "openai",
        {
            "model_name": model_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    query_rewrite_lora_backend = make_backend(
        "openai",
        {
            "model_name": query_rewrite_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    citations_lora_backend = make_backend(
        "openai",
        {
            "model_name": citations_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    answerability_lora_backend = make_backend(
        "openai",
        {
            "model_name": answerability_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    hallucination_lora_backend = make_backend(
        "openai",
        {
            "model_name": hallucination_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    certainty_lora_backend = make_backend(
        "openai",
        {
            "model_name": certainty_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )

In [None]:
# Download the indexed corpus if it hasn't already been downloaded.
# This notebook uses a subset of the government corpus from the MTRAG benchmark.
embeddings_location = f"{temp_data_dir}/{corpus_name}_embeds.parquet"
if not os.path.exists(embeddings_location):
    download_mtrag_embeddings(embedding_model_name, corpus_name, embeddings_location)
embeddings_location

We start by creating an example chat completion request. 

This chat completion request simulates a scenario where the user is chatting with the
automated help desk agent of the California State Parks and is asking about internship
opportunities. The agent is about to respond to the user's question, "Cool, how to I
sign up?"

In [None]:
# Create an example chat completions request.
chat_input = Granite3Point3Inputs.model_validate(
    {
        "messages": [
            {
                "role": "assistant",
                "content": "Welcome to the California State Parks help desk.",
            },
            {
                "role": "user",
                "content": "I'm a student. Do you have internships?",
            },
            {
                "role": "assistant",
                "content": "The California State Parks hires Student Assistants "
                "to perform a variety of tasks that require limited or no previous "
                "work experience.",
            },
            {"role": "user", "content": "Cool, how do I sign up?"},
        ],
        "generate_inputs": {
            "temperature": 0.0,
            "max_tokens": 4096,
        },
    }
)

Let's start by passing the chat completion request directly to the language model,
without using retrieval-augmented generation.

In [None]:
# Spin up an IO processor for the base model
io_proc = make_io_processor(model_name, backend=backend)

# Use the IO processor to generate a chat completion
non_rag_result = io_proc.create_chat_completion(chat_input)
display(Markdown(non_rag_result.results[0].next_message.content))

This result is a hallucination. The actual correct answer can be found [here](
    https://www.parks.ca.gov/?page_id=848
).

We can use the 
[Granite 3.3 8b Instruct - Uncertainty LoRA](
    https://huggingface.co/ibm-granite/granite-3.3-8b-rag-agent-lib/blob/main/certainty_lora/README.md)
LoRA adapter to flag cases such as this one that are not covered by the base model's 
training data.

In [None]:
certainty_io_proc = CertaintyIOProcessor(certainty_lora_backend)
certainty_score = (
    certainty_io_proc.create_chat_completion(chat_input).results[0].next_message.content
)
print(f"Certainty score is {certainty_score} out of 1.0")

The low certainty score indicates that the model's training data does not align closely
with this question.

To answer this question properly, we need to provide the model with domain-specific 
information. In this case, the relevant information can be found in the Government 
corpus of the [MTRAG multi-turn RAG benchmark](https://github.com/IBM/mt-rag-benchmark).
Let's spin up an in-memory vector database, using embeddings that we've precomputed
offline from this corpus.

In [None]:
# Spin up an in-memory vector database
retriever = InMemoryRetriever(embeddings_data_file, embedding_model_name)

We can send string queries against this vector database to retrieve relevant documents.
Here we query the database with the user's last turn from our example conversation, 
"Cool, how do I sign up?"

In [None]:
# The vector database fetches document snippets that match a given query.
# For example, the user's question in the conversation above:
print(f"Query is: '{chat_input.messages[-1].content}'")
print("Matching document snippets:")
pd.set_option("display.max_colwidth", 120)
retriever.retrieve(chat_input.messages[-1].content, top_k=3).to_pandas()

If we attach a GraniteIO RetrievalRequestProcessor to our vector database, we can use
this RequestProcessor to augment the original chat completion request with the document
snippets that the retriever fetches when fed the last user turn as a query.

In [None]:
retrieval_request_proc = RetrievalRequestProcessor(retriever, top_k=3)
chat_input_with_docs = retrieval_request_proc.process(chat_input)[0]
chat_input_with_docs.model_dump()

Note that the retriever here operates over the last user turn. In this particular 
conversation, the last user turn is the phrase, "Cool, how do I sign up?", which is 
missing crucial information for retrieving relevant documents -- specifically, what is
the user attempting to sign up for? 

As a result, the snippets retrieved are not specific to the user's intended question.
Instead, they cover the general topic of signing up for things.

Let's see what happens if we run our request through the model using these low-quality 
document snippets.


In [None]:
rag_result = io_proc.create_chat_completion(chat_input_with_docs)
display(Markdown(rag_result.results[0].next_message.content))

In this case, the model correctly refuses to answer the question.

Unfortunately, the training data for most LLMs is biased against 
producing this type of result, leading to frequent hallucinations in the presence of
faulty retrieved documents. For example, if the last user turn in our example 
conversation is "How to I sign up?", instead of "*Cool,* how do I sign up?", the model
produces an entirely different response:

In [None]:
# Change the last user turn from "Cool, how do I sign up?" to "How to I sign up?"
messages_no_cool = chat_input_with_docs.messages.copy()
messages_no_cool[-1].content = "How do I sign up?"
chat_input_no_cool = chat_input_with_docs.model_copy(
    update={"messages": messages_no_cool}
)
rag_result_no_cool = io_proc.create_chat_completion(chat_input_no_cool)
display(Markdown(rag_result_no_cool.results[0].next_message.content))

The [LoRA Adapter for Answerability Classification](
    https://huggingface.co/ibm-granite/granite-3.3-8b-rag-agent-lib/blob/main/answerability_prediction_lora/README.md
)
provides a more robust way to detect this kind of problem. Here's what happens if we 
run the chat completion request with faulty documents snippets through the 
answerability model, using the
`granite_io` IO processor for the model to handle input and output:

In [None]:
answerability_proc = AnswerabilityIOProcessor(answerability_lora_backend)
answerability_proc.create_chat_completion(chat_input_with_docs).results[
    0
].next_message.content

The answerability model detects that the documents we have retrieved cannot be used to
answer the user's question. We use use a composite IO processor to wrap this check in
a flow that falls back on canned response.

In [None]:
answerability_composite_proc = AnswerabilityCompositeIOProcessor(
    io_proc, answerability_proc
)
composite_result = answerability_composite_proc.create_chat_completion(
    chat_input_with_docs
).results[0]
print(composite_result.next_message.content)

At this point we've improved our model output from a hallucinated response to a refusal
to answer the question. This result is an improvement, but we can do better if we can
retrieve document snippets that are relevant to the user's intent as expressed in the
*entire* conversation, not just the last turn.

We can use use the [LoRA Adapter for Query Rewrite](
    https://huggingface.co/ibm-granite/granite-3.2-8b-lora-rag-query-rewrite
) to rewrite
the last user turn into a string that is more useful for retrieving document snippets.
Here's what we get if we call this model directly on the original request:

In [None]:
rewrite_io_proc = QueryRewriteIOProcessor(query_rewrite_lora_backend)
rewrite_io_proc.create_chat_completion(chat_input).results[0].next_message.content

Since the LoRA Adapter for Query Rewrite is a language model, we can ask it to generate
multiple different rewrites. We'll use this capability later on to improve end-to-end
result quality further.

In [None]:
multiple_rewrites = rewrite_io_proc.create_chat_completion(
    chat_input.with_addl_generate_params({"n": 10, "temperature": 0.8})
).results
[r.next_message.content for r in multiple_rewrites]

We can wrap the IO processor for this model in a request processor that rewrites
the last turn of the chat completion request.

In [None]:
rewrite_request_proc = RewriteRequestProcessor(rewrite_io_proc)
rewritten_chat_input = rewrite_request_proc.process(chat_input)[0]
print("Messages after rewrite:")
[{"role": m.role, "content": m.content} for m in rewritten_chat_input.messages]

We can fetch documents with the rewritten query, then use use the answerability IO processor to check that the fetched documents answer the rewritten query.

In [None]:
rewritten_chat_input_with_docs = retrieval_request_proc.process(rewritten_chat_input)[0]
answerability_proc.create_chat_completion(rewritten_chat_input_with_docs).results[
    0
].next_message.content

We can also verify that the fetched documents answer the *original* query prior to the rewrite.

In [None]:
rewritten_chat_input_with_docs = retrieval_request_proc.process(rewritten_chat_input)[0]
chat_input_with_docs_from_rewrite = rewritten_chat_input_with_docs.model_copy(
    update={"messages": chat_input.messages}  # Reinstate original messages
)
answerability_proc.create_chat_completion(chat_input_with_docs_from_rewrite).results[
    0
].next_message.content

We can chain all of these request processors together with the IO processor for 
the answerability model to create a single flow that processes requests in multiple
steps:
1. Rewrite the last user message for retrieval
1. Retrieve documents and attach them to the request
1. Check for answerability with the retrieved documents
1. If the answerability check passes, then send the request to the base model


In [None]:
rewrite_request_proc = RewriteRequestProcessor(rewrite_io_proc)

request = rewrite_request_proc.process(chat_input)[0]
request = retrieval_request_proc.process(request)[0]

# Switch back to original version of last turn
request = request.model_copy(update={"messages": chat_input.messages})

# Check for answerability and generate if documents pass the check
response = answerability_composite_proc.create_chat_completion(request)
rag_rewrite_result = response.results[0]
display(Markdown(rag_rewrite_result.next_message.content))

Unlike the responses we've seen so far, this response provides information that is both
relevant to the user's intended question and grounded in documents retrieved from the 
corpus.

We can use the [LoRA Adapter for Citation Generation](https://huggingface.co/ibm-granite/granite-3.3-8b-rag-agent-lib/blob/main/citation_generation_lora/README.md
) to explain exactly how this response is grounded in the documents that the rewritten
user query retrieves.

In [None]:
# Check for answerability, generate a response, then add citations
citations_composite_proc = CitationsCompositeIOProcessor(
    answerability_composite_proc, citations_lora_backend
)
result_with_citations = citations_composite_proc.create_chat_completion(
    request
).results[0]

print("Assistant response:")
display(Markdown(result_with_citations.next_message.content))
print("Citations:")
pd.set_option("display.max_colwidth", 1500)
pd.DataFrame.from_records(
    [c.model_dump() for c in result_with_citations.next_message.citations]
)

We can also use the [LoRA Adapter for Hallucination Detection in RAG outputs](
    https://huggingface.co/ibm-granite/granite-3.3-8b-rag-agent-lib/blob/main/hallucination_detection_lora/README.md
) to further verify that each sentence of the assistant response is consistent with the
information in the retrieved documents.

In [None]:
# Run a hallucination check over the preceding response
hallucinations_composite_proc = HallucinationsCompositeIOProcessor(
    io_proc, hallucination_lora_backend
)
result_with_hallucinations = hallucinations_composite_proc.create_chat_completion(
    request
).results[0]
print("Hallucination Checks:")
display(
    pd.DataFrame.from_records(
        [h.model_dump() for h in result_with_hallucinations.next_message.hallucinations]
    )
)

We can wrap all of the functionality we've shown so far in a single class that 
inherits from the `InputOutputProcessor` interface in `granite-io`. Packaging things
this way lets applications treat this multi-step flow as if it was a single chat 
completion request to a base model.

In [None]:
from granite_io.io.base import InputOutputProcessor, RequestProcessor
from granite_io.backend import Backend
from granite_io.io.base import ChatCompletionInputs, ChatCompletionResults
import asyncio


class MyRAGCompositeIOProcessor(InputOutputProcessor):
    def __init__(
        self,
        io_proc: InputOutputProcessor,
        rewrite_request_proc: RequestProcessor,
        retrieval_request_proc: RequestProcessor,
        answerability_proc: InputOutputProcessor,
        certainty_io_proc: InputOutputProcessor,
        citations_lora_backend: Backend,
        hallucination_lora_backend: Backend,
    ):
        self.io_proc = io_proc
        self.certainty_io_proc = certainty_io_proc

        self.rewrite_request_proc = rewrite_request_proc
        self.retrieval_request_proc = retrieval_request_proc

        # Build up a chain of IO processors:
        # answerability -> base model -> citations -> hallucinations
        chain = AnswerabilityCompositeIOProcessor(io_proc, answerability_proc)
        chain = CitationsCompositeIOProcessor(chain, citations_lora_backend)
        chain = HallucinationsCompositeIOProcessor(chain, hallucination_lora_backend)
        self.io_proc_chain = chain

    async def acreate_chat_completion(
        self, inputs: ChatCompletionInputs
    ) -> ChatCompletionResults:
        """
        Chat completions API inherited from the ``InputOutputProcessor`` base class.

        :param inputs: Structured representation of the inputs to a chat completion
            request, possibly including additional fields that only this input-output
            processor can consume

        :returns: The next message that the model produces when fed the specified
            inputs, plus additional information about the low-level request.
        """
        original_inputs = inputs

        # Start by checking whether retrieval is necessary to answer this query.
        certainty_result = await self.certainty_io_proc.acreate_chat_completion(
            inputs.with_addl_generate_params({"n": 1, "temperature": 0.0})
        )
        certainty_score = float(certainty_result.results[0].next_message.content)
        if certainty_score > 0.8:
            # If control reaches here, the base model can respond with high certainty
            # without using RAG, so skip retrieval and go directly to generation.
            return await self.io_proc.acreate_chat_completion(inputs)

        # If control reaches here, the base model needs additional context to generate
        # a high-quality response. Perform query rewrite, followed by retrieval
        # augmented generation.

        # Interpret the "n" parameter in the request as "number of times to run
        # retrieval and generation".
        rewritten_inputs = await rewrite_request_proc.aprocess(inputs)

        # Do the remaining workflow steps once per rewritten query.
        coroutines = []
        for inputs in rewritten_inputs:
            # Retrieve documents for each rewritten version of the query.
            inputs = inputs.with_addl_generate_params({"n": 1})
            inputs = (await retrieval_request_proc.aprocess(inputs))[0]

            # Switch back to original version of last turn
            inputs = inputs.with_messages(original_inputs.messages)

            # Perform answerability check, generate a response, add citations, and check
            # for hallucinations. Do these steps in parallel across retrievals.
            coroutines.append(self.io_proc_chain.acreate_chat_completion(inputs))

        # Merge results from parallel invocations
        sub_results = await asyncio.gather(*coroutines)
        return ChatCompletionResults(
            results=[sub_result.results[0] for sub_result in sub_results]
        )

In [None]:
rag_io_proc = MyRAGCompositeIOProcessor(
    io_proc,
    rewrite_request_proc=rewrite_request_proc,
    retrieval_request_proc=retrieval_request_proc,
    certainty_io_proc=certainty_io_proc,
    answerability_proc=answerability_proc,
    citations_lora_backend=citations_lora_backend,
    hallucination_lora_backend=hallucination_lora_backend,
)

rag_result = rag_io_proc.create_chat_completion(chat_input).results[0]

In [None]:
print("Assistant response:")
display(Markdown(rag_result.next_message.content))
print("Citations:")
pd.set_option("display.max_colwidth", 1500)
display(
    pd.DataFrame.from_records(
        [c.model_dump() for c in rag_result.next_message.citations]
    )
)
print("Hallucination Checks:")
display(
    pd.DataFrame.from_records(
        [h.model_dump() for h in rag_result.next_message.hallucinations]
    )
)

Since the end-to-end flow is an `InputOutputProcessor`, we  can generate multiple 
alternative responses by changing the generation parameters.

In [None]:
many_rag_results = rag_io_proc.create_chat_completion(
    chat_input.with_addl_generate_params({"n": 10, "temperature": 0.8})
).results
many_rag_results

We can also use inference-time scaling techniques to rank these results and pick the 
"best" one. 
Here we use Minimum Bayesian Risk (MBR) decoding to choose the result that is closest
to the average of all other results.

In [None]:
voting_io_proc = MBRDMajorityVotingProcessor(rag_io_proc)
mbrd_result = voting_io_proc.create_chat_completion(
    chat_input.with_addl_generate_params({"n": 10, "temperature": 0.8})
).results[0]

In [None]:
print("Assistant response with Minimum Bayesian Risk decoding:")
display(Markdown(mbrd_result.next_message.content))
print("Citations:")
pd.set_option("display.max_colwidth", 1500)
display(
    pd.DataFrame.from_records(
        [c.model_dump() for c in mbrd_result.next_message.citations]
    )
)
print("Hallucination Checks:")
display(
    pd.DataFrame.from_records(
        [h.model_dump() for h in mbrd_result.next_message.hallucinations]
    )
)

In [None]:
# Free up GPU resources
if "server" in locals():
    server.shutdown()