In [1]:
import getpass
import os
import time
from dotenv import load_dotenv

# Load the .env file
_ = load_dotenv()

### Model

In [2]:
from agno.models.ollama import Ollama

ollama_model=Ollama(id="llama3.2:3b")


### Text 2 sql agent

In [3]:
from agno.agent import Agent

instructions="""
You are an expert in SQL. You can create sql queries from natural language following these schemas:

CREATE TABLE "Orders" (
	date DATETIME, 
	id TEXT, 
	quantity FLOAT
)

CREATE TABLE "Stock" (
	id TEXT, 
	current_stock_quantity FLOAT, 
	units TEXT, 
	avg_lead_time_days BIGINT, 
	maximum_lead_time_days BIGINT, 
	unit_price FLOAT
)

Return only the sql query. Don't explain the query and don't put the sql query in ```sql\n```.
"""
text2sql_agent = Agent(model=ollama_model, name="text2sql", instructions=instructions, show_tool_calls=True)

### Tests

In [4]:
def write_sql_query(query:str):
    """ Function to write sql query """
    start_time = time.perf_counter()

    response = text2sql_agent.run(query)

    end_time = time.perf_counter()
    
    latency = end_time - start_time
    
    print(f"Execution time: {latency} seconds")
    answer = response.content
    
    return answer
    

In [5]:
query = "How many orders are there in the Orders table?"
write_sql_query(query)

Execution time: 20.196944900002563 seconds


'SELECT COUNT id FROM Orders'

In [6]:
query = "Select the orders that have a quantity equal 475 in the Orders table."
write_sql_query(query)

Execution time: 4.552857100003166 seconds


'SELECT date, id, quantity FROM Orders WHERE quantity = 475'

In [7]:
query = "Get all records from the Stock table."
write_sql_query(query)

Execution time: 2.769525199983036 seconds


'SELECT * FROM "Stock"'

In [8]:
query = "List the IDs of products in stock with their current quantity"
write_sql_query(query)

Execution time: 5.24509909999324 seconds


"SELECT id, current_stock_quantity FROM Stock WHERE units = 'in stock'"

In [9]:
query = "Find products with more than 10,000 units in stock"
write_sql_query(query)

Execution time: 4.951010799995856 seconds


'SELECT * FROM Stock WHERE current_stock_quantity > 10000'

In [10]:
query = "Get the total quantity of products in stock"
write_sql_query(query)

Execution time: 3.463996500009671 seconds


'SELECT SUM(current_stock_quantity) FROM Stock'

In [11]:
query = "Get the product with the highest maximum lead time"
write_sql_query(query)

Execution time: 5.830955299985362 seconds


'SELECT unit_price FROM Stock WHERE maximum_lead_time_days = ( SELECT MAX(maximum_lead_time_days) FROM Stock );'

In [12]:
query = "List products in stock that have been ordered at least once"
write_sql_query(query)

Execution time: 6.122881899995264 seconds


'SELECT s.id, s.current_stock_quantity \nFROM Orders o \nJOIN Stock s ON o.id = s.id'

### Benchmark

In [13]:
import json
import glob

def load_benchmark(benchmark_path):
    """ This function loads the benchmark dataset"""
    questions = []
    expected_sql = []
    with open(benchmark_path, 'r', encoding='utf-8') as f:
        try:
            datos = json.load(f)
            questions.extend(datos.keys())
            expected_sql.extend(datos.values())
        except Exception as e:
            print(f,e)
    return questions, expected_sql

questions , expected_sql = load_benchmark("../tests/dataset_queries_en.json")

In [14]:
questions

['Get all records from the Stock table',
 'List the IDs of products in stock with their current quantity',
 'Get the average unit price of all products in stock',
 'Find products with more than 10,000 units in stock',
 'Get the total quantity of products in stock',
 'List orders placed in the year 2023',
 'Get the product with the highest maximum lead time',
 'List products whose unit price is above the average',
 'Get the total number of orders per product',
 'List products in stock that have been ordered at least once']

In [15]:
expected_sql

['SELECT * FROM Stock;',
 'SELECT id, current_stock_quantity FROM Stock;',
 'SELECT AVG(unit_price) AS avg_unit_price FROM Stock;',
 'SELECT id, current_stock_quantity FROM Stock WHERE current_stock_quantity > 10000;',
 'SELECT SUM(current_stock_quantity) AS total_stock FROM Stock;',
 "SELECT * FROM Orders WHERE strftime('%Y', date) = '2023';",
 'SELECT id, maximum_lead_time_days FROM Stock ORDER BY maximum_lead_time_days DESC LIMIT 1;',
 'SELECT id, unit_price FROM Stock WHERE unit_price > (SELECT AVG(unit_price) FROM Stock);',
 'SELECT id, COUNT(*) AS total_orders FROM Orders GROUP BY id;',
 'SELECT DISTINCT s.id, s.current_stock_quantity FROM Stock s JOIN Orders o ON s.id = o.id;']

In [16]:
import pandas as pd
import time

def eval_benchmark(benchmark_path:str)->pd.DataFrame:
	""" This function evaluates the benchmark dataset"""

	# Load the benchmark dataset
	questions , expected_sql = load_benchmark(benchmark_path)
	# Create a dataframe to store the results
	df = pd.DataFrame({"question":questions, "expected_sql":expected_sql})

	latencies = []
	answers = []
	for index, row in df.iterrows():
		try:
			question = row["question"]
			expected_sql = row["expected_sql"]
			
			# Measure the latency
			start_time = time.time()
			# Invoke the agent
			response = text2sql_agent.run(query)
			predicted_sql = response.content
			# Measure the latency
			latency = time.time() - start_time
			latencies.append(latency)
			# Store the predicted SQL query
			answers.append(predicted_sql)
		except Exception as e:
			answers.append("ERROR")
			latencies.append(-1)
			print(f"Error processing question: {question}")

	# Store the results in the dataframe
	df["predicted_sql"] = answers
	df["latency"] = latencies
	return df

In [17]:
df_eval = eval_benchmark( "../tests/dataset_queries_en.json")

In [18]:
df_eval.head(10)

Unnamed: 0,question,expected_sql,predicted_sql,latency
0,Get all records from the Stock table,SELECT * FROM Stock;,SELECT S.* FROM Stock S INNER JOIN Orders O ON...,5.070666
1,List the IDs of products in stock with their c...,"SELECT id, current_stock_quantity FROM Stock;",SELECT S.id FROM Stock AS S INNER JOIN Orders ...,5.972475
2,Get the average unit price of all products in ...,SELECT AVG(unit_price) AS avg_unit_price FROM ...,SELECT S.id FROM Stock AS S JOIN Orders AS O O...,4.506686
3,"Find products with more than 10,000 units in s...","SELECT id, current_stock_quantity FROM Stock W...",SELECT S.* FROM Stock AS S JOIN Orders AS O ON...,5.407918
4,Get the total quantity of products in stock,SELECT SUM(current_stock_quantity) AS total_st...,SELECT S.* FROM Stock AS S JOIN Orders AS O ON...,5.147977
5,List orders placed in the year 2023,"SELECT * FROM Orders WHERE strftime('%Y', date...",SELECT T1.id FROM Stock AS T1 JOIN Orders AS T...,6.076564
6,Get the product with the highest maximum lead ...,"SELECT id, maximum_lead_time_days FROM Stock O...",SELECT T2.id FROM Orders AS T1 JOIN Stock AS T...,6.01953
7,List products whose unit price is above the av...,"SELECT id, unit_price FROM Stock WHERE unit_pr...",SELECT s.* FROM Stock AS s JOIN Orders AS o ON...,4.888183
8,Get the total number of orders per product,"SELECT id, COUNT(*) AS total_orders FROM Order...",SELECT id FROM Stock WHERE units != 'Out of St...,4.148718
9,List products in stock that have been ordered ...,"SELECT DISTINCT s.id, s.current_stock_quantity...",SELECT S.id FROM Orders O JOIN Stock S ON O.id...,4.865434


In [19]:
model_id = "llama-32-3B"

df_eval.to_csv(f"../tests/df_benchmark_results_{model_id}.csv", index=False)

In [20]:
import pandas as pd

model_id = "llama-32-3B"

df_eval = pd.read_csv(f"../tests/df_benchmark_results_{model_id}.csv")

In [23]:
import plotly.graph_objects as go

# Crear figura
fig = go.Figure(data=[go.Bar(x=df_eval.question, y=df_eval.latency)])

# Configurar ejes
fig.update_layout(
    title=f'Our Text2SQL Benchmark dataset to evaluate the performance of : {model_id}',
    xaxis_title='Question',
    yaxis_title='Latency (s)'
)
fig.update_layout(title=dict(font_size=12, font_weight='bold', x=0.5))
fig.update_traces(marker_line_color='#409ef3', marker_line_width=2, marker_color="#077fed")
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)', yaxis=dict(gridcolor='lightgrey'), height=600, width=1000)

# Mostrar figura
fig.show()