# Lab 3.1 – Agent with watsonx.ai + Accelerator RAG API

This notebook implements the Lab 3.1 agent in the `simple-watsonx-environment`.

High-level flow:

1. The accelerator FastAPI service exposes a `/ask` endpoint for RAG.
2. We wrap `/ask` as a **tool** (`rag_service_tool`).
3. We add a small **calculator tool** for arithmetic.
4. Granite on watsonx.ai acts as a **planner + final answer generator**.

You can adapt this to your own models and endpoints as needed.


In [None]:
# Install dependencies (run once per environment)
!pip install -q "ibm-watsonx-ai>=1.1.22" requests pydantic


In [None]:
import os
import json
import time
import ast
import operator as op
from typing import Dict, Any

import requests
from pydantic import BaseModel
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.foundation_models import Model
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from ibm_watsonx_ai.foundation_models.utils.enums import DecodingMethods

# --- Configuration ---
ACCELERATOR_API_URL = os.getenv("ACCELERATOR_API_URL", "http://localhost:8000/ask")
WATSONX_URL = os.getenv("WATSONX_URL", "https://us-south.ml.cloud.ibm.com")
WATSONX_APIKEY = os.getenv("WATSONX_APIKEY") or input("WATSONX_APIKEY: ")
WATSONX_PROJECT_ID = os.getenv("WATSONX_PROJECT_ID") or input("WATSONX_PROJECT_ID: ")
LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "ibm/granite-3-3-8b-instruct")

creds = Credentials(url=WATSONX_URL, api_key=WATSONX_APIKEY)
params = {
    GenParams.DECODING_METHOD: DecodingMethods.GREEDY,
    GenParams.MAX_NEW_TOKENS: 512,
    GenParams.MIN_NEW_TOKENS: 1,
    GenParams.TEMPERATURE: 0.2,
}

planner_model = Model(
    model_id=LLM_MODEL_ID,
    credentials=creds,
    project_id=WATSONX_PROJECT_ID,
    params=params,
)


## Define Tools

We define two tools:

- `rag_service_tool(question)` – calls the accelerator `/ask` endpoint.
- `calculator_tool(expression)` – safely evaluates arithmetic expressions.


In [None]:
def rag_service_tool(question: str) -> Dict[str, Any]:
    """Call the accelerator `/ask` endpoint with the given question.

    Returns a dict with keys like `answer`, `citations`, `model_id`, `latency_ms`.
    The exact schema depends on your accelerator implementation.
    """
    payload = {"question": question}
    start = time.time()
    resp = requests.post(ACCELERATOR_API_URL, json=payload, timeout=60)
    latency_ms = int((time.time() - start) * 1000)
    resp.raise_for_status()
    data = resp.json()
    data.setdefault("latency_ms", latency_ms)
    return data


# --- Safe calculator implementation using AST ---
_allowed_operators = {
    ast.Add: op.add,
    ast.Sub: op.sub,
    ast.Mult: op.mul,
    ast.Div: op.truediv,
    ast.Pow: op.pow,
    ast.Mod: op.mod,
}


def _eval_ast(node):
    if isinstance(node, ast.Num):  # type: ignore[attr-defined]
        return node.n
    if isinstance(node, ast.BinOp) and type(node.op) in _allowed_operators:
        return _allowed_operators[type(node.op)](_eval_ast(node.left), _eval_ast(node.right))
    if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)):
        value = _eval_ast(node.operand)
        return +value if isinstance(node.op, ast.UAdd) else -value
    raise ValueError("Unsupported expression")


def calculator_tool(expression: str) -> str:
    """Safely evaluate a simple arithmetic expression.

    Supports +, -, *, /, %, and exponentiation. No function calls or variables.
    """
    try:
        parsed = ast.parse(expression, mode="eval")
        result = _eval_ast(parsed.body)
        return str(result)
    except Exception as e:
        return f"Error evaluating expression: {e}"


## Planner Schema & Prompt

We ask Granite to output **JSON** describing which tool to call and with what arguments.


In [None]:
class ToolPlan(BaseModel):
    tool: str
    arguments: Dict[str, Any]


PLANNER_SYSTEM_PROMPT = (
    "You are a planner agent. You must choose exactly ONE tool per request.\n\n"
    "Available tools:\n"
    "- rag_service: Use this to answer enterprise questions using the /ask RAG API.\n"
    "- calculator: Use this to evaluate arithmetic expressions like '2 * (3 + 4)'.\n\n"
    "Return a JSON object with keys 'tool' and 'arguments'. Do not include any extra text.\n"
    "If the user is clearly asking a math question, prefer 'calculator'. Otherwise, prefer 'rag_service'."
)


def plan_tool_call(user_input: str) -> ToolPlan:
    """Ask the LLM which tool to use and with what arguments."""
    user_prompt = (
        f"User input: {user_input}\n\n"
        "Respond ONLY with JSON, for example:\n"
        '{"tool": "calculator", "arguments": {"expression": "2 + 2"}}'
    )
    prompt = f"{PLANNER_SYSTEM_PROMPT}\n\n{user_prompt}"
    raw = planner_model.generate_text(prompt=prompt)
    text = raw["results"][0]["generated_text"].strip()
    # Try to extract JSON
    try:
        plan_dict = json.loads(text)
    except json.JSONDecodeError:
        # Fallback: try to find JSON substring
        start = text.find("{")
        end = text.rfind("}") + 1
        if start >= 0 and end > start:
            plan_dict = json.loads(text[start:end])
        else:
            raise ValueError(f"Could not parse JSON from planner output: {text!r}")
    return ToolPlan(**plan_dict)


## Final Answer Step

After tool execution, we call the LLM again to generate a user-friendly answer.


In [None]:
FINAL_ANSWER_SYSTEM = (
    "You are a helpful assistant. You will be given: the original user question, the tool you used, "
    "and the tool output. Compose a clear and concise final answer. "
    "If the tool output indicates an error, explain the error and suggest what the user can try."
)


def generate_final_answer(user_input: str, tool_name: str, tool_output: Any) -> str:
    """Call the LLM to turn tool output into a final answer."""
    context = (
        f"User question: {user_input}\n"
        f"Tool used: {tool_name}\n"
        f"Tool output: {tool_output}\n"
        "\nPlease write the final answer for the user."
    )
    prompt = f"{FINAL_ANSWER_SYSTEM}\n\n{context}"
    raw = planner_model.generate_text(prompt=prompt)
    return raw["results"][0]["generated_text"].strip()


## Orchestrate One Turn of the Agent

`run_agent_once` ties everything together:

1. Planner picks a tool.
2. Python executes the tool.
3. LLM generates a final answer.
4. We return a structured record that you can log or analyze later.


In [None]:
def run_agent_once(user_input: str) -> Dict[str, Any]:
    plan = plan_tool_call(user_input)
    tool_name = plan.tool
    args = plan.arguments or {}

    if tool_name == "rag_service":
        question = args.get("question") or user_input
        tool_output = rag_service_tool(question)
    elif tool_name == "calculator":
        expr = args.get("expression") or user_input
        tool_output = calculator_tool(expr)
    else:
        tool_output = f"Unknown tool: {tool_name}"

    final_answer = generate_final_answer(user_input, tool_name, tool_output)

    return {
        "question": user_input,
        "tool": tool_name,
        "tool_args": args,
        "tool_output": tool_output,
        "final_answer": final_answer,
    }


## Quick Smoke Tests

Run a couple of tests:

- A RAG-style question (should pick `rag_service`).
- A pure math question (should pick `calculator`).


In [None]:
test_questions = [
    "What is RAG and why do we use it?",
    "What is 2 * (3 + 4)?",
]

for q in test_questions:
    print("=" * 80)
    print("Q:", q)
    result = run_agent_once(q)
    print("Tool:", result["tool"], "Args:", result["tool_args"])
    print("Final answer:\n", result["final_answer"])
