### Retriever API Usage

This notebook showcases how to use the NVIDIA RAG retriever APIs to fetch relevant document passages based on user queries and also generate responses using end-to-end RAG APIs.


- Ensure the rag-server container is running before executing the notebook by following the steps in [Get Started](../docs/deploy-docker-self-hosted.md).
- Please run the [ingestion notebook](./ingestion_api_usage.ipynb) as a prerequisite to using this notebook.
- Replace `IP_ADDRESS` with the actual server URL if the API is hosted on another system.

You can now execute each cell in sequence to test the API.
#### 1. Install Dependencies

In [None]:
!pip install aiohttp
import json
import os

import aiohttp

#### 2. Setup Base Configuration

In [None]:
IPADDRESS = (
    "rag-server" if os.environ.get("AI_WORKBENCH", "false") == "true" else "localhost"
)  # Replace this with the correct IP address
RAG_SERVER_PORT = "8081"
BASE_URL = f"http://{IPADDRESS}:{RAG_SERVER_PORT}"  # Replace with your server URL


async def print_response(response):
    """Helper to print API response."""
    try:
        response_json = await response.json()
        print(json.dumps(response_json, indent=2))
    except aiohttp.ClientResponseError:
        print(await response.text())

#### 3. Health Check Endpoint

**Purpose:**
This endpoint performs a health check on the server. It returns a 200 status code if the server is operational. It also returns the status of the dependent services.

In [None]:
async def fetch_health_status():
    """Fetch health status asynchronously."""
    url = f"{BASE_URL}/v1/health"
    params = {"check_dependencies": "True"}  # Check health of dependencies as well
    async with aiohttp.ClientSession() as session:
        async with session.get(url, params=params) as response:
            await print_response(response)


# Run the async function
await fetch_health_status()

#### 4. Generate Answer Endpoint

**Purpose:**
This endpoint generates a streaming AI response to a given user message. The system message is specified in the [prompts.yaml]((../src/nvidia_rag/rag_server/prompt.yaml)) file. This API retrieves the relevant chunks related to the query from knowledge base, adds them as part of the LLM prompt and returns a streaming response. It supports parameters like temperature, top_p, knowledge base usage, and also generates based on the specified vector collection. 

The API endpoint also returns multimodal base64 encoded data if the cited source is an image as part of the returned document chunks. The citations field is always populated as part of the first chunk returned in the streaming response.

In [None]:
url = f"{BASE_URL}/v1/generate"
payload = {
    "messages": [
        {
            "role": "user",
            "content": "How does the price of bluetooth speaker compare with hammer?",
        }
    ],
    "use_knowledge_base": True,
    "temperature": 0.2,
    "top_p": 0.7,
    "max_tokens": 1024,
    "reranker_top_k": 2,
    "vdb_top_k": 10,
    "vdb_endpoint": "http://milvus:19530",
    "collection_names": ["multimodal_data"],
    "enable_query_rewriting": True,
    "enable_reranker": True,
    "enable_citations": True,
    "stop": [],
    "filter_expr": "",
    # Override model endpoints and details if needed
    #"model": "nvidia/llama-3.3-nemotron-super-49b-v1.5",
    #"reranker_model": "nvidia/llama-3.2-nv-rerankqa-1b-v2",
    #"embedding_model": "nvidia/llama-3.2-nv-embedqa-1b-v2",
    #"llm_endpoint": "",
    #"embedding_endpoint": "",
    #"reranker_endpoint": "",
}


async def generate_answer(payload):
    async with aiohttp.ClientSession() as session:
        try:
            async with session.post(url=url, json=payload) as response:
                await print_response(response)
        except aiohttp.ClientError as e:
            print(f"Error: {e}")


await generate_answer(payload)

#### 5. Parse RAG Metrics from the last chunk of the streaming response
**Purpose:**
This cell is an extension of the generate endpoint call, where the streaming response from the endpoint is parsed in chunks and the response content as well as RAG metrics are extraced from these chunks.

In [None]:
async def parse_metrics_from_response(response):
    buffer = ""
    async for chunk in response.content.iter_chunked(8192):
        if chunk:
            # Decode the chunk and add to buffer
            chunk_str = chunk.decode("utf-8")
            buffer += chunk_str

            # Process complete lines in the buffer
            lines = buffer.split('\n')
            buffer = lines[-1]  # Keep the incomplete last line in buffer

            for line in lines[:-1]:  # Process all complete lines
                line = line.strip()
                if line.startswith("data: "):
                    json_str = line[6:]  # Remove "data: " prefix
                    if json_str:  # Skip empty data lines
                        try:
                            data = json.loads(json_str)

                            # Process the message content
                            message = data.get("choices", [{}])[0].get("message", {}).get("content", "")
                            if message:
                                print(message, end="")  # Stream the message content

                            # Check if it's the last chunk based on the "finish_reason"
                            finish_reason = data.get("choices", [{}])[0].get("finish_reason")
                            if finish_reason == "stop":
                                print("\nMetrics:")
                                metrics = data.get("metrics", {})
                                print(f"LLM Generation Time: {metrics.get('llm_generation_time_ms', 'N/A')} ms")
                                print(f"LLM TTFT Time: {metrics.get('llm_ttft_ms', 'N/A')} ms")
                                print(f"Context Reranker Time: {metrics.get('context_reranker_time_ms', 'N/A')} ms")
                                print(f"Retrieval Time: {metrics.get('retrieval_time_ms', 'N/A')} ms")
                                print(f"RAG TTFT Time: {metrics.get('rag_ttft_ms', 'N/A')} ms")
                                return
                        except json.JSONDecodeError as e:
                            # Skip malformed JSON chunks
                            continue


async def generate_answer(payload):
    async with aiohttp.ClientSession() as session:
        try:
            async with session.post(url=url, json=payload) as response:
                await parse_metrics_from_response(response)
        except aiohttp.ClientError as e:
            print(f"Error: {e}")


# Call the async function
await generate_answer(payload)

#### 6. Document Search Endpoint

**Purpose:**
This endpoint searches for the most relevant documents in the vector store based on a query. You can specify the maximum number of documents to retrieve using `reranker_top_k`.  

The `content` of the document is returned as well, in case of images representing charts or table, in a base64 represention. Developers can use these base64 representations for rendering multimodal citations to end users. The textual representation of this content is available under `description` field of `metadata`.

In [None]:
url = f"{BASE_URL}/v1/search"
payload = {
    "query": "Tell me about robert frost's poems",
    "reranker_top_k": 2,
    "vdb_top_k": 10,
    "vdb_endpoint": "http://milvus:19530",
    "collection_names": [
        "multimodal_data"
    ],  # Multiple collection retrieval can be used by passing multiple collection names
    "messages": [],
    "enable_query_rewriting": False,
    "enable_reranker": True,
    # Override model endpoints and details if needed
    #"reranker_model": "nvidia/llama-3.2-nv-rerankqa-1b-v2",
    #"embedding_model": "nvidia/llama-3.2-nv-embedqa-1b-v2",
    #"embedding_endpoint": "",
    #"reranker_endpoint": "",
}


async def document_seach(payload):
    async with aiohttp.ClientSession() as session:
        try:
            async with session.post(url=url, json=payload) as response:
                await print_response(response)
        except aiohttp.ClientError as e:
            print(f"Error: {e}")


await document_seach(payload)

#### 7. [Optional] Document Search Endpoint with metadata filtering

**Purpose:** Filtering can be performed with custom-metadata provided during ingestion. Similarly `filter_expr` field can be passed in `/generate` endpoint to filter the retrieved chunks from the RAG. 

Before using custom-metadata filtering, kindly ensure the custom metadata is added at ingestion stage. The filtering can be performed using Milvus filtering expression (Reference: [Milvus Filtering](https://milvus.io/docs/boolean.md)). An example is shown below:

In [None]:
url = f"{BASE_URL}/v1/search"
payload = {
    "query": "What is lion doing?",
    "reranker_top_k": 10,
    "vdb_top_k": 100,
    "vdb_endpoint": "http://milvus:19530",
    "collection_names": [
        "multimodal_data"
    ],  # Multiple collection retrieval can be used by passing multiple collection names
    "messages": [],
    "enable_query_rewriting": False,
    "enable_reranker": True,
    "filter_expr": 'content_metadata["meta_field_1"] == "multimodal document"',  # Following is an example filter expression
}


async def document_seach(payload):
    async with aiohttp.ClientSession() as session:
        try:
            async with session.post(url=url, json=payload) as response:
                await print_response(response)
        except aiohttp.ClientError as e:
            print(f"Error: {e}")


await document_seach(payload)

#### 8. [Optional] Retrieve documents summary
You can execute this cell if summary generation was enabled during document upload using `generate_summary: bool` flag.

In [None]:
async def fetch_summary():
    url = f"{BASE_URL}/v1/summary"
    params = {
        "collection_name": "multimodal_data",
        "file_name": "woods_frost.pdf",
        "blocking": "false",
        "timeout": 20,
    }
    async with aiohttp.ClientSession() as session:
        try:
            async with session.get(url, params=params) as response:
                await print_response(response)
        except aiohttp.ClientError as e:
            print(f"Error: {e}")


await fetch_summary()