In [50]:
import sys
import os
from dotenv import load_dotenv
import pandas as pd

# Add the parent directory to the sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))

from sqltoolkit.connectors import PostgreSQLConnector
from sqltoolkit.client import DatabaseClient
from sqltoolkit.entities import Table, TableColumn

load_dotenv()

# Define the connection parameters
server = os.getenv('POSTGRES_HOST')
database = 'debit_card_specializing'
username = os.getenv('POSTGRES_USER')
password = os.getenv('POSTGRES_PWD')
port = '5432'

psql_connector = PostgreSQLConnector(server, database, username, password, port)
sql_client = DatabaseClient(psql_connector)

In [51]:
from openai import AzureOpenAI

aoai_endpoint = os.getenv('OPENAI_ENDPOINT')
aoai_key = os.getenv('OPENAI_API_KEY')
aoai_deployment = os.getenv('OPENAI_4o_DEPLOYMENT') # should be GPT-4o
aoai_embedding_deployment = os.getenv('OPENAI_EMBEDDING_DEPLOYMENT') # should be text-embedding-small

openai_client = AzureOpenAI(azure_endpoint=aoai_endpoint,
    api_key=aoai_key,
    api_version='2024-10-21')

In [3]:
import os
import pandas as pd

directory = 'debit_card_specializing_doc'
concatenated_content = ""

for filename in os.listdir(directory):
    if filename.endswith(".csv"):
        table_name = filename.split('.')[0]
        df = pd.read_csv(os.path.join(directory, filename))
        table_markdown = df.to_markdown(index=False)
        concatenated_content += f"\n### Table: {table_name}\n\n{table_markdown}\n\n"

print(concatenated_content)


### Table: transactions_1k

| original_column_name   | column_name    | column_description   | data_format   | value_description            |
|:-----------------------|:---------------|:---------------------|:--------------|:-----------------------------|
| TransactionID          | Transaction ID | Transaction ID       | integer       | nan                          |
| Date                   | nan            | Date                 | date          | nan                          |
| Time                   | nan            | Time                 | text          | nan                          |
| CustomerID             | Customer ID    | Customer ID          | integer       | nan                          |
| CardID                 | Card ID        | Card ID              | integer       | nan                          |
| GasStationID           | Gas Station ID | Gas Station ID       | integer       | nan                          |
| ProductID              | Product ID     | Product ID     

In [4]:
from sqltoolkit.indexer import DatabaseIndexer

context = concatenated_content

indexer = DatabaseIndexer(sql_client, 
                          openai_client, 
                          aoai_deployment,
                          aoai_embedding_deployment,
                          extra_context=context)

manifest = indexer.fetch_and_describe_tables()
indexer.generate_table_embeddings()

tables_dict = indexer.export_json_manifest()

2024-12-18 13:06:15,248 - DatabaseIndexer - INFO - Fetching tables from the database.
2024-12-18 13:06:15,433 - DatabaseIndexer - INFO - Processing table: public.customers
2024-12-18 13:06:20,312 - DatabaseIndexer - INFO - Completed processing table: public.customers
2024-12-18 13:06:20,314 - DatabaseIndexer - INFO - Processing table: public.gasstations
2024-12-18 13:06:27,721 - DatabaseIndexer - INFO - Completed processing table: public.gasstations
2024-12-18 13:06:27,723 - DatabaseIndexer - INFO - Processing table: public.products
2024-12-18 13:06:31,181 - DatabaseIndexer - INFO - Completed processing table: public.products
2024-12-18 13:06:31,182 - DatabaseIndexer - INFO - Processing table: public.transactions_1k


Object of type date is not JSON serializable


2024-12-18 13:06:44,404 - DatabaseIndexer - INFO - Completed processing table: public.transactions_1k
2024-12-18 13:06:44,405 - DatabaseIndexer - INFO - Processing table: public.yearmonth
2024-12-18 13:06:49,301 - DatabaseIndexer - INFO - Completed processing table: public.yearmonth
2024-12-18 13:06:49,302 - DatabaseIndexer - INFO - Completed fetching and processing all tables.


In [9]:
indexer.tables[0].model_dump()

{'name': 'public.customers',
 'business_readable_name': 'Customer Segments',
 'description': 'The customers table contains information about individual customers, including a unique identifier, their client segment classification, and the currency they use for transactions. This table helps in distinguishing customers and categorizes them into different groups such as Key Account Management (KAM), Large Account Management (LAM), and Small and Medium Enterprises (SME). The currency column indicates the currency codes like CZK and EUR used by customers in their transactions. This information is used for tailoring services and performing financial analyses.',
 'columns': [{'name': 'customerid',
   'type': 'bigint',
   'description': None,
   'definition': 'Customer ID is a unique identifier assigned to each customer in the database. It is represented as a bigint integer and serves as the primary key for the customers table. This column helps in distinguishing individual customers and is u

In [10]:
search_endpoint = os.getenv('AI_SEARCH_ENDPOINT')
search_key = os.getenv('AI_SEARCH_API_KEY')
embedding_deployment = os.getenv('OPENAI_EMBEDDING_DEPLOYMENT')
index_name = database

indexer.create_azure_ai_search_index(
    search_endpoint=search_endpoint,
    search_credential=search_key,
    index_name=index_name,
    embedding_deployment=aoai_embedding_deployment,
    openai_endpoint=aoai_endpoint,
    openai_key=aoai_key,
)

# write to AI Search
indexer.push_to_ai_search()

2024-12-18 13:11:11,603 - DatabaseIndexer - INFO - Creating Azure AI Search Index.
2024-12-18 13:11:11,745 - DatabaseIndexer - INFO - SearchIndexClient created.
2024-12-18 13:11:11,983 - DatabaseIndexer - INFO - Index 'debit_card_specializing' does not exist. Creating a new one.
2024-12-18 13:11:11,985 - DatabaseIndexer - INFO - Fields for the index defined.
2024-12-18 13:11:11,985 - DatabaseIndexer - INFO - Vector search configuration defined.
2024-12-18 13:11:11,986 - DatabaseIndexer - INFO - Semantic configuration defined.
2024-12-18 13:11:11,988 - DatabaseIndexer - INFO - Semantic search configuration defined.
2024-12-18 13:11:12,515 - DatabaseIndexer - INFO - Index 'debit_card_specializing' created successfully.
2024-12-18 13:11:12,519 - DatabaseIndexer - INFO - Pushing metadata for 5 tables to Azure AI Search.


Pushing data for table public.customers to the index.


2024-12-18 13:11:12,796 - DatabaseIndexer - INFO - Data for table public.customers pushed to the index.
2024-12-18 13:11:12,866 - DatabaseIndexer - INFO - Data for table public.gasstations pushed to the index.
2024-12-18 13:11:12,935 - DatabaseIndexer - INFO - Data for table public.products pushed to the index.


Pushing data for table public.gasstations to the index.
Pushing data for table public.products to the index.
Pushing data for table public.transactions_1k to the index.


2024-12-18 13:11:13,004 - DatabaseIndexer - INFO - Data for table public.transactions_1k pushed to the index.
2024-12-18 13:11:13,067 - DatabaseIndexer - INFO - Data for table public.yearmonth pushed to the index.


Pushing data for table public.yearmonth to the index.


In [None]:
import time
from azure.search.documents import SearchClient
from azure.search.documents.models import QueryType, QueryCaptionType, QueryAnswerType
from azure.search.documents.models import VectorizableTextQuery
from azure.core.credentials import AzureKeyCredential
import json

search_client = SearchClient(endpoint=search_endpoint, 
                             credential=AzureKeyCredential(search_key), 
                             index_name=index_name)

def execute_query_pipeline(user_question, evidence=None):
    # Query rewriting
    rewriter_prompt = f"""
    You are a helpful AI SQL Agent. Your role is to help people translate their questions 
    into better questions that can be translated into SQL queries.
    You should disambiguate the question and provide a more specific version of the question to the best of your ability.
    You should rewrite the question in a way that is more likely to be answered by the database.

    ## You can use the following context to help you rewrite the question:
    {evidence}

    You must only return a single sentence that is a better version of the user's question.
    """

    messages = [
        {
            "role": "system",
            "content": rewriter_prompt
        },
        {
            "role": "user",
            "content": user_question
        }
    ]

    rewriter_response = openai_client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
    )

    rewriter_response = rewriter_response.choices[0].message.content
    print("Rewritten Question:", rewriter_response)

    # AI Search
    vector_query = VectorizableTextQuery(text=rewriter_response, k_nearest_neighbors=50, fields="embedding")

    results = search_client.search(
        search_text=user_question,
        vector_queries=[vector_query],
        query_type=QueryType.SEMANTIC,
        semantic_configuration_name='my-semantic-config',
        query_caption=QueryCaptionType.EXTRACTIVE,
        query_answer=QueryAnswerType.EXTRACTIVE,
        top=5
    )

    candidate_tables = []
    candidate_table_names = []
    for result in results:
        candidate_tables.append(
            json.dumps({key: result[key] for key in ['name', 'business_readable_name', 'description', 'columns'] if key in result}, indent=4)
        )
        candidate_table_names.append(result['name'])

    # LLM Call
    print("Candidate Tables:", candidate_tables)
    tables_prompt = candidate_tables

    system_prompt = f"""
    You are a helpful AI data analyst assistant, 
    You can execute SQL queries to retrieve information from a SQL database,
    The database is SQL server, use the right syntax to generate queries

    ### You must use the POSTGRES SQL syntax to query the database
    ### You must use double quotes around the column names but not around table names
    ### Your output must be a single SQL query
    ### You must use the column sample values to infer the value in your WHERE clause

    ### Here's additional context to help you generate the SQL query:
    {evidence}

    ### These are the available tables in the database alongside their descriptions:
    {tables_prompt}

    When asked a question that could be answered with a SQL query: 
    - Break down the question into smaller parts that can be answered with SQL queries
    - Identify the tables that contain the information needed to answer the question
    - Only once this is done, create a SQL query to retrieve the information based on your understanding of the table

    # you must use the exact table names and column names provided in the table descriptions

    Think step by step, before doing anything, share the different steps you'll execute to get the answer

    Your response should be in JSON format with the following structure:
    "chain_of_thought": "Your reasoning",
    "sql_query": "The generated SQL query, this can only be a single query"
    """

    messages = [
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": user_question
        }
    ]

    response = openai_client.chat.completions.create(
        model="gpt-4o-global",
        messages=messages,
        response_format={"type": "json_object"}
    )

    response_message = response.choices[0].message.content
    response_json = json.loads(response_message)
    print("Generated Chain of Thought:", response_json['chain_of_thought'])
    print("Generated SQL Query:", response_json['sql_query'])

    # Execute SQL Query
    # query_result = sql_client.query(response_json['sql_query'])
    return {'sql_query':response_json['sql_query'],
            'chain_of_thought':response_json['chain_of_thought'],
            'tables':candidate_table_names,
            'rewritten_question':rewriter_response}

# Example usage
user_question = "For all the transactions happened during 8:00-9:00 in 2012/8/26, how many happened in CZE?"

start_time = time.time()
result = execute_query_pipeline(user_question)
end_time = time.time()

latency = end_time - start_time
print("Result:", result)
print("Latency:", latency, "seconds")

Rewritten Question: What is the total count of transactions that occurred in the CZE location between 8:00 AM and 9:00 AM on August 26, 2012?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n    "business_readable_name": "Customer Monthly Consumption",\n    "description": "The transactions_1k table records detailed information about individual transactions at gas stations, including the transaction date and time, customer identification, card used, gas station, and product details. It also includes the amount and price of products purchased, with the total price being the product of amount and price. This table is essential for analyzing sales activities, customer purchasing habits, and financial performance in the context of fuel and related product sales.",\n    "columns": [\n        {\n            "name": "customerid",\n            "type": "bigint",\n            "description": null,\n            "definition": "This column contains the unique identification numbers assigned to

In [15]:
import pandas as pd

df = pd.read_json('mini_dev_postgresql.json')
eval = df[df['db_id']=='debit_card_specializing']
eval.head()

Unnamed: 0,question_id,db_id,question,evidence,SQL,difficulty
0,1471,debit_card_specializing,What is the ratio of customers who pay in EUR ...,ratio of customers who pay in EUR against cust...,SELECT CAST(SUM(CASE WHEN Currency = 'EUR' THE...,simple
1,1472,debit_card_specializing,"In 2012, who had the least consumption in LAM?",Year 2012 can be presented as Between 201201 A...,SELECT T1.CustomerID FROM customers AS T1 INNE...,moderate
2,1473,debit_card_specializing,What was the average monthly consumption of cu...,Average Monthly consumption = AVG(Consumption)...,"SELECT AVG(T2.Consumption) / NULLIF(12, 0) FRO...",moderate
3,1476,debit_card_specializing,What was the difference in gas consumption bet...,cast the consumption into float when perform c...,SELECT SUM(CASE WHEN T1.Currency = 'CZK' THEN ...,challenging
4,1479,debit_card_specializing,Which year recorded the most consumption of ga...,The first 4 strings of the Date values in the ...,"SELECT SUBSTR(T2.Date, 1, 4) FROM customers AS...",moderate


In [39]:
from tqdm import tqdm

tqdm.pandas()

eval['llm_result'] = eval.progress_apply(lambda row: execute_query_pipeline(row['question'], row['evidence']), axis=1)
eval['generated_sql_query'] = eval['llm_result'].apply(lambda x: x['sql_query'])

  0%|          | 0/30 [00:00<?, ?it/s]

Rewritten Question: What is the ratio of the number of customers who paid in EUR to the number of customers who paid in CZK?
Candidate Tables: ['{\n    "name": "public.customers",\n    "business_readable_name": "Customer Segments",\n    "description": "The customers table contains information about individual customers, including a unique identifier, their client segment classification, and the currency they use for transactions. This table helps in distinguishing customers and categorizes them into different groups such as Key Account Management (KAM), Large Account Management (LAM), and Small and Medium Enterprises (SME). The currency column indicates the currency codes like CZK and EUR used by customers in their transactions. This information is used for tailoring services and performing financial analyses.",\n    "columns": [\n        {\n            "name": "customerid",\n            "type": "bigint",\n            "description": null,\n            "definition": "Customer ID is a un

  7%|▋         | 2/30 [00:04<01:01,  2.18s/it]

Generated Chain of Thought: To find the ratio of customers who pay in EUR against customers who pay in CZK, we need to count the number of customers using each currency. This information is located in the 'public.customers' table in the 'currency' column. We can use a SQL query to count the occurrences of each currency type and then compute the ratio of these counts.
Generated SQL Query: SELECT (SUM(CASE WHEN "currency" = 'EUR' THEN 1 ELSE 0 END)::decimal / NULLIF(SUM(CASE WHEN "currency" = 'CZK' THEN 1 ELSE 0 END), 0)) AS eur_to_czk_ratio FROM public.customers;
Rewritten Question: Which individual or entity had the lowest consumption in LAM during the year 2012, based on data from the yearmonth table?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n    "business_readable_name": "Customer Monthly Consumption",\n    "description": "The transactions_1k table records detailed information about individual transactions at gas stations, including the transaction date and time, custom

 10%|█         | 3/30 [00:08<01:24,  3.15s/it]

Generated Chain of Thought: To determine who had the least consumption in the LAM segment in 2012, we need to follow these steps:
1. Identify the records for the year 2012 from the 'yearmonth' table.
2. Filter the customers who belong to the LAM segment from the 'customers' table.
3. Join the 'yearmonth' and 'customers' tables on 'customerid'.
4. Aggregate the consumption of each customer in the LAM segment for the year 2012.
5. Identify the customer with the least consumption.
Generated SQL Query: SELECT c.customerid, MIN(y.consumption) AS least_consumption
FROM public.yearmonth y
JOIN public.customers c ON y.customerid = c.customerid
WHERE c.segment = 'LAM' AND y.date BETWEEN '201201' AND '201212'
GROUP BY c.customerid
ORDER BY least_consumption ASC
LIMIT 1;
Rewritten Question: What was the average monthly consumption of customers in the SME category for the year 2013, calculated as total consumption for the year divided by 12?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n

 13%|█▎        | 4/30 [00:13<01:31,  3.53s/it]

Generated Chain of Thought: To find the average monthly consumption of SME customers for the year 2013, we need to perform the following steps: 
1. Filter the customers to only include those from the SME segment.
2. Filter the consumption records to include only those within the year 2013 (dates between 201301 and 201312).
3. Join the filtered customers with their corresponding consumption records (matching on customerid).
4. Calculate the average monthly consumption for the filtered set by averaging the total consumption and dividing by 12.
Generated SQL Query: SELECT AVG("consumption") / 12 AS average_monthly_consumption
FROM yearmonth y
JOIN customers c
ON y.customerid = c.customerid
WHERE c.segment = 'SME' AND y.date BETWEEN '201301' AND '201312';
Rewritten Question: What is the difference in gas consumption, calculated as float, between CZK-paying and EUR-paying customers for the year 2012?
Candidate Tables: ['{\n    "name": "public.customers",\n    "business_readable_name": "Cust

 17%|█▋        | 5/30 [00:18<01:41,  4.04s/it]

Generated Chain of Thought: To calculate the difference in gas consumption between CZK-paying customers and EUR-paying customers in 2012, follow these steps:
1. Identify customers who pay in CZK and EUR from the 'customers' table.
2. Retrieve consumption data for the year 2012 (dates between 201201 and 201212) from the 'yearmonth' table.
3. Filter the consumption data to include only the records for customers who pay in CZK and EUR.
4. Sum the total consumption for CZK-paying customers and for EUR-paying customers separately.
5. Calculate the difference between the two sums. The final query combines these steps using a JOIN and appropriate WHERE clauses to filter by year and currency.
Generated SQL Query: SELECT (SUM(CAST(czk_consumption AS FLOAT)) - SUM(CAST(eur_consumption AS FLOAT))) AS consumption_difference
FROM (
  SELECT y."consumption" AS czk_consumption, NULL AS eur_consumption
  FROM public.yearmonth y
  JOIN public.customers c ON y."customerid" = c."customerid"
  WHERE "date

 20%|██        | 6/30 [00:23<01:50,  4.62s/it]

Generated Chain of Thought: To determine which year recorded the most consumption of gas paid in CZK, we'll need to follow these steps: 
1. Identify the relevant tables that contain 'customerid', 'date', 'consumption', and 'currency'. This information is found in the 'public.yearmonth' table (for 'customerid', 'date', and 'consumption') and 'public.customers' table (for 'customerid' and 'currency'). 
2. We'll join these tables on the 'customerid'. 
3. Filter the results to only include records where 'currency' is 'CZK'. 
4. Extract the year from the 'date' column. 
5. Group the data by year and sum the 'consumption' for each year. 
6. Order the results by the total consumption in descending order and select the top result.
Generated SQL Query: SELECT SUBSTRING("date", 1, 4) AS year, SUM("consumption") AS total_consumption 
FROM public.yearmonth 
JOIN public.customers ON public.yearmonth."customerid" = public.customers."customerid" 
WHERE public.customers."currency" = 'CZK' 
GROUP BY ye

 23%|██▎       | 7/30 [00:28<01:48,  4.70s/it]

Generated Chain of Thought: To determine the peak gas consumption month for SME customers in 2013, we need to take several steps:
1. Identify the customers who belong to the 'SME' segment from the customers table.
2. Filter the consumption data from the yearmonth table for the year 2013 (the date should be between 201301 and 201312) and for SME customers.
3. Sum the consumption for each month and identify the month with the highest total consumption.
4. Join the tables on customerid to get the required data.

Generated SQL Query: SELECT "date", SUM("consumption") as total_consumption
FROM yearmonth
JOIN customers ON yearmonth.customerid = customers.customerid
WHERE customers.segment = 'SME' AND yearmonth.date BETWEEN '201301' AND '201312'
GROUP BY "date"
ORDER BY total_consumption DESC
LIMIT 1;
Rewritten Question: What are the differences in the annual average consumption in CZK for the customers with the lowest consumption in 2013 between SME and LAM, LAM and KAM, and KAM and SME?
Can

 27%|██▋       | 8/30 [00:36<02:01,  5.53s/it]

Generated Chain of Thought: To find the difference in the annual average consumption of the customers with the least amount of consumption paid in CZK for 2013 between different segments (SME and LAM, LAM and KAM, and KAM and SME), we need to perform the following steps:
Step 1: Filter the data for the year 2013 (date between 201301 and 201312) and the currency CZK.
Step 2: Calculate the total annual consumption for each customer.
Step 3: Identify the customer with the least consumption in each segment (SME, LAM, KAM).
Step 4: Calculate the annual average consumption of the customer with the lowest consumption in each segment.
Step 5: Compute the differences in the annual average consumption between SME and LAM, LAM and KAM, and KAM and SME.
Generated SQL Query: WITH annual_consumption AS (
    SELECT c."segment", y."customerid", SUM(y."consumption") AS total_consumption
    FROM public.customers c
    JOIN public.yearmonth y ON c."customerid" = y."customerid"
    WHERE y."date" BETWEE

 30%|███       | 9/30 [00:43<02:09,  6.19s/it]

Generated Chain of Thought: To address the question, the following steps need to be executed: 
1. Filter the customer segments which are 'SME', 'LAM' and 'KAM' from the customers table where the currency is EUR. 
2. For the selected customers, retrieve consumption data for the periods corresponding to 2012 and 2013 from the yearmonth table. 
3. Aggregate the consumption over each year for each customer. 
4. Calculate the increase/decrease and the percentage change between the years 2012 and 2013. 
5. Group by segment and compute the average percentage increase for each segment. 
6. Identify the segments with the highest and lowest percentage increases. 

The required data is in the customers and yearmonth tables. We will filter customers with EUR currency and belonging to SME, LAM, and KAM segments, aggregate their consumption data from yearmonth for 2012 and 2013, and calculate the required percentage changes.
Generated SQL Query: WITH yearly_consumption AS ( 
SELECT c.segment, 
     

 33%|███▎      | 10/30 [00:47<01:47,  5.36s/it]

Generated Chain of Thought: To find the total consumption for customer 6 between August and November 2013, we need to filter the records in the 'yearmonth' table by the 'customerid' being 6 and the 'date' column representing months between 201308 and 201311. We then sum the 'consumption' values for these filtered records.
Generated SQL Query: SELECT SUM("consumption") AS total_consumption FROM yearmonth WHERE "customerid" = 6 AND "date" >= '201308' AND "date" <= '201311';
Rewritten Question: What is the difference in the total number of discount gas stations between the Czech Republic and Slovakia?
Candidate Tables: ['{\n    "name": "public.gasstations",\n    "business_readable_name": "Gas Station Information",\n    "description": "The gasstations table contains detailed information about each gas station. It includes unique identifiers for each station, its affiliation with a chain, and geographic location by country. Additionally, the table categorizes stations based on their market 

 37%|███▋      | 11/30 [00:50<01:27,  4.60s/it]

Generated Chain of Thought: To determine the number of additional 'Discount' gas stations in the Czech Republic compared to Slovakia, I need to complete the following steps: 1) Calculate the total number of 'Discount' gas stations in the Czech Republic. 2) Calculate the total number of 'Discount' gas stations in Slovakia. 3) Find the difference between the two values. The 'gasstations' table contains information about gas stations, including their 'country' and 'segment'. I will use this table to generate the required counts and compute the difference.
Generated SQL Query: SELECT (SELECT COUNT(*) FROM gasstations WHERE "country" = 'CZE' AND "segment" = 'Discount') - (SELECT COUNT(*) FROM gasstations WHERE "country" = 'SVK' AND "segment" = 'Discount') AS difference_in_discount_gas_stations;
Rewritten Question: How many more SMEs pay using Czech koruna compared to those that pay using euros?
Candidate Tables: ['{\n    "name": "public.customers",\n    "business_readable_name": "Customer S

 40%|████      | 12/30 [00:53<01:15,  4.21s/it]

Generated Chain of Thought: To determine if more SMEs pay in Czech koruna (CZK) than in euros (EUR) and to find out how many more, we need to follow these steps: 
1. Filter the customers to get only those classified as SMEs. 
2. Group these SMEs by the currency they use for transactions. 
3. Count the number of SMEs paying in CZK and EUR. 
4. Calculate the difference between the two counts. Both steps 2 and 3 will be combined in one query.
Generated SQL Query: SELECT (COUNT(CASE WHEN "currency" = 'CZK' THEN 1 END) - COUNT(CASE WHEN "currency" = 'EUR' THEN 1 END)) AS difference FROM public.customers WHERE "segment" = 'SME';
Rewritten Question: What is the percentage of LAM customers who have consumed more than 46.73 compared to the total number of LAM customers?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n    "business_readable_name": "Customer Monthly Consumption",\n    "description": "The transactions_1k table records detailed information about individual transactions at g

 43%|████▎     | 13/30 [00:57<01:11,  4.22s/it]

Generated Chain of Thought: To determine the percentage of LAM customers who consumed more than 46.73, we need to complete the following steps:
Step 1: Identify the LAM customers.
Step 2: Calculate the total number of LAM customers.
Step 3: Calculate the number of LAM customers who consumed more than 46.73.
Step 4: Compute the percentage by dividing the number of LAM customers who consumed more than 46.73 by the total number of LAM customers and then multiply by 100.
Generated SQL Query: SELECT (COUNT(CASE WHEN yc."consumption" > 46.73 THEN 1 END)::float / COUNT(yc."customerid"::float * 100) AS percentage_exceeding_46_73
FROM public.yearmonth yc
JOIN public.customers c ON yc."customerid" = c."customerid"
WHERE c."segment" = 'LAM'
Rewritten Question: What was the percentage of customers who had consumption greater than 528.3 in the yearmonth.date value for February 2012?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n    "business_readable_name": "Customer Monthly Consumption",

 47%|████▋     | 14/30 [01:01<01:06,  4.14s/it]

Generated Chain of Thought: To determine the percentage of customers who consumed more than 528.3 in February 2012:
1. Identify the relevant table: the 'yearmonth' table, since it contains consumption data per customer per month.
2. Filter records in the 'yearmonth' table where the date is '201202' (February 2012) and consumption is greater than 528.3.
3. Count the number of customers who meet this criterion.
4. Count the total number of customers in February 2012.
5. Calculate the percentage by dividing the number of customers who consumed more than 528.3 by the total number of customers, and then multiplying by 100.
Generated SQL Query: SELECT (COUNT(CASE WHEN "consumption" > 528.3 THEN 1 END) * 100.0 / COUNT(*)) AS percentage FROM yearmonth WHERE "date" = '201202'
Rewritten Question: What is the maximum monthly consumption recorded in the year 2012?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n    "business_readable_name": "Customer Monthly Consumption",\n    "description

 50%|█████     | 15/30 [01:04<00:57,  3.81s/it]

Generated Chain of Thought: To determine the highest monthly consumption in the year 2012, we need to follow these steps: 1. Filter records from the yearmonth table where the date column corresponds to the year 2012. 2. Extract the consumption values and cast them to float for accurate comparison. 3. Identify the maximum consumption value from these records. The 'yearmonth' table contains the necessary 'date' and 'consumption' columns. The date is in the YYYYMM format, so we need to filter for dates starting with '2012'.
Generated SQL Query: SELECT MAX(CAST("consumption" AS float)) AS highest_monthly_consumption FROM yearmonth WHERE "date" LIKE '2012%'
Rewritten Question: What are the product descriptions for products consumed in September 2013?
Candidate Tables: ['{\n    "name": "public.products",\n    "business_readable_name": "Product Descriptions",\n    "description": "The products table contains information regarding various products offered within the database. Each product is un

 53%|█████▎    | 16/30 [01:09<00:58,  4.20s/it]

Generated Chain of Thought: To answer this question, I need to follow these steps: 1. Identify the table that contains the consumption data along with the associated date information. This is the 'yearmonth' table. 2. Filter the records by the specified date, which is September 2013 (201309). 3. Identify the product IDs from the 'transactions_1k' table, which would help link products consumed to the 'products' table. 4. Join the filtered records with the 'products' table based on the product IDs to get the product descriptions.
Generated SQL Query: SELECT p."description" FROM yearmonth y JOIN transactions_1k t ON y."customerid" = t."customerid" JOIN products p ON t."productid" = p."productid" WHERE y."date" = '201309'
Rewritten Question: Could you provide a list of countries for gas stations that processed transactions during June 2013?
Candidate Tables: ['{\n    "name": "public.gasstations",\n    "business_readable_name": "Gas Station Information",\n    "description": "The gasstations

 57%|█████▋    | 17/30 [01:13<00:53,  4.12s/it]

Generated Chain of Thought: To get the countries of the gas stations with transactions in June 2013, we need to follow these steps:
1. Filter the transactions that occurred in June 2013 from the 'transactions_1k' table. We will use the 'date' column to filter the transactions.
2. Join the filtered transactions with the 'gasstations' table using the 'gasstationid' to get the country of each gas station.
3. Select the distinct countries from the result to list each country only once.
Generated SQL Query: SELECT DISTINCT gasstations."country" 
FROM transactions_1k 
JOIN gasstations ON transactions_1k."gasstationid"=gasstations."gasstationid" 
WHERE DATE_TRUNC('month', transactions_1k."date") = '2013-06-01'
Rewritten Question: How many customers who paid in euro (Currency = 'EUR') have a monthly consumption greater than 1000?
Candidate Tables: ['{\n    "name": "public.customers",\n    "business_readable_name": "Customer Segments",\n    "description": "The customers table contains informati

 60%|██████    | 18/30 [01:16<00:43,  3.67s/it]

Generated Chain of Thought: To answer the question, I need to perform the following steps: 1) Identify customers who use the EUR currency from the customers table. 2) Join this list of customers with the yearmonth table to get their monthly consumption. 3) Filter the records where the monthly consumption is greater than 1000. 4) Count the unique customers who meet these criteria.
Generated SQL Query: SELECT COUNT(DISTINCT yearmonth.customerid) FROM customers JOIN yearmonth ON customers.customerid = yearmonth.customerid WHERE customers.currency = 'EUR' AND yearmonth.consumption > 1000;
Rewritten Question: What are the product descriptions of transactions that occurred at gas stations in the Czech Republic (Country value 'CZE')?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n    "business_readable_name": "Customer Monthly Consumption",\n    "description": "The transactions_1k table records detailed information about individual transactions at gas stations, including the transact

 63%|██████▎   | 19/30 [01:19<00:38,  3.51s/it]

Generated Chain of Thought: To list the product descriptions of the transactions that took place in the gas stations in the Czech Republic, we need to follow these steps:
1. We need to join the transactions_1k and gasstations tables using the gasstationid column to filter transactions that happened in the Czech Republic.
2. We need to join the resulting table with the products table using the productid column to get the product descriptions.
3. Finally, we will select the description of the products.
Generated SQL Query: SELECT products.description
FROM transactions_1k
JOIN gasstations ON transactions_1k.gasstationid = gasstations.gasstationid
JOIN products ON transactions_1k.productid = products.productid
WHERE gasstations.country = 'CZE';
Rewritten Question: Could you provide a list of all unique transaction times for gas station transactions associated with chain number 11?
Candidate Tables: ['{\n    "name": "public.transactions_1k",\n    "business_readable_name": "Gas Station Trans

 67%|██████▋   | 20/30 [01:22<00:33,  3.31s/it]

Generated Chain of Thought: To answer this question, we need to retrieve the times of transactions that took place in gas stations belonging to chain number 11. This requires looking into two tables: 'public.transactions_1k' for transaction times and 'public.gasstations' for gas station chain information. We will join these two tables using 'gasstationid' and filter the result based on 'chainid' being 11. Finally, we will select distinct times of transactions.
Generated SQL Query: SELECT DISTINCT t."time" FROM public.transactions_1k t INNER JOIN public.gasstations g ON t."gasstationid" = g."gasstationid" WHERE g."chainid" = 11;
Rewritten Question: How many transactions at gas stations in the Czech Republic occurred after January 1, 2012?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n    "business_readable_name": "Customer Monthly Consumption",\n    "description": "The transactions_1k table records detailed information about individual transactions at gas stations, including t

 70%|███████   | 21/30 [01:25<00:28,  3.21s/it]

Generated Chain of Thought: To answer this question, we need to follow these steps: First, we need to identify transactions that occurred in gas stations located in the Czech Republic. This information is available in the 'gasstations' table, where the country code 'CZE' represents the Czech Republic. Secondly, we need to filter transactions that occurred after '2012-01-01'. The relevant information regarding transactions is available in the 'transactions_1k' table. We'll perform a join between 'transactions_1k' and 'gasstations' on the 'gasstationid', and then apply the date filter to count the transactions.
Generated SQL Query: SELECT COUNT(*) FROM transactions_1k t JOIN gasstations g ON t.gasstationid = g.gasstationid WHERE g.country = 'CZE' AND t.date > '2012-01-01';
Rewritten Question: What currency did the customer use to pay at 16:25:00 on 2012-08-24?
Candidate Tables: ['{\n    "name": "public.customers",\n    "business_readable_name": "Customer Segments",\n    "description": "T

 73%|███████▎  | 22/30 [01:28<00:26,  3.34s/it]

Generated Chain of Thought: To answer this question, I will take the following steps: 
1. Identify the transactions that occurred at the specific time '16:25:00' on '2012-08-24' from the 'public.transactions_1k' table.
2. Filter these transactions to get the 'customerid'.
3. Use the 'customerid' to query the 'public.customers' table to retrieve the type of currency used by the customer.
4. Combine these steps into a single SQL query.
Generated SQL Query: SELECT c.currency 
FROM public.transactions_1k t 
JOIN public.customers c ON t.customerid = c.customerid 
WHERE t.date = '2012-08-24' AND t.time = '16:25:00';
Rewritten Question: What was the customer's segment on 2012-08-23 at 21:20:00?
Candidate Tables: ['{\n    "name": "public.customers",\n    "business_readable_name": "Customer Segments",\n    "description": "The customers table contains information about individual customers, including a unique identifier, their client segment classification, and the currency they use for transact

 77%|███████▋  | 23/30 [01:34<00:27,  3.91s/it]

Generated Chain of Thought: To determine the segment of the customer who made a transaction at 2012-08-23 21:20:00: - I need to search for the transaction at 2012-08-23 21:20:00 in the transactions_1k table to find the customer ID. - Next, I need to find the segment of this customer from the customers table using the customer ID.
Generated SQL Query: SELECT c."segment" 
FROM public.transactions_1k t 
JOIN public.customers c ON t."customerid" = c."customerid" 
WHERE t."date" = '2012-08-23' AND t."time" = '21:20:00';
Rewritten Question: What is the total number of transactions that occurred in the Czech Republic (CZE) on August 26, 2012, during the time period between 08:00:00 and 09:00:00?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n    "business_readable_name": "Customer Monthly Consumption",\n    "description": "The transactions_1k table records detailed information about individual transactions at gas stations, including the transaction date and time, customer identificat

 80%|████████  | 24/30 [01:41<00:29,  4.87s/it]

Generated Chain of Thought: To determine the number of transactions that happened in CZE during 8:00-9:00 on 2012-08-26, we need to follow these steps: 1) Identify the relevant tables. The transactions_1k table contains the date, time, and gas station ID associated with each transaction. The gasstations table contains the country information associated with each gas station ID. 2) Filter transactions within the specified time range (08:00:00 to 09:00:00) on the given date (2012-08-26). 3) Join the transactions_1k table with the gasstations table on the gasstationid column. 4) Count the number of transactions that happened in the Czech Republic (CZE).
Generated SQL Query: SELECT COUNT(*) FROM transactions_1k t JOIN gasstations g ON t.gasstationid = g.gasstationid WHERE t.date = '2012-08-26' AND t.time BETWEEN '08:00:00' AND '09:00:00' AND g.country = 'CZE';
Rewritten Question: What is the nationality of the customer who made a purchase of 548.4 on August 24, 2012?
Candidate Tables: ['{\

 83%|████████▎ | 25/30 [01:44<00:21,  4.34s/it]

Generated Chain of Thought: To determine the nationality of the customer who spent 548.4 on 2012-08-24, we will follow these steps: 1) Identify the customer who made the specific transaction on the specified date and amount from the transactions_1k table. 2) Once we have the customer ID, we need to find this customer's nationality from the customers table by joining it with the transactions_1k table.
Generated SQL Query: SELECT gasstations.country FROM transactions_1k INNER JOIN customers ON transactions_1k.customerid = customers.customerid INNER JOIN gasstations ON transactions_1k.gasstationid = gasstations.gasstationid WHERE transactions_1k.date = '2012-08-24' AND transactions_1k.price = 548.4 LIMIT 1;
Rewritten Question: What percentage of customers used EUR on '2012-08-25'?
Candidate Tables: ['{\n    "name": "public.customers",\n    "business_readable_name": "Customer Segments",\n    "description": "The customers table contains information about individual customers, including a un

 87%|████████▋ | 26/30 [01:49<00:18,  4.73s/it]

Generated Chain of Thought: To determine the percentage of customers who used EUR on 2012/8/25, we need to follow these steps:
1. Identify the total number of customers who used EUR.
2. Identify the total number of customers who performed transactions on 2012-08-25.
3. Identify the customers who used EUR and performed transactions on this specific date.
4. Calculate the percentage of these filtered customers with respect to the total customers who performed transactions on that date.

The relevant tables are 'public.customers' to filter customers by currency (EUR) and 'public.transactions_1k' to filter transactions by date (2012-08-25). We will join these tables based on the 'customerid'. Finally, we will calculate the percentage.
Generated SQL Query: SELECT (CAST(COUNT(DISTINCT c.customerid) AS FLOAT) / total.total_customers) * 100 AS percentage_of_customers_who_used_EUR
FROM public.customers c
JOIN public.transactions_1k t ON c.customerid = t.customerid
CROSS JOIN (SELECT COUNT(DISTI

 90%|█████████ | 27/30 [01:56<00:15,  5.30s/it]

Generated Chain of Thought: To find the consumption decrease rate for the customer who paid 634.8 on 2012/8/25, we need to take the following steps: 
1. Identify the customer who made a transaction of 634.8 on 2012/8/25 from the transactions_1k table. 
2. Find the total consumption for that customer in Year 2012 and 2013 from the yearmonth table. 
3. Calculate the consumption decrease rate as (consumption_2012 - consumption_2013) / consumption_2012. We need to join the yearmonth and transactions_1k tables using the customerid and filter the data accordingly.
Generated SQL Query: WITH customer_payment AS (SELECT customerid FROM transactions_1k WHERE price = 634.8 AND date = '2012-08-25'), 
consumption_2012 AS (SELECT customerid, SUM(consumption) as consumption_2012 FROM yearmonth WHERE date >= '201201' AND date <= '201212' GROUP BY customerid), 
consumption_2013 AS (SELECT customerid, SUM(consumption) as consumption_2013 FROM yearmonth WHERE date >= '201301' AND date <= '201312' GROUP B

 93%|█████████▎| 28/30 [02:06<00:13,  6.64s/it]

Generated Chain of Thought: To calculate the percentage of 'Premium' segment gas stations against the overall segments in Country 'SVK', we need to perform the following steps: 1. Filter gas stations located in 'SVK'. 2. Count the total number of gas stations in 'SVK'. 3. Filter the gas stations with the 'Premium' segment in 'SVK'. 4. Count the number of 'Premium' segment gas stations in 'SVK'. 5. Calculate the percentage by dividing the count of 'Premium' segment gas stations by the total gas stations in 'SVK' and then multiply by 100.
Generated SQL Query: SELECT (COUNT(CASE WHEN segment = 'Premium' THEN 1 END) * 100.0 / COUNT(*)) AS premium_percentage FROM gasstations WHERE country = 'SVK';
Rewritten Question: What was the total amount spent by customer "38508" at gas stations in January 2012?
Candidate Tables: ['{\n    "name": "public.yearmonth",\n    "business_readable_name": "Customer Monthly Consumption",\n    "description": "The transactions_1k table records detailed information

 97%|█████████▋| 29/30 [02:10<00:05,  5.98s/it]

Generated Chain of Thought: To determine the total amount spent by customer 38508 at the gas stations, we need to query the transactions_1k table. We will filter the transactions by the customer ID 38508. Next, to find out how much the customer spent in January 2012, we will need to use the same table but add a date filter to restrict the transactions to that specific month. The total amount spent can be calculated using the amount and price columns. We can combine these operations into a single query.
Generated SQL Query: SELECT SUM("amount" * "price") AS total_spent, SUM(CASE WHEN TO_CHAR("date", 'YYYYMM') = '201201' THEN "amount" * "price" ELSE 0 END) AS jan_2012_spent FROM transactions_1k WHERE "customerid" = 38508;
Rewritten Question: Who is the top spending customer, what is the average price per single item purchased by this customer, and what currency was used for these transactions?
Candidate Tables: ['{\n    "name": "public.customers",\n    "business_readable_name": "Customer

100%|██████████| 30/30 [02:16<00:00,  5.88s/it]

Generated Chain of Thought: To identify the top spending customer and calculate the average price per item purchased by this customer along with the currency used, we need to follow these steps:

1. Identify the total spending per customer to determine the top spender.
2. Once we identify the top spender, calculate the average price per item for this customer.
3. Get the currency used by this customer from the customers table.

We will need to join the transactions_1k table with the customers table to get the customer's currency.

Query 1: Calculate total spending per customer to identify the top spender.
Query 2: Calculate the average price per single item purchased by this customer.
Query 3: Identify the currency used by this customer.
Generated SQL Query: WITH total_spending AS (SELECT "customerid", SUM("price" * "amount") AS total_spent FROM transactions_1k GROUP BY "customerid" ORDER BY total_spent DESC LIMIT 1), 
 top_customer AS (SELECT ts."customerid", ts.total_spent, c."curren

100%|██████████| 30/30 [02:21<00:00,  4.73s/it]

Generated Chain of Thought: To answer this question, we need to follow these steps: 1) Identify the table that contains the transaction information including product ID, price, and customer ID. This information is in the 'transactions_1k' table. 2) Filter the transactions to get records where the price per unit is greater than 29.00 and product ID is 5. 3) From these records, extract the unique customer IDs. 4) Use these customer IDs to get the consumption status in August 2012 from the 'yearmonth' table.
Generated SQL Query: SELECT y."customerid", y."consumption" FROM public.yearmonth y WHERE y."date" = '201208' AND y."customerid" IN (SELECT t."customerid" FROM public.transactions_1k t WHERE t."price" / t."amount" > 29.00 AND t."productid" = 5)





In [40]:
# execute the queries against the database and output the results

def run_query(query):
    try:
        result = sql_client.query(query)
        return result
    except Exception as e:
        return str(e)
    
eval['generated_query_result'] = eval['generated_sql_query'].progress_apply(run_query)
eval['ground_truth_result'] = eval['SQL'].progress_apply(run_query)

100%|██████████| 30/30 [00:03<00:00,  8.57it/s]
100%|██████████| 30/30 [00:04<00:00,  7.30it/s]


In [56]:
class isValidSQL:
    def __init__(self, db_client: DatabaseClient):
        self.db_client = db_client

    def __call__(self,*, query,**kwds):
        try:
            result = self.db_client.query(query)
            return 1
        except Exception as e:
            return 0

class SQLExecutionAccuracy:
    def __init__(self, db_client: DatabaseClient, openai_client: AzureOpenAI):
        self.db_client = db_client

    def __call__(self,*, question, query, expected_result, **kwds):
        evaluator_prompt = f"""You are a helpful AI SQL expert.
        Your are given a question, the expected result and the results generated by a SQL query.
        Your role is to determine if the ground truth result and the generated result are the same.
        These are query results so the columns may not be in the same order or have the same names but contain the right information.


        ## Specifically you must:
        - Compare the expected result with the generated result
        - If the generated result is not the same as the expected result but contains the same information or numerical data, you return 1
        - If the generated result is exactly the same as the expected result, you return 1
        - If the generated result is different from the expected result and does not contain the same information, you return 0
        - If two numeric values are rounded differently, you should consider them the same

        You cannot return anything other than 1 or 0. Your answer must be an integer.
        
        ## Examples
        - Question: What is the average amount spent by customer 234?
        - Generated Result: [{{"average_spend": 0.06572769953051644}}]
        - Expecter Result: [{{"?column?": 0.06572769953051644}}]

        Answer: 1
        """
        try:
            result = self.db_client.query(query)
            user_prompt = f"""
                    Question: {question}
                    Generated Result: {result}
                    Expected Result: {expected_result}
                    """
            messages = [{"role": "system", "content": evaluator_prompt},
                        {"role": "user", "content": user_prompt},
                        ]
            response = openai_client.chat.completions.create(
                model="gpt-4o-global",
                messages=messages,
            )
            response = response.choices[0].message.content
            return int(response)
        except Exception as e:
            return None
                        

In [57]:

eval['IsValidSQL']=eval.progress_apply(lambda row: isValidSQL(sql_client)(query=row['generated_sql_query']), axis=1)
eval['SQLExecutionAccuracy'] = eval.progress_apply(lambda row: SQLExecutionAccuracy(sql_client, openai_client)(question=row['question'],
                                                                                                                query=row['generated_sql_query'],
                                                                                                                expected_result=row['ground_truth_result']), axis=1)

execution_accuracy = eval['SQLExecutionAccuracy'].mean()
valid_sql = eval['IsValidSQL'].mean()

print("SQL Execution Accuracy:", execution_accuracy)
print("Valid SQL Queries:", valid_sql)

100%|██████████| 30/30 [00:03<00:00,  8.61it/s]
100%|██████████| 30/30 [00:13<00:00,  2.27it/s]

SQL Execution Accuracy: 0.5925925925925926
Valid SQL Queries: 0.9



