# Demo of retrieval IO processor with MTRAG benchmark data

This notebook shows how to use the retrieval IO processor to implement the retrieval 
phase of Retrieval-Augmented Generation (RAG) on top of Granite 3.3.

This notebook can run its own vLLM server to perform inference, or you can host the 
model on your own server. 

To use your own server, set the `run_server` variable below
to `False` and set appropriate values for the constants in the cell marked
`# Constants go here`.

In [None]:
import pathlib
from granite_io.io.granite_3_3.input_processors.granite_3_3_input_processor import (
    Granite3Point3Inputs,
)
from granite_io import make_io_processor, make_backend
from granite_io.io.retrieval import InMemoryRetriever, RetrievalRequestProcessor
from granite_io.io.retrieval.util import download_mtrag_embeddings
from granite_io.backend.vllm_server import LocalVLLMServer
from IPython.display import display, Markdown
import pandas as pd
import os

In [None]:
# Constants go here
temp_data_dir = "../data/test_retrieval_temp"
corpus_name = "govt"
embeddings_data_file = pathlib.Path(temp_data_dir) / f"{corpus_name}_embed.parquet"
embedding_model_name = "multi-qa-mpnet-base-dot-v1"
model_name = "ibm-granite/granite-3.3-8b-instruct"

run_server = True

In [None]:
if run_server:
    # Start by firing up a local vLLM server and connecting a backend instance to it.
    server = LocalVLLMServer(model_name)
    server.wait_for_startup(200)
    backend = server.make_backend()
else:  # if not run_server
    # Use an existing server.
    # The constants here are for the server that local_vllm_server.ipynb starts.
    # Modify as needed.
    openai_base_url = "http://localhost:55555/v1"
    openai_api_key = "granite_intrinsics_1234"
    backend = make_backend(
        "openai",
        {
            "model_name": model_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )

In [None]:
# Download the indexed corpus if it hasn't already been downloaded.
# This notebook uses a subset of the government corpus from the MTRAG benchmark.
embeddings_location = f"{temp_data_dir}/{corpus_name}_embeds.parquet"
if not os.path.exists(embeddings_location):
    download_mtrag_embeddings(embedding_model_name, corpus_name, embeddings_location)
embeddings_location

In [None]:
# Spin up an IO processor for the base model
io_proc = make_io_processor(model_name, backend=backend)
io_proc

In [None]:
# Create an example chat completions request
chat_input = Granite3Point3Inputs.model_validate(
    {
        "messages": [
            {
                "role": "assistant",
                "content": "Welcome to the California Appellate Courts help desk.",
            },
            {
                "role": "user",
                "content": "I need to do some legal research to be prepared for my "
                "oral argument. Can I visit the law library?",
            },
        ],
        "generate_inputs": {
            "temperature": 0.0,
            "max_tokens": 4096,
        },
    }
)
chat_input

In [None]:
# Run the chat completion request through the base model without RAG.
# The result should be a refusal message that starts with, "As an AI, I don't have
# physical locations or resources."
non_rag_result = io_proc.create_chat_completion(chat_input)
display(Markdown(non_rag_result.results[0].next_message.content))

In [None]:
# Spin up an in-memory vector database
retriever = InMemoryRetriever(embeddings_location, embedding_model_name)

In [None]:
# Use a RetrievalRequestProcessor to augment the chat completion request with documents.
rag_processor = RetrievalRequestProcessor(retriever)
rag_chat_input = rag_processor.process(chat_input)[0]
pd.set_option("display.max_colwidth", 200)
print("Documents:")
pd.DataFrame.from_records([d.model_dump() for d in rag_chat_input.documents])

In [None]:
# Run the same request through the base model with RAG documents
rag_result = io_proc.create_chat_completion(rag_chat_input)
display(Markdown(rag_result.results[0].next_message.content))

In [None]:
# Free up GPU resources
if "server" in locals():
    server.shutdown()