## This notebook uses llama-stack-client to do handle database queries

In [1]:
import asyncio
import os

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.client_tool import client_tool
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.tool_group import McpEndpoint
from rich.pretty import pprint
import rich
import json
import uuid
from pydantic import BaseModel
from typing import List
#from llama_stack.distribution.library_client import LlamaStackAsLibraryClient

from dotenv import load_dotenv

load_dotenv()
#BRAVE_SEARCH_API_KEY = os.environ["BRAVE_SEARCH_API_KEY"]
HOST=os.environ["HOST"]
PORT=os.environ["LLAMA_STACK_PORT"]
MODEL_NAME=os.environ["INFERENCE_MODEL"]
#TAVILY_SEARCH_API_KEY=os.environ["TAVILY_API_KEY"]

In [2]:
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")

#npx -y supergateway --port 8000 --stdio 'npx -y @modelcontextprotocol/server-filesystem /tmp/content'
#npx -y supergateway --port 8000 --stdio 'npx -y @modelcontextprotocol/server-postgres "postgresql://postgres:xxxx@localhost:5432/search"
client.toolgroups.register(
    toolgroup_id="mcp::dbsearch",
    provider_id="model-context-protocol",
    mcp_endpoint=McpEndpoint(uri="http://localhost:8000/sse"),
)

for tool in client.tools.list() :
    print(tool)
    print('-----')
# This will now appear in the output
#Tool(description='Run a read-only SQL query', identifier='query', 
# parameters=[Parameter(description='', name='sql', parameter_type='string', required=True, default=None)], 
# provider_id='model-context-protocol', provider_resource_id='query', tool_host='model_context_protocol', 
# toolgroup_id='mcp::dbsearch', type='tool', metadata={'endpoint': 'http://localhost:8000/sse'})

Tool(description='Execute code', identifier='code_interpreter', parameters=[Parameter(description='The code to execute', name='code', parameter_type='string', required=True, default=None)], provider_id='code-interpreter', provider_resource_id='code_interpreter', tool_host='distribution', toolgroup_id='builtin::code_interpreter', type='tool', metadata=None)
-----
Tool(description='Insert documents into memory', identifier='insert_into_memory', parameters=[], provider_id='rag-runtime', provider_resource_id='insert_into_memory', tool_host='distribution', toolgroup_id='builtin::rag', type='tool', metadata=None)
-----
Tool(description='Search for information in a database.', identifier='knowledge_search', parameters=[Parameter(description='The query to search for. Can be a natural language sentence or keywords.', name='query', parameter_type='string', required=True, default=None)], provider_id='rag-runtime', provider_resource_id='knowledge_search', tool_host='distribution', toolgroup_id='bu

In [6]:
rag_agent = Agent(
    client,
    model=MODEL_NAME,
    #brute force prompt asking to the use the tool
    instructions="You are a helpful assistant that can answer queries from the database. \
    Use the following methodology to give an answer. Remember, the SQL Query may not be given explicity.\
    1. Firstly find out how many tables are in the database schema. \
    2. Then examine the column name and datatype of each table.\
    3. Then examine 3 rows from each table to get an idea of the data it has. \
    4. Then figure out which table may have the answer.\
    5. Then and only then can you formulate to the real sql query to the answer the question.\
    Always use the query tool that has been given to you",
    tools=["mcp::dbsearch"],
    enable_session_persistence=True,
    max_infer_iters=10,
    # Configure safety (optional)
    input_shields=["meta-llama/Llama-Guard-3-1B"],
    output_shields=["meta-llama/Llama-Guard-3-1B"],
)
examples = [
    "how many different kinds resources are there",
    #"how many different kinds resources are there in the resources table",
    "how many pods are there in the resources table ",
    #"can you get me the schema of the database",
]
for example in examples:
    rag_session_id = rag_agent.create_session(session_name=f"rag_session_{uuid.uuid4()}")
    response = rag_agent.create_turn(
        messages=[
            {
                "role": "user",
                "content": example
            }
        ],
        session_id=rag_session_id,
        stream=False
    )
    pprint(response)
    rich.print(f"[bold yellow]Agent Answer:[/bold yellow] {response.output_message.content}")

    #for log in EventLogger().log(response):
    #    log.print()
    session_response = client.agents.session.retrieve( session_id=rag_session_id, agent_id=rag_agent.agent_id,)
    pprint(session_response)