# End-to-end demo with MTRAG benchmark data

This notebook shows several examples of end-to-end RAG use cases that use the retrieval IO processor in conjunction with the IO processors for other Granite-based LoRA adapters. More information about the models used here can be found in our [technical report](https://arxiv.org/html/2504.11704).

This notebook requires a hosted Elasticsearch server for retrieval and a hosted vLLM server to perform inference. Change these variables in the `Constants` cell below.

```
elasticsearch_host = "https://localhost:32765"
openai_base_url = "http://localhost:55555/v1"
```

In [None]:
# Imports go in this cell
import pathlib
import os
import json
import openai

from IPython.display import display, Markdown

import granite_common
from granite_common.base.types import (
    AssistantMessage,
    ChatCompletion,
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    UserMessage,
    VLLMExtraBody,
)

from granite_common.retrievers.util import download_mtrag_embeddings
from granite_common.retrievers import (
    ElasticsearchRetriever,
    InMemoryRetriever,
    Retriever,
)

In [None]:
# Constants go here
CORPUS_NAMES_MAPPINGS = {
    "banking": "mt-rag-banking-elser-512-100-20250205",
    "clapnq": "mt-rag-clapnq-elser-512-100-20240503",
    "fiqa": "mt-rag-fiqa-beir-elser-512-100-20240501",
    "govt": "mt-rag-govt-elser-512-100-20240611",
    "ibmcloud": "mt-rag-ibmcloud-elser-512-100-20240502",
    "scifact": "mt-rag-scifact-beir-elser-512-100-20240501",
    "telco": "mt-rag-telco-elser-512-100-20241210",
}

DEFAULT_CANNED_RESPONSE = (
    "Sorry, but I am unable to answer this question from the documents retrieved."
)

target_model_name = "granite-3.3-8b-instruct"
base_model_name = f"ibm-granite/{target_model_name}"

openai_base_url = "http://localhost:55555/v1"
openai_api_key = "rag_intrinsics_1234"

intrinsic_names = [
    "citations",
    "query_rewrite",
    "answerability",
    "hallucination_detection",
    "uncertainty",
]

# retriever_name = "elasticsearch"
retriever_name = "embeddings"
corpus_name = "govt"

if retriever_name == "elasticsearch":
    # Elasticsearch retriever
    elasticsearch_host = "https://localhost:32765"
elif retriever_name == "embeddings":
    # Embeddings retriever
    temp_data_dir = "../data/test_retrieval_temp"
    embeddings_data_file = pathlib.Path(temp_data_dir) / f"{corpus_name}_embeds.parquet"
    embedding_model_name = "multi-qa-mpnet-base-dot-v1"

    # Download the indexed corpus if it hasn't already been downloaded.
    # This notebook uses a subset of the government corpus from the MTRAG benchmark.
    embeddings_location = f"{temp_data_dir}/{corpus_name}_embeds.parquet"
    if not os.path.exists(embeddings_location):
        download_mtrag_embeddings(
            embedding_model_name, corpus_name, embeddings_location
        )

In [None]:
# Intrinsics
# Load config files and create objects

intrinsic_rewriters = {}
intrinsic_result_processors = {}
for intrinsic_name in intrinsic_names:
    io_yaml_file = granite_common.intrinsics.util.obtain_io_yaml(
        intrinsic_name, base_model_name
    )

    intrinsic_rewriter = granite_common.IntrinsicsRewriter(config_file=io_yaml_file)
    intrinsic_result_processor = granite_common.IntrinsicsResultProcessor(
        config_file=io_yaml_file
    )

    intrinsic_rewriters[intrinsic_name] = intrinsic_rewriter
    intrinsic_result_processors[intrinsic_name] = intrinsic_result_processor

# Connect to the inference server
client = openai.OpenAI(base_url=openai_base_url, api_key=openai_api_key)

In [None]:
# Retriever

if retriever_name == "elasticsearch":
    # Connect to the Elasticsearch server.
    # Due to the setup, we have to open a retriever connection for each corpus.
    retrievers = {}
    for corpus_name, actual_corpus_name in CORPUS_NAMES_MAPPINGS.items():
        retriever = ElasticsearchRetriever(
            corpus_name=actual_corpus_name,
            host=elasticsearch_host,
            verify_certs=False,
            ssl_show_warn=False,
        )
        retrievers[corpus_name] = retriever
elif retriever_name == "embeddings":
    pass

In [None]:
# Functions


def call_intrinsic(
    intrinsic_name: str,
    chat_completion_request: dict,
    **kwargs,
) -> openai.types.chat.ChatCompletion:
    """
    Call an intrinsic with OpenAI Python API objects on input and output.

    :param intrinsic_name: Name of intrinsic to invoke
    :param chat_completion_request: Chat completion request to make; can be dict or
        OpenAI dataclass
    :param kwargs: Optional named argument(s) for intrinsic

    :returns: OpenAI Python API chat completion containing processed intrinsic outputs
    """
    # Some intrinsics modify the chat object.
    _chat_completion_request = chat_completion_request.model_copy(deep=True)

    rewriter = intrinsic_rewriters[intrinsic_name]
    result_processor = intrinsic_result_processors[intrinsic_name]
    rewritten_request = rewriter.transform(_chat_completion_request, **kwargs)

    # Set model name manually for now, because vLLM does not maintain any kind of
    # metadata that would allow us to determine the right model name.
    rewritten_request.model = intrinsic_name

    response = client.chat.completions.create(**rewritten_request.model_dump())
    # return response
    transformed_response = result_processor.transform(response, rewritten_request)

    # Convert to same type as OpenAI API
    return openai.types.chat.ChatCompletion.model_validate(
        transformed_response.model_dump()
    )


def retrieve_snippets(retriever: Retriever, query: str, top_k: int = 3):
    return retriever.retrieve(query, top_k=top_k)

We start by creating an example chat completion request. 

This chat completion request simulates a scenario where the user is chatting with the automated help desk agent of the California State Parks and is asking about internship opportunities. The agent is about to respond to the user's question, "Cool, how do I sign up?"

In [None]:
# Create an example chat completion with a user question and two documents.
chat_input = ChatCompletion.model_validate(
    {
        "messages": [
            {
                "role": "assistant",
                "content": "Welcome to the California State Parks help desk.",
            },
            {
                "role": "user",
                "content": "I'm a student. Do you have internships?",
            },
            {
                "role": "assistant",
                "content": "The California State Parks hires Student Assistants "
                "to perform a variety of tasks that require limited or no previous "
                "work experience.",
            },
            {"role": "user", "content": "Cool, how do I sign up?"},
        ],
        "temperature": 0.0,
        "max_tokens": 4096,
    }
)
print(chat_input.model_dump_json(indent=2))

Let's start by passing the chat completion request directly to the language model, without using retrieval-augmented generation.

In [None]:
# Pass the example through Granite to get an answer.
chat_input.model = base_model_name
non_rag_completion = client.chat.completions.create(**chat_input.model_dump())

display(Markdown(non_rag_completion.choices[0].message.content))

This result is a hallucination. The actual correct answer can be found [here](
    https://www.parks.ca.gov/?page_id=848
).

We can use the 
[Uncertainty LoRA](
    https://huggingface.co/generative-computing/core-intrinsics-lib/blob/main/uncertainty/README.md)
adapter to flag cases such as this one that are not covered by the base model's 
training data.

In [None]:
response = call_intrinsic("uncertainty", chat_input)
certainty_score = round(json.loads(response.choices[0].message.content)["certainty"], 2)

print(f"Certainty score is {certainty_score} out of 1.0")

The low certainty score indicates that the model's training data does not align closely with this question.

To answer this question properly, we need to provide the model with domain-specific information. In this case, the relevant information can be found in the Government corpus of the [MTRAG (multi-turn RAG) benchmark](https://github.com/IBM/mt-rag-benchmark).

In [None]:
if retriever_name == "elasticsearch":
    retriever = retrievers[corpus_name]
elif retriever_name == "embeddings":
    retriever = InMemoryRetriever(embeddings_data_file, embedding_model_name)

We can send string queries against this vector database to retrieve relevant documents. Here we query the database with the user's last turn from our example conversation, "Cool, how do I sign up?"

In [None]:
# The database fetches document snippets that match a given query.
# For example, the user's question in the conversation above:
query = chat_input.messages[-1].content
print(f"Query is: '{query}'")
print("Matching document snippets:")
documents = retrieve_snippets(retriever, query, top_k=3)
documents

We augment the original chat completion request with the document snippets that the retriever fetches when fed the last user turn as a query.

In [None]:
# Copy the original chat input and place the documents in the correct place.
chat_input_with_docs = chat_input.model_copy(deep=True)
chat_input_with_docs.extra_body = VLLMExtraBody(documents=documents)
chat_input_with_docs.model_dump()

Note that the retriever here operates over the last user turn. In this particular conversation, the last user turn is the phrase, "Cool, how do I sign up?", which is missing crucial information for retrieving relevant documents -- specifically, what isthe user attempting to sign up for? 

As a result, the snippets retrieved are not specific to the user's intended question. Instead, they cover the general topic of signing up for things.

Let's see what happens if we run our request through the model using these low-quality document snippets.

In [None]:
rag_completion = client.chat.completions.create(**chat_input_with_docs.model_dump())
display(Markdown(rag_completion.choices[0].message.content))

In this case, the model correctly refuses to answer the question.

Unfortunately, the training data for most LLMs is biased against 
producing this type of result, leading to frequent hallucinations in the presence of faulty retrieved documents. For example, if the last user turn in our example conversation is "How to I sign up?", instead of "*Cool,* how do I sign up?", the model produces an entirely different response:

In [None]:
# Change the last user turn from "Cool, how do I sign up?" to "How to I sign up?"
messages_no_cool = chat_input_with_docs.messages.copy()
messages_no_cool[-1].content = "How do I sign up?"
chat_input_no_cool = chat_input_with_docs.model_copy(
    update={"messages": messages_no_cool}
)
rag_result_no_cool = client.chat.completions.create(**chat_input_no_cool.model_dump())
display(Markdown(rag_result_no_cool.choices[0].message.content))

The [LoRA Adapter for Answerability Classification](
    https://huggingface.co/generative-computing/rag-intrinsics-lib/blob/main/answerability/README.md
)
provides a more robust way to detect this kind of problem. Here's what happens if we run the chat completion request with faulty documents snippets through the answerability model:

In [None]:
response = call_intrinsic("answerability", chat_input_with_docs)
answerability_likelihood = json.loads(response.choices[0].message.content)[
    "answerability_likelihood"
]
answerability_likelihood

The answerability model detects that the documents we have retrieved cannot be used to answer the user's question. We can wrap this check in a flow that falls back on canned response.

In [None]:
if answerability_likelihood >= 0.5:
    rag_completion = client.chat.completions.create(**chat_input_with_docs.model_dump())
else:
    rag_completion = ChatCompletionResponse(
        choices=[
            ChatCompletionResponseChoice(
                index=0, message=AssistantMessage(content=DEFAULT_CANNED_RESPONSE)
            )
        ]
    )
display(Markdown(rag_completion.choices[0].message.content))

At this point we've improved our model output from a hallucinated response to a refusal to answer the question. This result is an improvement, but we can do better if we can retrieve document snippets that are relevant to the user's intent as expressed in the *entire* conversation, not just the last turn.

We can use use the [LoRA Adapter for Query Rewrite](
    https://huggingface.co/generative-computing/rag-intrinsics-lib/blob/main/query_rewrite/README.md
) to rewrite the last user turn into a string that is more useful for retrieving document snippets.
Here's what we get if we call this model directly on the original request:

In [None]:
response = call_intrinsic("query_rewrite", chat_input_with_docs)
rewritten_question = json.loads(response.choices[0].message.content)[
    "rewritten_question"
]
rewritten_question

Since the LoRA Adapter for Query Rewrite is a language model, we can ask it to generate multiple different rewrites. We'll use this capability later on to improve end-to-end result quality further.

In [None]:
# Generate 10 rewrites with variations (i.e. increase the temperature).
chat_input_10_rewrites = chat_input.model_copy(deep=True)
chat_input_10_rewrites.n = 10
chat_input_10_rewrites.temperature = 0.8

response = call_intrinsic("query_rewrite", chat_input_10_rewrites)
for choice in response.choices:
    rewritten_question = json.loads(choice.message.content)["rewritten_question"]
    print(rewritten_question)

We can wrap the IO processor for this model in a request processor that rewrites
the last turn of the chat completion request.

In [None]:
response = call_intrinsic("query_rewrite", chat_input)
rewritten_question = json.loads(response.choices[0].message.content)[
    "rewritten_question"
]

rewritten_chat_input = chat_input.model_copy(deep=True)
rewritten_chat_input.messages[-1] = UserMessage(content=rewritten_question)

print("Messages after rewrite:")
[{"role": m.role, "content": m.content} for m in rewritten_chat_input.messages]

We can fetch documents with the rewritten query, then use use the answerability LoRA to check that the fetched documents answer the rewritten query.

In [None]:
# Fetch documents using the rewritten query.
documents = retrieve_snippets(
    retriever, rewritten_chat_input.messages[-1].content, top_k=3
)

# Append the documents to the rewritten chat input.
rewritten_chat_input_with_docs = rewritten_chat_input.model_copy(deep=True)
rewritten_chat_input_with_docs.extra_body = VLLMExtraBody(documents=documents)

# Check the answerability of the chat with rewritten query.
response = call_intrinsic("answerability", rewritten_chat_input_with_docs)
answerability_likelihood = json.loads(response.choices[0].message.content)[
    "answerability_likelihood"
]
answerability_likelihood

We can also verify that the fetched documents answer the *original* query prior to the rewrite.

In [None]:
chat_input_with_docs_from_rewrite = chat_input_with_docs.model_copy(deep=True)
chat_input_with_docs_from_rewrite.extra_body.documents = (
    rewritten_chat_input_with_docs.extra_body.documents
)

response = call_intrinsic("answerability", chat_input_with_docs_from_rewrite)
answerability_likelihood = json.loads(response.choices[0].message.content)[
    "answerability_likelihood"
]
answerability_likelihood

We can chain all of these request processors together with the IO processor for 
the answerability model to create a single flow that processes requests in multiple
steps:
1. Rewrite the last user message for retrieval
1. Retrieve documents and attach them to the request
1. Check for answerability with the retrieved documents
1. If the answerability check passes, then send the request to the base model


In [None]:
# 1. Rewrite
response = call_intrinsic("query_rewrite", chat_input)
rewritten_question = json.loads(response.choices[0].message.content)[
    "rewritten_question"
]

new_chat_input = chat_input.model_copy(deep=True)
new_chat_input.messages[-1] = UserMessage(content=rewritten_question)

# 2. Retrieve
documents = retrieve_snippets(
    retriever,
    new_chat_input.messages[-1].content,
    top_k=3,
)
new_chat_input.extra_body = VLLMExtraBody(documents=documents)

# 3. Answerability
response = call_intrinsic("answerability", new_chat_input)
answerability_likelihood = json.loads(response.choices[0].message.content)[
    "answerability_likelihood"
]

# 4. Answerable -> base model to generate
DEFAULT_CANNED_RESPONSE = (
    "Sorry, but I am unable to answer this question from the documents retrieved."
)
if answerability_likelihood >= 0.5:
    rag_completion = client.chat.completions.create(**new_chat_input.model_dump())
else:
    rag_completion = ChatCompletionResponse(
        choices=[
            ChatCompletionResponseChoice(
                index=0, message=AssistantMessage(content=DEFAULT_CANNED_RESPONSE)
            )
        ]
    )
display(Markdown(rag_completion.choices[0].message.content))

Unlike the responses we've seen so far, this response provides information that is both relevant to the user's intended question and grounded in documents retrieved from the  corpus.

We can use the [LoRA Adapter for Citation Generation](https://huggingface.co/generative-computing/rag-intrinsics-lib/blob/main/citations/README.md
) to explain exactly how this response is grounded in the documents that the rewritten user query retrieves.

In [None]:
# Generate a Granite response.
chat_input_citations = new_chat_input.model_copy(deep=True)
chat_input_citations.model = base_model_name

chat_completion = client.chat.completions.create(**chat_input_citations.model_dump())
chat_input_citations.messages.append(chat_completion.choices[0].message)

response = call_intrinsic("citations", chat_input_citations)
citations = json.loads(response.choices[0].message.content)

print("Assistant response:")
display(Markdown(chat_input_citations.messages[-1].content))
print("Citations:")
print(json.dumps(citations, indent=2))

We can also use the [LoRA Adapter for Hallucination Detection](
    https://huggingface.co/generative-computing/rag-intrinsics-lib/blob/main/hallucination_detection/README.md
) to further verify that each sentence of the assistant response is consistent with the information in the retrieved documents.

In [None]:
# Generate a Granite response.
chat_input_hallucinations = new_chat_input.model_copy(deep=True)
chat_input_hallucinations.model = base_model_name

chat_completion = client.chat.completions.create(
    **chat_input_hallucinations.model_dump()
)
chat_input_hallucinations.messages.append(chat_completion.choices[0].message)

response = call_intrinsic("hallucination_detection", chat_input_hallucinations)
hallucinations = json.loads(response.choices[0].message.content)

print("Assistant response:")
display(Markdown(chat_input_hallucinations.messages[-1].content))
print("Hallucination Checks:")
print(json.dumps(hallucinations, indent=2))

## TODO: Composite IO Processor class