# End-to-end demo with MTRAG benchmark data

This notebook shows several examples of end-to-end RAG use cases that use the retrieval
IO processor in conjunction with the IO processors for other Granite-based LoRA 
adapters.

This notebook can run its own vLLM server to perform inference, or you can host the 
models on your own server. To use your own server, set the `run_server` variable below
to `False` and set appropriate values for the constants in the cell marked
`# Constants go here`.

In [None]:
import pathlib
from granite_io.io.granite_3_2.input_processors.granite_3_2_input_processor import (
    Granite3Point2Inputs,
)
from granite_io import make_io_processor, make_backend
from granite_io.io.base import RewriteRequestProcessor
from granite_io.io.retrieval.util import download_mtrag_embeddings
from granite_io.io.retrieval import InMemoryRetriever, RetrievalRequestProcessor
from granite_io.io.answerability import (
    AnswerabilityIOProcessor,
    AnswerabilityCompositeIOProcessor,
)
from granite_io.io.query_rewrite import QueryRewriteIOProcessor
from granite_io.io.citations import CitationsCompositeIOProcessor
from granite_io.backend.vllm_server import LocalVLLMServer
from IPython.display import display, Markdown
import pandas as pd
import os

In [None]:
# Constants go here
temp_data_dir = "../data/test_retrieval_temp"
corpus_name = "govt"
embeddings_data_file = pathlib.Path(temp_data_dir) / f"{corpus_name}_embeds.parquet"
embedding_model_name = "multi-qa-mpnet-base-dot-v1"
model_name = "ibm-granite/granite-3.2-8b-instruct"
query_rewrite_lora_name = "ibm-granite/granite-3.2-8b-lora-rag-query-rewrite"
citations_lora_name = "ibm-granite/granite-3.2-8b-lora-rag-citation-generation"
answerability_lora_name = "ibm-granite/granite-3.2-8b-lora-rag-answerability-prediction"

run_server = False

In [None]:
if run_server:
    # Start by firing up a local vLLM server and connecting a backend instance to it.
    server = LocalVLLMServer(
        model_name,
        lora_adapters=[
            (lora_name, lora_name)
            for lora_name in (
                query_rewrite_lora_name,
                citations_lora_name,
                answerability_lora_name,
            )
        ],
    )
    server.wait_for_startup(200)
    query_rewrite_lora_backend = server.make_lora_backend(query_rewrite_lora_name)
    citations_lora_backend = server.make_lora_backend(citations_lora_name)
    answerability_lora_backend = server.make_lora_backend(answerability_lora_name)
    backend = server.make_backend()
else:  # if not run_server
    # Use an existing server.
    # The constants here are for the server that local_vllm_server.ipynb starts.
    # Modify as needed.
    openai_base_url = "http://localhost:55555/v1"
    openai_api_key = "granite_intrinsics_1234"
    backend = make_backend(
        "openai",
        {
            "model_name": model_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    query_rewrite_lora_backend = make_backend(
        "openai",
        {
            "model_name": query_rewrite_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    citations_lora_backend = make_backend(
        "openai",
        {
            "model_name": citations_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    answerability_lora_backend = make_backend(
        "openai",
        {
            "model_name": answerability_lora_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )

In [None]:
# Download the indexed corpus if it hasn't already been downloaded.
# This notebook uses a subset of the government corpus from the MTRAG benchmark.
embeddings_location = f"{temp_data_dir}/{corpus_name}_embeds.parquet"
if not os.path.exists(embeddings_location):
    download_mtrag_embeddings(embedding_model_name, corpus_name, embeddings_location)
embeddings_location

In [None]:
# Spin up an IO processor for the base model
io_proc = make_io_processor(model_name, backend=backend)
io_proc

We start by creating an example chat completion request. The user is chatting with an
agent for the California Appellate Courts help desk and is currently asking for 
information about the opening hours of the law library at the Solano County Hall of
Justice in Fairfield, CA.

In [None]:
# Create an example chat completions request
chat_input = Granite3Point2Inputs.model_validate(
    {
        "messages": [
            {
                "role": "assistant",
                "content": "Welcome to the California Appellate Courts help desk.",
            },
            {
                "role": "user",
                "content": "I have a court appearance next Tuesday at the Solano "
                "County Hall of Justice. Can I use the law library to prepare?",
            },
            {
                "role": "assistant",
                "content": "Yes, you can visit the law library. Law libraries are a "
                "great resource for self-representation or learning about the law.",
            },
            {"role": "user", "content": "What's its address?"},
        ],
        "generate_inputs": {
            "temperature": 0.0,
            "max_tokens": 4096,
        },
    }
)
chat_input

Let's start by passing the chat completion request directly to the language model,
without using retrieval-augmented generation.

In [None]:
non_rag_result = io_proc.create_chat_completion(chat_input)
display(Markdown(non_rag_result.results[0].next_message.content))

The model's output contains a hallucinated address. 
See [this website](https://solanolibrary.com/hours-and-locations/law-library/) for 
the actual location of the Solano County courthouse's law library.

Now let's spin up an in-memory vector database, using embeddings that we've precomputed
offline from the MTRAG benchmark's government corpus.

In [None]:
# Spin up an in-memory vector database
retriever = InMemoryRetriever(embeddings_data_file, embedding_model_name)

We attach a RequestProcessor to our vector database so that we can augment the chat 
completion request with retrieved document snippets.

In [None]:
retrieval_request_proc = RetrievalRequestProcessor(retriever, top_k=3)
rag_chat_input = retrieval_request_proc.process(chat_input)[0]
rag_chat_input.documents

The retriever operates over the last user turn, and this turn does not mention 
Fairfield or Solano County, so the snippets retrieved will not be specific to the 
Solano County courthouse.

Let's see what happens if we run our request through the model using the low-quality 
RAG snippets from the previous cell.


In [None]:
rag_result = io_proc.create_chat_completion(rag_chat_input)
display(Markdown(rag_result.results[0].next_message.content))

This address is also incorrect due to the low-quality retrieved snippets.

We can use the [LoRA Adapter for Answerability Classification](
    https://huggingface.co/ibm-granite/granite-3.2-8b-lora-rag-answerability-prediction)
to detect this kind of problem. Here's what happens if we run the chat completion 
request with faulty documents snippets through the answerability model, using the
`granite_io` IO processor for the model to handle input and output:

In [None]:
answerability_proc = AnswerabilityIOProcessor(answerability_lora_backend)
answerability_proc.create_chat_completion(rag_chat_input).results[
    0
].next_message.content

The answerability model detects that the documents we have retrieved cannot be used to
answer the user's question. We use use a composite IO processor to wrap this check in
a flow that falls back on canned response.

In [None]:
answerability_composite_proc = AnswerabilityCompositeIOProcessor(
    io_proc, answerability_proc
)
composite_result = answerability_composite_proc.create_chat_completion(
    rag_chat_input
).results[0]
composite_result.next_message.content

We can use use the [LoRA Adapter for Query Rewrite](
    https://huggingface.co/ibm-granite/granite-3.2-8b-lora-rag-query-rewrite) to rewrite
the last user turn into a string that is more useful for retrieiving document snippets.
Here's what we get if we call this model directly on the original request:

In [None]:
rewrite_io_proc = QueryRewriteIOProcessor(query_rewrite_lora_backend)
rewrite_io_proc.create_chat_completion(chat_input).results[0].next_message.content

We can wrap the IO processor for this model in a request processor that rewrites
the last turn of the chat completion request, then chain that request processor to
the request processor for retrieval that we used earlier:

In [None]:
rewrite_request_proc = RewriteRequestProcessor(rewrite_io_proc)

request = rewrite_request_proc.process(chat_input)[0]
request = retrieval_request_proc.process(request)[0]
rag_rewrite_result = io_proc.create_chat_completion(request).results[0]
display(Markdown(rag_rewrite_result.next_message.content))

Finally we get the correct address!

We can also augment the above result with citations back to the supporting documents
by using the [LoRA Adapter for Citation Generation](
    https://huggingface.co/ibm-granite/granite-3.2-8b-lora-rag-citation-generation
).

In [None]:
citations_composite_proc = CitationsCompositeIOProcessor(
    io_proc, citations_lora_backend
)
result_with_citations = citations_composite_proc.create_chat_completion(
    request
).results[0]

print("Assistant response:")
display(Markdown(result_with_citations.next_message.content))
print("Citations:")
pd.set_option("display.max_colwidth", 200)
pd.DataFrame.from_records(
    [c.model_dump() for c in result_with_citations.next_message.citations]
)