# Demo of LoRA adapter for query rewrite

This notebook shows the usage of the IO processor for the Granite query rewrite
intrisic, also known as the [LoRA Adapter for Query Rewrite](
    https://huggingface.co/ibm-granite/granite-3.2-8b-lora-rag-query-rewrite
)

This notebook can run its own vLLM server to perform inference, or you can host the 
models on your own server. To use your own server, set the `run_server` variable below
to `False` and set appropriate values for the constants 
`openai_base_url`, `openai_base_model_name` and `openai_lora_model_name`.

In [1]:
# Imports go here
from granite_io.io.query_rewrite import QueryRewriteIOProcessor
from granite_io.io.granite_3_3.input_processors.granite_3_3_input_processor import (
    Granite3Point3Inputs,
)
from granite_io.backend.vllm_server import LocalVLLMServer
from granite_io import make_backend
from granite_io.io.rag_agent_lib import obtain_lora

In [3]:
# Constants go here
base_model_name = "ibm-granite/granite-3.3-8b-instruct"
# TEMPORARY: Load LoRA adapter locally
lora_model_name = "query_rewrite"
run_server = True

In [None]:
if run_server:
    # Start by firing up a local vLLM server and connecting a backend instance to it.
    # Download and cache the model's LoRA adapter.
    lora_model_path = obtain_lora(lora_model_name)
    print(f"Local path to LoRA adapter: {lora_model_path}")
    server = LocalVLLMServer(
        base_model_name, lora_adapters=[(lora_model_name, lora_model_path)]
    )
    server.wait_for_startup(200)
    lora_backend = server.make_lora_backend(lora_model_name)
    backend = server.make_backend()
else:  # if not run_server
    # Use an existing server.
    # Modify the constants here as needed.
    openai_base_url = "http://localhost:55555/v1"
    openai_api_key = "granite_intrinsics_1234"
    openai_base_model_name = base_model_name
    openai_lora_model_name = lora_model_name
    backend = make_backend(
        "openai",
        {
            "model_name": openai_base_model_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )
    lora_backend = make_backend(
        "openai",
        {
            "model_name": openai_lora_model_name,
            "openai_base_url": openai_base_url,
            "openai_api_key": openai_api_key,
        },
    )

In [None]:
# Create an example chat completion with a short conversation.
chat_input = Granite3Point3Inputs.model_validate(
    {
        "messages": [
            {"role": "assistant", "content": "Welcome to pet questions!"},
            {
                "role": "user",
                "content": "I have two pets, a dog named Rex and a cat named Lucy.",
            },
            {
                "role": "assistant",
                "content": "Great, what would you like to share about them?",
            },
            {
                "role": "user",
                "content": "Rex spends a lot of time in the backyard and outdoors, "
                "and Luna is always inside.",
            },
            {
                "role": "assistant",
                "content": "Sounds good! Rex must love exploring outside, while Lucy "
                "probably enjoys her cozy indoor life.",
            },
            {
                "role": "user",
                "content": "But is he more likely to get fleas because of that?",
            },
        ],
        "generate_inputs": {"temperature": 0.0},
    }
)
chat_input

In [None]:
# Instantiate the I/O processor for the LoRA adapter
io_proc = QueryRewriteIOProcessor(lora_backend)

# Pass our example input through the I/O processor and retrieve the result
chat_result = await io_proc.acreate_chat_completion(chat_input)
print(chat_result.results[0].next_message.model_dump_json(indent=2))

In [7]:
# Free up GPU resources
if "server" in locals():
    server.shutdown()