In [3]:
import os
from openai import AzureOpenAI

client = AzureOpenAI(
    api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_API_KEY")
)

In [6]:
# test query completion

response = client.chat.completions.create(
    model="o4-mini",
    messages=[
        {"role": "system", "content": "You are an AI assistant that writes SQL queries."},
        {"role": "user", "content": "Write a SQL query to select all users from the users table."}
    ]
)
print(response.choices[0].message.content)

Here’s a simple query to fetch all rows from the users table:

SELECT *
FROM users;


In [27]:
def evaluate_sql_generation(input_text, llm_output, client, deployment_name):
    """
    Evaluates if the LLM output is a correct SQL query for the given input using an LLM prompt.
    Returns a JSON with isCorrect, errorType, and errorExplanation fields.
    """
    prompt = (
    "You are an expert SQL evaluator. Your job is to judge whether a SQL query correctly answers a given natural language request.\n\n"
    "You will be given:\n"
    "- A natural language question (user's intent)\n"
    "- A SQL query generated by an LLM\n\n"
    "Your task is to evaluate ONLY whether the SQL query is a correct and complete answer to the question.\n\n"
    "Return your judgment in exactly one of the two following JSON formats:\n"
    '1. If the SQL query is correct and complete:\n'
    '   {"isCorrect": true, "errorType": null, "errorExplanation": null}\n'
    '2. If the SQL query is incorrect:\n'
    '   {"isCorrect": false, "errorType": "<incomplete|irrelevant|logic_error>", "errorExplanation": "<short explanation>"}\n\n'
    "**Return ONLY the JSON object. Do not add any explanation, comments, or formatting.**\n\n"
    f"Natural language question: {input_text}\n"
    f"SQL query: {llm_output}"
)
    response = client.chat.completions.create(
        model=deployment_name,
        messages=[
            {"role": "system", "content": prompt}
        ]
    )
    result = response.choices[0].message.content.strip()
    return result

In [28]:
# Example that should return false
wrong_input = "List all users in the database."
wrong_llm_output = "SELECT name FROM products;"

eval_result_false = evaluate_sql_generation(
    input_text=wrong_input,
    llm_output=wrong_llm_output,
    client=client,
    deployment_name="o4-mini"
)
print("Wrong output result:", eval_result_false)

# Example that should return true
correct_input = "List all users in the database."
correct_llm_output = "SELECT * FROM users;"

eval_result_true = evaluate_sql_generation(
    input_text=correct_input,
    llm_output=correct_llm_output,
    client=client,
    deployment_name="o4-mini"
)
print("Correct output result:", eval_result_true)

Wrong output result: {"isCorrect": false, "errorType": "irrelevant", "errorExplanation": "Selects from products instead of users table"}
Correct output result: {"isCorrect": true, "errorType": null, "errorExplanation": null}
Correct output result: {"isCorrect": true, "errorType": null, "errorExplanation": null}


In [None]:
from datasets import load_dataset

ds = load_dataset("gretelai/synthetic_text_to_sql")
eval_set = ds["test"]

# Example: iterate over the evaluation set and print the first 3 items
for i, item in enumerate(eval_set):
    print(f"Example {i+1}:")
    print("Prompt:", item["sql_prompt"])
    print("Reference SQL:", item["sql"])
    if i >= 2:
        break

Example 1:
Prompt: What is the average explainability score of creative AI applications in 'Europe' and 'North America' in the 'creative_ai' table?
Reference SQL: SELECT AVG(explainability_score) FROM creative_ai WHERE region IN ('Europe', 'North America');
Example 2:
Prompt: Delete all records of rural infrastructure projects in Indonesia that have a completion date before 2010.
Reference SQL: DELETE FROM rural_infrastructure WHERE country = 'Indonesia' AND completion_date < '2010-01-01';
Example 3:
Prompt: How many accidents have been recorded for SpaceX and Blue Origin rocket launches?
Reference SQL: SELECT launch_provider, COUNT(*) FROM Accidents GROUP BY launch_provider;


In [29]:
from transformers import pipeline
import torch

# Detect device for best performance (CPU, MPS for Apple Silicon, or CUDA for GPU)
if torch.cuda.is_available():
    device = 0  # CUDA device
elif torch.backends.mps.is_available():
    device = "mps"  # Apple Silicon (M1/M2)
else:
    device = -1  # CPU

generator = pipeline(
    "text-generation",
    model="thiborose/SmolLM2-FT-SQL",
    device=device,
)

Device set to use mps


In [30]:
def generate_sql_with_pipeline(prompt, generator):
    output = generator(
        [{"role": "user", "content": prompt}],
        max_new_tokens=128,
        return_full_text=False,
    )[0]
    return output["generated_text"].strip()


def evaluate_eval_set_with_pipeline(eval_set, generator, client, deployment_name, max_examples=5):
    results = []
    for i, item in enumerate(eval_set):
        prompt = item["sql_prompt"]
        generated_sql = generate_sql_with_pipeline(prompt, generator)
        eval_result = evaluate_sql_generation(
            input_text=prompt,
            llm_output=generated_sql,
            client=client,
            deployment_name=deployment_name
        )
        results.append({
            "prompt": prompt,
            "generated_sql": generated_sql,
            "reference_sql": item["sql"],
            "evaluation": eval_result
        })
        print(f"Example {i+1} | Prompt: {prompt}\nGenerated SQL: {generated_sql}\nReference SQL: {item['sql']}\nEvaluation: {eval_result}\n---")
        if i + 1 >= max_examples:
            break
    return results


In [31]:
# quick test with the first 5 examples from the evaluation set
pipeline_results = evaluate_eval_set_with_pipeline(eval_set, generator, client, deployment_name="o4-mini", max_examples=5)

Example 1 | Prompt: What is the average explainability score of creative AI applications in 'Europe' and 'North America' in the 'creative_ai' table?
Generated SQL: SELECT AVG(scoring) FROM creative_ai WHERE country IN ('Europe', 'North America') AND app_type = 'Creative AI' AND app_name IN ('AI Chatbot', 'AI Recommendation Engine', 'AI Chatbot Language Translation', 'AI Chatbot Text Summarization', 'AI Support Vector Machine');
Reference SQL: SELECT AVG(explainability_score) FROM creative_ai WHERE region IN ('Europe', 'North America');
Evaluation: {"isCorrect": false, "errorType": "logic_error", "errorExplanation": "The query unnecessarily filters by specific app_name and app_type and does not compute separate averages for each region"}
---
Example 2 | Prompt: Delete all records of rural infrastructure projects in Indonesia that have a completion date before 2010.
Generated SQL: DELETE FROM infrastructure_projects WHERE country = 'Indonesia' AND completion_date < 2010;
Reference SQL: D

In [None]:
# now on the whole evaluation set
pipeline_results_full = evaluate_eval_set_with_pipeline(eval_set, generator, client, deployment_name="o4-mini", max_examples=len(eval_set))

Example 1 | Prompt: What is the average explainability score of creative AI applications in 'Europe' and 'North America' in the 'creative_ai' table?
Generated SQL: SELECT AVG(explainability_score) FROM creative_ai WHERE region = 'Europe' AND region = 'North America'
Reference SQL: SELECT AVG(explainability_score) FROM creative_ai WHERE region IN ('Europe', 'North America');
Evaluation: {"isCorrect":false,"errorType":"logic_error","errorExplanation":"Filtering on region uses AND instead of OR or IN, resulting in no rows returned"}
---
Example 2 | Prompt: Delete all records of rural infrastructure projects in Indonesia that have a completion date before 2010.
Generated SQL: DELETE FROM infrastructure WHERE country = 'Indonesia' AND completion_date < '2010-01-01';
Reference SQL: DELETE FROM rural_infrastructure WHERE country = 'Indonesia' AND completion_date < '2010-01-01';
Evaluation: {"isCorrect": false, "errorType": "incomplete", "errorExplanation": "Missing filter for rural infrastruc

In [None]:
# Save the results to a file
import json
with open("pipeline_evaluation_results.json", "w") as f:
    json.dump(pipeline_results_full, f, indent=2)
# Print the number of examples evaluated
print(f"Total examples evaluated: {len(pipeline_results_full)}")