## SQL Retriever Tool 

### LLM Used - Granite3.1-8B

In order to work with this notebook please deploy the example database on this repository:

```
kubectl create ns agentic-zone
kubectl apply -k bootstrap/database/
```

### 1. Setup and Import Libraries

To get started, you'll need to install and import a few Python libraries. Run the following command to install them:

In [1]:
!pip install -q psycopg2 tabulate langgraph==0.2.35 langchain_experimental==0.0.65 langchain-openai==0.1.25 termcolor==2.3.0 duckduckgo_search==7.1.0 openapi-python-client==0.12.3 langchain_community==0.2.19 wikipedia==1.4.0


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [20]:
# Imports
import os
import json
import getpass

from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.chains import LLMChain
from langchain_openai import ChatOpenAI
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.prompts import PromptTemplate
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage

from typing_extensions import TypedDict
from typing import Annotated
from langchain_core.tools import tool

from langchain_community.utilities import SQLDatabase

### 2. Configure the Database connection and test

In [17]:
# Database connection details
# TODO: Fetch this from the variables
dbname = 'agenticdb'
user = 'agenticdb'
password = 'agenticdb'
host = 'agenticdb.agentic-zone.svc.cluster.local'
port = '5432'

# Setup PostgreSQL URI
uri = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"

# Initialize the SQLDatabase connection
try:
    # Connect to the PostgreSQL database using LangChain's SQLDatabase utility
    db = SQLDatabase.from_uri(uri)

    tables_query = """
    SELECT table_name FROM information_schema.tables WHERE table_schema = 'agenticdb';
    """
    tables_result = db.run(tables_query)
    print("\n## List of Tables in 'agenticdb' schema:")

    # Ensure result is properly parsed
    if isinstance(tables_result, str):
        tables_result = tables_result.strip().split("\n")

    if tables_result:
        for row in tables_result:
            print(f"Table found: {row.strip()}")
    else:
        print("No tables found in 'agenticdb' schema.")

    schema_query = """
    SELECT column_name, data_type
    FROM information_schema.columns
    WHERE table_schema = 'agenticdb'
    AND table_name = 'transactions';
    """
    schema_result = db.run(schema_query)

    print("\n## Schema for 'agenticdb.transactions' table:")

    # Convert to list if returned as a string
    if isinstance(schema_result, str):
        schema_result = schema_result.strip().split("\n")

    if schema_result:
        for row in schema_result:
            print(f" Column: {row}")
    else:
        print("No schema found for 'agenticdb.transactions'.")

    # Example query to retrieve transactions
    example_query = """
    SELECT transaction_id, client_name, transaction_type, stock_symbol, shares, price_per_share,
           (shares * price_per_share) AS total_value
    FROM agenticdb.transactions
    WHERE transaction_type = 'BUY' AND price_per_share > 300;
    """

    query_result = db.run(example_query)
    print("\n## Query Execution Result:")

    # Convert to list if returned as a string
    if isinstace(query_result, str):
        query_result = query_result.strip().split("\n") 

    if query_result:
        for txn in query_result:
            print(f"Transaction: {txn}")
    else:
        print(" No transactions found.")

except Exception as e:
    print(f" An error occurred while interacting with the PostgreSQL database: {e}")



## List of Tables in 'agenticdb' schema:
Table found: [('transactions',)]

## Schema for 'agenticdb.transactions' table:
 Column: [('id', 'integer'), ('transaction_id', 'text'), ('client_name', 'text'), ('transaction_type', 'text'), ('stock_symbol', 'text'), ('shares', 'integer'), ('price_per_share', 'numeric'), ('broker', 'text'), ('transaction_time', 'timestamp without time zone')]

## Query Execution Result:
Transaction: [('TXN1003', 'Emma Davis', 'BUY', 'MSFT', 200, Decimal('340.10'), Decimal('68020.00'))]


### 3. Model Configuration

We will start by creating an llm instance, defined by the location where the LLM API can be queried and some parameters that will be applied to the model.


#### 3.1 Define the Inference Model Server specifics

In [4]:
INFERENCE_SERVER_URL = os.getenv('API_URL_GRANITE')
MODEL_NAME = "granite-3-8b-instruct"
API_KEY= os.getenv('API_KEY_GRANITE')

#### 3.2 Create the LLM instance

In [26]:
llm = ChatOpenAI(
    openai_api_key=API_KEY,
    openai_api_base= f"{INFERENCE_SERVER_URL}/v1",
    model_name=MODEL_NAME,
    top_p=0.92,
    temperature=0.01,
    max_tokens=512,
    presence_penalty=1.03,
    streaming=True,
    callbacks=[StreamingStdOutCallbackHandler()],
    verbose=True
)

### 4. Use LLMs to generate and execute SQL Queries into the DB

In [27]:
from tabulate import tabulate

## 1. Generate an SQL query using LLM
query = "Get all 'BUY' transactions where the price per share is greater than 300."

# Ensure the model only outputs the SQL query
print("## LLM Answer: ")
llm_response = llm.invoke(f"Generate a PostgreSQL query for this request. Return ONLY the SQL query, no explanations, no extra text:\n\n{query}")

## 2. Extract ONLY the SQL query from the AI response
# Handles AIMessage format
if hasattr(llm_response, "content"):
    raw_response = llm_response.content.strip()
# Handles plain string output
elif isinstance(llm_response, str):
    raw_response = llm_response.strip()
else:
    print("\n## ERROR: Invalid response format from LLM")
    raw_response = None

# Extract SQL query safely
# TODO: Handle with Structure Outputs
if raw_response:
    sql_query = raw_response.split("```sql")[-1].split("```")[0].strip() if "```sql" in raw_response else raw_response
else:
    print("\n## WARNING: No valid SQL query extracted")
    sql_query = None

## 3. Execute the SQL query
if sql_query:
    print("\n")  # Add a blank line before the query output
    print("## SQL Query Generated:")
    print(sql_query)
    
    print("\n## Running Query...")
    result = db.run(sql_query)

    print("\n## Query Results:")
    print(result)
        
else:
    print("\n## WARNING: No valid SQL query extracted. ##")


## LLM Answer: 
SELECT * FROM transactions
WHERE transaction_type = 'BUY' AND price_per_share > 300;

## SQL Query Generated:
SELECT * FROM transactions
WHERE transaction_type = 'BUY' AND price_per_share > 300;

## Running Query...

## Query Results:
[(3, 'TXN1003', 'Emma Davis', 'BUY', 'MSFT', 200, Decimal('340.10'), 'JPMorgan Chase', datetime.datetime(2024, 6, 15, 11, 45))]


### 5. Define tools for execute SQL Queries

Usage of [Tool Calling / Function Calling](https://ai-on-openshift.io/odh-rhoai/enable-function-calling/#how-to-enable-function-calling-with-vllm-in-openshift-ai) enabling the LLM to interact with external tools like the execute_sql_query in a structured way. 

Gives the LLM with functions (or tools) to perform actions like query the database.

In [24]:
## Define Custom Tools (SQL Query Tool Retriever)

from langchain_core.tools import Tool

# Define a tool that allows the LLM to execute SQL queries
def execute_sql_query(query: str):
    """Executes SQL query and returns results"""
    try:
        result = db.run(query)
        return result
    except Exception as e:
        return f"SQL Execution Error: {e}"

sql_query_tool = Tool(
    name="sql_db_query",
    func=execute_sql_query,
    description="Executes SQL queries and returns the results."
)

We will define the list of Tools that will be used by our AI Agents (in this case the SQL query Tool):

In [19]:
llm_with_tools = llm.bind_tools([sql_query_tool], tool_choice="auto")

### 6. Invoke LLM to generate SQL Query

Process the user query invoking the LLM, to generate the function calling to the sql_db_query Tool and process the execution result

In [32]:
def handle_sql_query(user_query):
    """
    Processes a user query by invoking the LLM, extracting an SQL query, 
    correcting common errors, and executing it.

    Parameters:
        user_query (str): The user's natural language request for a database query.

    Returns:
        None (prints the generated SQL query and its execution result)
    """

    # Pass the query to LLM
    messages = [HumanMessage(user_query)]
    ai_msg = llm_with_tools.invoke(messages)

    # Check for tool calls from the LLM
    if hasattr(ai_msg, "tool_calls") and ai_msg.tool_calls:
        for tool_call in ai_msg.tool_calls:
            # Debugging the tool call structure
            print("\n## Tool Call Output:\n", tool_call)

            if tool_call["name"] == "sql_db_query":
                # Extract SQL query dynamically
                sql_args = tool_call["args"]
                sql_query = sql_args.get("query") or sql_args.get("__arg1") or str(sql_args)

                print(f"\n## Generated SQL Query:\n{sql_query}")

                # Correct the SQL query before execution (Fix column name error)
                corrected_query = sql_query.replace("type", "transaction_type") 

                # Execute the corrected query
                print("\n## Running Query...")
                query_result = execute_sql_query(corrected_query)

                # Print the execution result
                print("\n## Query Execution Result:\n")
                print(query_result)
    else:
        print("\n## No tool calls detected. LLM did not generate an SQL query.")

### 7. LLM and SQL Query Tool Calling in action!

In [33]:
# Example usage
handle_sql_query("Retrieve all BUY transactions where price per share is above 200")

<tool_call>
## Tool Call Output:
 {'name': 'sql_db_query', 'args': {'__arg1': "SELECT * FROM transactions WHERE type = 'BUY' AND price_per_share > 200"}, 'id': 'chatcmpl-tool-835db03070f647e3860447605040d827', 'type': 'tool_call'}

## Generated SQL Query:
SELECT * FROM transactions WHERE type = 'BUY' AND price_per_share > 200

## Running Query...

## Query Execution Result:

[(3, 'TXN1003', 'Emma Davis', 'BUY', 'MSFT', 200, Decimal('340.10'), 'JPMorgan Chase', datetime.datetime(2024, 6, 15, 11, 45))]


In [39]:
handle_sql_query("Retrieve all transactions where the client name contains 'Smith'")

<tool_call>
## Tool Call Output:
 {'name': 'sql_db_query', 'args': {'__arg1': "SELECT * FROM transactions WHERE client_name LIKE '%Smith%'"}, 'id': 'chatcmpl-tool-672452a3f07e4a64a06f2804610fbccf', 'type': 'tool_call'}

## Generated SQL Query:
SELECT * FROM transactions WHERE client_name LIKE '%Smith%'

## Running Query...

## Query Execution Result:

[(2, 'TXN1002', 'Michael Smith', 'SELL', 'TSLA', 100, Decimal('245.20'), 'Charles Schwab', datetime.datetime(2024, 6, 15, 10, 15))]
