In [13]:
import sys
import os
from dotenv import load_dotenv
import pandas as pd

# Add the parent directory to the sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))

from sqltoolkit.connectors import PostgreSQLConnector
from sqltoolkit.client import DatabaseClient
from sqltoolkit.entities import Table, TableColumn
from sqltoolkit.compiler import SQLQueryChecker

load_dotenv()

# Define the connection parameters
server = os.getenv('POSTGRES_HOST')
database = 'debit_card_specializing'
username = os.getenv('POSTGRES_USER')
password = os.getenv('POSTGRES_PWD')
port = '5432'

psql_connector = PostgreSQLConnector(server, database, username, password, port)
sql_client = DatabaseClient(psql_connector)

In [3]:
from openai import AzureOpenAI

aoai_endpoint = os.getenv('OPENAI_ENDPOINT')
aoai_key = os.getenv('OPENAI_API_KEY')
aoai_deployment = os.getenv('OPENAI_4o_DEPLOYMENT') # should be GPT-4o
aoai_embedding_deployment = os.getenv('OPENAI_EMBEDDING_DEPLOYMENT') # should be text-embedding-small

openai_client = AzureOpenAI(azure_endpoint=aoai_endpoint,
    api_key=aoai_key,
    api_version='2024-10-21')

In [3]:
import os
import pandas as pd

directory = 'debit_card_specializing_doc'
concatenated_content = ""

for filename in os.listdir(directory):
    if filename.endswith(".csv"):
        table_name = filename.split('.')[0]
        df = pd.read_csv(os.path.join(directory, filename))
        table_markdown = df.to_markdown(index=False)
        concatenated_content += f"\n### Table: {table_name}\n\n{table_markdown}\n\n"

print(concatenated_content)


### Table: transactions_1k

| original_column_name   | column_name    | column_description   | data_format   | value_description            |
|:-----------------------|:---------------|:---------------------|:--------------|:-----------------------------|
| TransactionID          | Transaction ID | Transaction ID       | integer       | nan                          |
| Date                   | nan            | Date                 | date          | nan                          |
| Time                   | nan            | Time                 | text          | nan                          |
| CustomerID             | Customer ID    | Customer ID          | integer       | nan                          |
| CardID                 | Card ID        | Card ID              | integer       | nan                          |
| GasStationID           | Gas Station ID | Gas Station ID       | integer       | nan                          |
| ProductID              | Product ID     | Product ID     

In [4]:
from sqltoolkit.indexer import DatabaseIndexer

context = concatenated_content

indexer = DatabaseIndexer(sql_client, 
                          openai_client, 
                          aoai_deployment,
                          aoai_embedding_deployment,
                          extra_context=context)

manifest = indexer.fetch_and_describe_tables()
indexer.generate_table_embeddings()

tables_dict = indexer.export_json_manifest()

2024-12-18 13:06:15,248 - DatabaseIndexer - INFO - Fetching tables from the database.
2024-12-18 13:06:15,433 - DatabaseIndexer - INFO - Processing table: public.customers
2024-12-18 13:06:20,312 - DatabaseIndexer - INFO - Completed processing table: public.customers
2024-12-18 13:06:20,314 - DatabaseIndexer - INFO - Processing table: public.gasstations
2024-12-18 13:06:27,721 - DatabaseIndexer - INFO - Completed processing table: public.gasstations
2024-12-18 13:06:27,723 - DatabaseIndexer - INFO - Processing table: public.products
2024-12-18 13:06:31,181 - DatabaseIndexer - INFO - Completed processing table: public.products
2024-12-18 13:06:31,182 - DatabaseIndexer - INFO - Processing table: public.transactions_1k


Object of type date is not JSON serializable


2024-12-18 13:06:44,404 - DatabaseIndexer - INFO - Completed processing table: public.transactions_1k
2024-12-18 13:06:44,405 - DatabaseIndexer - INFO - Processing table: public.yearmonth
2024-12-18 13:06:49,301 - DatabaseIndexer - INFO - Completed processing table: public.yearmonth
2024-12-18 13:06:49,302 - DatabaseIndexer - INFO - Completed fetching and processing all tables.


In [9]:
indexer.tables[0].model_dump()

{'name': 'public.customers',
 'business_readable_name': 'Customer Segments',
 'description': 'The customers table contains information about individual customers, including a unique identifier, their client segment classification, and the currency they use for transactions. This table helps in distinguishing customers and categorizes them into different groups such as Key Account Management (KAM), Large Account Management (LAM), and Small and Medium Enterprises (SME). The currency column indicates the currency codes like CZK and EUR used by customers in their transactions. This information is used for tailoring services and performing financial analyses.',
 'columns': [{'name': 'customerid',
   'type': 'bigint',
   'description': None,
   'definition': 'Customer ID is a unique identifier assigned to each customer in the database. It is represented as a bigint integer and serves as the primary key for the customers table. This column helps in distinguishing individual customers and is u

In [4]:
search_endpoint = os.getenv('AI_SEARCH_ENDPOINT')
search_key = os.getenv('AI_SEARCH_API_KEY')
embedding_deployment = os.getenv('OPENAI_EMBEDDING_DEPLOYMENT')
index_name = database


In [None]:

indexer.create_azure_ai_search_index(
    search_endpoint=search_endpoint,
    search_credential=search_key,
    index_name=index_name,
    embedding_deployment=aoai_embedding_deployment,
    openai_endpoint=aoai_endpoint,
    openai_key=aoai_key,
)

# write to AI Search
indexer.push_to_ai_search()

In [5]:
import time
from azure.search.documents import SearchClient
from azure.search.documents.models import QueryType, QueryCaptionType, QueryAnswerType
from azure.search.documents.models import VectorizableTextQuery
from azure.core.credentials import AzureKeyCredential
import json

search_client = SearchClient(endpoint=search_endpoint, 
                             credential=AzureKeyCredential(search_key), 
                             index_name=index_name)

tables = search_client.search(search_text="*")
tables = [{'name':table['name'],
           'description':table['description']} for table in tables]

def rewrite_query(openai_client: AzureOpenAI, 
                  deployment_name:str,
                  user_question:str,
                  evidence:str=None)->str:
    rewriter_prompt = f"""
    You are a helpful AI SQL Agent. Your role is to help people translate their questions 
    into better questions that can be translated into SQL queries.
    You should disambiguate the question and provide a more specific version of the question to the best of your ability.
    You should rewrite the question in a way that is more likely to be answered by the database.

    ## YOU MUST USE THE FOLLOWING CONTEXT ABOUT THE DATABASE TO REWRITE THE QUESTION:
    {tables}

    ## You should rewrite the question using as many terms from the context as possible without using specific table names

    You must only return a single sentence that is a better version of the user's question.
    """

    messages = [
        {
            "role": "system",
            "content": rewriter_prompt
        },
        {
            "role": "user",
            "content": user_question
        }
    ]

    rewriter_response = openai_client.chat.completions.create(
        model=deployment_name,
        messages=messages,
    )

    rewriter_response = rewriter_response.choices[0].message.content
    return rewriter_response

def search_tables(search_client: SearchClient, user_question:str)->list:
    vector_query = VectorizableTextQuery(text=user_question, k_nearest_neighbors=50, fields="embedding")

    results = search_client.search(
        search_text=user_question,
        vector_queries=[vector_query],
        query_type=QueryType.SEMANTIC,
        semantic_configuration_name='my-semantic-config',
        query_caption=QueryCaptionType.EXTRACTIVE,
        query_answer=QueryAnswerType.EXTRACTIVE,
        top=10
    )

    candidate_tables = []
    for result in results:
        candidate_tables.append(
            {key: result[key] for key in ['name', 'business_readable_name', 'description', 'columns'] if key in result}
        )

    return candidate_tables

def table_re_ranker(openai_client: AzureOpenAI, 
                    deployment:str,
                    user_question:str, 
                    candidate_tables:list)->str:
    
    reranker_prompt = f"""You are a helpful AI SQL Agent.
    Your role is to identify the most relevant tables that can be used to answer the user's question.
    You are provided with a list of tables and their schema.

    ## Important instructions:
    - Your role is to select the tables that might be relevant to the user's question.
    - Don't make assumptions on the SQL query that will be generated, only select the tables that might be relevant to the user's question.
    - You must also select a subset of the columns that contain the information needed to answer the user's question.
    - You must return 3 tables
    - You must rank the tables based on their relevance to the user's question

    You should return the 3 tables in the following JSON format:
    {{ "tables":[
        {{
            "table": "table_name",
            "columns": ["column1", "column2", "column3"]
        }}
    ]}}

    Do not return anything other than the json

    """
    user_prompt = f"""
    <User Question>: {user_question} </User Question>
    <Candidate Tables>: {candidate_tables} </Candidate Tables>"""

    messages = [
        {
            "role": "system",
            "content": reranker_prompt
        },
        {
            "role": "user",
            "content": user_prompt
        }
    ]
    
    response = openai_client.chat.completions.create(
        model=deployment,
        messages=messages,
        temperature=0.3,
        response_format={"type": "json_object"}
    )

    response_message = response.choices[0].message.content
    return response_message

def filter_candidate_tables(candidate_tables, reranked_tables_json):
    filtered_tables = []
    for table in candidate_tables:
        for reranked_table in reranked_tables_json['tables']:
            if table['name'] == reranked_table['table']:
                filtered_columns = [col for col in table['columns'] if col['name'] in reranked_table['columns']]
                filtered_tables.append({
                    'name': table['name'],
                    'description': table['description'],
                    'columns': filtered_columns
                })
    return filtered_tables


In [6]:
import json

query = "In 2012, who had the least consumption in LAM??"
evidence = "Year 2012 can be presented as Between 201201 And 201212; The first 4 strings of the Date values in the yearmonth table can represent year."

rewritten_query = rewrite_query(openai_client, 'gpt-4o-mini', query, evidence)
print(rewritten_query)

candidate_tables = search_tables(search_client, rewritten_query)
print(len(candidate_tables))

reranked_tables = table_re_ranker(openai_client, 'gpt-4o-mini', rewritten_query, candidate_tables)
print(reranked_tables)

# Parse the reranked tables JSON
reranked_tables_json = json.loads(reranked_tables)

filtered_tables = filter_candidate_tables(candidate_tables, reranked_tables_json)

print(filtered_tables)

Who were the customers in the Large Account Management (LAM) segment with the lowest product consumption in the year 2012?
5
{
    "tables": [
        {
            "table": "public.customers",
            "columns": ["customerid", "segment"]
        },
        {
            "table": "public.yearmonth",
            "columns": ["customerid", "date", "consumption"]
        },
        {
            "table": "public.transactions_1k",
            "columns": ["customerid", "date", "amount"]
        }
    ]
}
[{'name': 'public.customers', 'description': 'The customers table contains information about individual customers, including a unique identifier, their client segment classification, and the currency they use for transactions. This table helps in distinguishing customers and categorizes them into different groups such as Key Account Management (KAM), Large Account Management (LAM), and Small and Medium Enterprises (SME). The currency column indicates the currency codes like CZK and EUR u

In [25]:

from pydantic import BaseModel

def check_query(sql_query:str, tables:list, openai_client, deployment)->bool:
    query_checker = SQLQueryChecker(openai_client, deployment,'Postgres', tables)
    query_checks = query_checker.validate_query(sql_query)
    print(f'Query checks: {query_checks}')
    return query_checks

def get_tools():
    return [{
            "type": "function",
            "function": {
                "name": "check_query",
                "description": "Validate a SQL query for common mistakes, always call this before returning the query, pass in the full query",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "query": {
                            "type": "string",
                            "description": "The SQL query to validate",
                        }
                    },
                    "required": ["query"],
                },
            }
        }]


def generate_sql_query(openai_client: AzureOpenAI,
                          deployment:str,
                          user_question:str,
                          evidence:str,
                          candidate_tables:list)->str:
    
    candidate_tables_min = [{'table': table['name'], 'columns': [col['name'] for col in table['columns']]} for table in candidate_tables]
    system_prompt = f"""
     You are a helpful AI data analyst assistant, 
     You can execute SQL queries to retrieve information from a SQL database,
     The database is SQL server, use the right syntax to generate queries
    
     ### You must use the POSTGRES SQL syntax to query the database
     ### You must use double quotes around the column names but not around table names
     ### Your output must be a single SQL query
     ### You must use the column sample values to infer the value in your WHERE clause
    
     ### These are the available tables in the database alongside their descriptions:
     <table_descriptions>
     {candidate_tables}
     </table_descriptions>
    
     When asked a question that could be answered with a SQL query: 
     - Break down the question into smaller parts that can be answered with SQL queries
     - Identify the tables that contain the information needed to answer the question
     - Only once this is done, create a SQL query to retrieve the information based on your understanding of the table
    
     # You must use the exact table names and column names provided in the table descriptions
     # Do not rewrite any column names in a way that is different from the table descriptions
     # You must write the column names and table names exactly as they are in the table descriptions
     # if the table name provided in the table descriptions is of the format schema.table, you must use the schema.table format in your query
     # You must produce a valid SQL query that can be executed, the query should not have any syntax errors
     # When calculating ratios, you must prevent division by zero errors 

     # Call the check_query function to validate your query before returning it
     # You must provide your entire query to the check_query function, do not split it into parts
    
     Think step by step, before doing anything, share the different steps you'll execute to get the answer
    
    ## Your response should be in JSON format with the following structure:
    
    {{"chain_of_thought": "Your reasoning",
    "sql_query": "The generated SQL query, this can only be a single query"}}

    ## You must return a valid JSON
    ## VERY IMPORTANT: You cannot return anything other than the JSON object

    """
     
    user_prompt = f"""<User Question>: {user_question} </User Question>
                <Additional Context>: {evidence} </Additional Context>"""
     
    messages = [
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": user_prompt
        }
    ]
    
    response = openai_client.chat.completions.create(
        model=deployment,
        messages=messages,
        response_format={"type": "json_object"},
    )
    response_message = response.choices[0].message
    messages.append(response_message)

    # Check the query for common mistakes
    response_json = json.loads(response_message.content)['sql_query']
    query_checks = check_query(response_json, candidate_tables_min, openai_client, deployment)
    
    counter = 0
    while counter < 2:
        if not query_checks.get('query_valid'):
            messages.append({
                "role": "user",
                "content": f"Query validation failed, please fix your query: {query_checks.get('error')}"
            })

            response = openai_client.chat.completions.create(
                model=deployment,
                messages=messages,
                response_format={"type": "json_object"},
            )
            response_message = response.choices[0].message
            response_json = json.loads(response_message.content)['sql_query']
            query_checks = check_query(response_json, candidate_tables_min, openai_client, deployment)
            counter += 1
        else:
            break

    return response




def execute_query_pipeline(user_question, evidence=None):
    rewriter_response = rewrite_query(openai_client, 'gpt-4o-mini', user_question, evidence)
    candidate_tables = search_tables(search_client, rewriter_response)
    reranked_tables = table_re_ranker(openai_client, 'gpt-4o-mini', rewriter_response, candidate_tables)
    reranked_tables_json = json.loads(reranked_tables)
    filtered_tables = filter_candidate_tables(candidate_tables, reranked_tables_json)
    response = generate_sql_query(openai_client, aoai_deployment, user_question, evidence, filtered_tables)
    response_json = json.loads(response.choices[0].message.content, strict=False)
    candidate_table_names = [table['name'] for table in filtered_tables]
    
    return {'sql_query':response_json['sql_query'],
            'chain_of_thought':response_json['chain_of_thought'],
            'tables':candidate_table_names,
            'rewritten_question':rewriter_response,
            'table_context':candidate_tables,
            }
    
    

# Example usage
user_question = "What is the ratio of customers who pay in EUR against customers who pay in CZK?"
hint = "ratio of customers who pay in EUR against customers who pay in CZK = count(Currency = 'EUR') / count(Currency = 'CZK')."
start_time = time.time()
result = execute_query_pipeline(user_question, evidence=hint)
end_time = time.time()

latency = end_time - start_time
print("Result:", result)
print("Latency:", latency, "seconds")


Query checks: {'query_valid': True}
Result: {'sql_query': 'SELECT CASE WHEN czk_count = 0 THEN NULL ELSE eur_count::float / czk_count END AS eur_to_czk_ratio FROM (SELECT COUNT(DISTINCT CASE WHEN "currency" = \'EUR\' THEN "customerid" END) AS eur_count, COUNT(DISTINCT CASE WHEN "currency" = \'CZK\' THEN "customerid" END) AS czk_count FROM public.customers) AS subquery;', 'chain_of_thought': 'To determine the ratio of customers who pay in EUR to those who pay in CZK, I\'ll need to count the number of customers who use each currency. The relevant information can be found in the \'public.customers\' table. Specifically, I\'ll count the number of distinct "customerid" entries where the "currency" is \'EUR\' and \'CZK\'. Then, I\'ll compute the ratio of these counted values. To avoid division by zero, I will include a CASE statement ensuring that division is safe. I\'ll perform this in a single aggregate SQL query using conditional aggregation to compute the counts and the ratio.', 'tables'

In [16]:
sql_client.query(result['sql_query'])

'[{"eur_to_czk_ratio": 0.06572769953051644}]'

In [17]:
import pandas as pd

df = pd.read_json('mini_dev_postgresql.json')
eval = df[df['db_id']=='debit_card_specializing']
eval.head()

Unnamed: 0,question_id,db_id,question,evidence,SQL,difficulty
0,1471,debit_card_specializing,What is the ratio of customers who pay in EUR ...,ratio of customers who pay in EUR against cust...,SELECT CAST(SUM(CASE WHEN Currency = 'EUR' THE...,simple
1,1472,debit_card_specializing,"In 2012, who had the least consumption in LAM?",Year 2012 can be presented as Between 201201 A...,SELECT T1.CustomerID FROM customers AS T1 INNE...,moderate
2,1473,debit_card_specializing,What was the average monthly consumption of cu...,Average Monthly consumption = AVG(Consumption)...,"SELECT AVG(T2.Consumption) / NULLIF(12, 0) FRO...",moderate
3,1476,debit_card_specializing,What was the difference in gas consumption bet...,cast the consumption into float when perform c...,SELECT SUM(CASE WHEN T1.Currency = 'CZK' THEN ...,challenging
4,1479,debit_card_specializing,Which year recorded the most consumption of ga...,The first 4 strings of the Date values in the ...,"SELECT SUBSTR(T2.Date, 1, 4) FROM customers AS...",moderate


In [None]:
from tqdm import tqdm

tqdm.pandas()

eval['llm_result'] = eval.progress_apply(lambda row: execute_query_pipeline(row['question'], row['evidence']), axis=1)
eval['generated_sql_query'] = eval['llm_result'].apply(lambda x: x['sql_query'])
eval['rewritten_question'] = eval['llm_result'].apply(lambda x: x['rewritten_question'])
eval['generated_context_tables'] = eval['llm_result'].apply(lambda x: x['tables'])
eval['generated_chain_of_thought'] = eval['llm_result'].apply(lambda x: x['chain_of_thought'])

  0%|          | 0/30 [00:00<?, ?it/s]

['public.customers', 'public.yearmonth', 'public.transactions_1k']


  7%|▋         | 2/30 [00:06<01:36,  3.43s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT CASE WHEN COUNT(CASE WHEN "currency" = 'CZK' THEN 1 END) = 0 THEN NULL ELSE CAST(COUNT(CASE WHEN "currency" = 'EUR' THEN 1 END) AS FLOAT) / COUNT(CASE WHEN "currency" = 'CZK' THEN 1 END) END AS eur_to_czk_ratio FROM public.customers
['public.customers', 'public.yearmonth', 'public.transactions_1k']


 10%|█         | 3/30 [00:15<02:29,  5.53s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT "customerid", MIN(consumption) AS min_consumption FROM public.yearmonth ym JOIN public.customers c ON ym.customerid = c.customerid WHERE ym.date BETWEEN '201201' AND '201212' AND c.segment = 'LAM' GROUP BY "customerid" ORDER BY min_consumption LIMIT 1;
['public.yearmonth', 'public.customers', 'public.transactions_1k']


 13%|█▎        | 4/30 [00:26<03:20,  7.70s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT AVG("consumption") / 12 AS average_monthly_consumption FROM public.yearmonth INNER JOIN public.customers ON public.yearmonth."customerid" = public.customers."customerid" WHERE "date" BETWEEN '201301' AND '201312' AND "segment" = 'SME';
['public.customers', 'public.yearmonth', 'public.transactions_1k']


 17%|█▋        | 5/30 [00:33<03:00,  7.24s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT SUM(CASE WHEN currency = 'CZK' THEN consumption::float ELSE 0 END) - SUM(CASE WHEN currency = 'EUR' THEN consumption::float ELSE 0 END) AS consumption_difference FROM public.yearmonth INNER JOIN public.customers ON public.yearmonth.customerid = public.customers.customerid WHERE date BETWEEN '201201' AND '201212';
['public.customers', 'public.yearmonth', 'public.transactions_1k']


 20%|██        | 6/30 [00:39<02:44,  6.86s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT LEFT(y."date", 4) AS year, SUM(y."consumption") AS total_consumption FROM public.yearmonth y JOIN public.customers c ON y."customerid" = c."customerid" WHERE c."currency" = 'CZK' GROUP BY LEFT(y."date", 4) ORDER BY total_consumption DESC LIMIT 1
['public.yearmonth', 'public.customers', 'public.transactions_1k']


 23%|██▎       | 7/30 [00:48<02:53,  7.54s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT SUBSTRING(y.date, 5, 2) AS month, SUM(y.consumption) AS total_consumption FROM public.yearmonth y INNER JOIN public.transactions_1k t ON y.date = TO_CHAR(t.date, 'YYYYMM') INNER JOIN public.customers c ON t.customerid = c.customerid WHERE c.segment = 'SME' AND y.date BETWEEN '201301' AND '201312' GROUP BY SUBSTRING(y.date, 5, 2) ORDER BY total_consumption DESC LIMIT 1;
['public.customers', 'public.yearmonth', 'public.transactions_1k']


 27%|██▋       | 8/30 [00:58<03:03,  8.33s/it]

Query checks: {'query_valid': True}
SQL Query: WITH RankedConsumption AS (SELECT "customers"."segment", "yearmonth"."customerid", SUM("yearmonth"."consumption") AS total_consumption, ROW_NUMBER() OVER (PARTITION BY "customers"."segment" ORDER BY SUM("yearmonth"."consumption") ASC) AS rn FROM public.customers INNER JOIN public.yearmonth ON "customers"."customerid" = "yearmonth"."customerid" WHERE "customers"."currency" = 'CZK' AND "yearmonth"."date" BETWEEN '201301' AND '201312' GROUP BY "customers"."segment", "yearmonth"."customerid"), AverageConsumption AS (SELECT "segment", total_consumption AS annual_average FROM RankedConsumption WHERE rn = 1) SELECT (SELECT annual_average FROM AverageConsumption WHERE "segment" = 'SME') - (SELECT annual_average FROM AverageConsumption WHERE "segment" = 'LAM') AS difference_sme_lam, (SELECT annual_average FROM AverageConsumption WHERE "segment" = 'LAM') - (SELECT annual_average FROM AverageConsumption WHERE "segment" = 'KAM') AS difference_lam_kam,

 30%|███       | 9/30 [01:07<02:59,  8.53s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT "segment", (coalesce(consumption_2013, 0) - coalesce(consumption_2012, 0))::real / NULLIF(coalesce(consumption_2013, 0), 0) * 100 AS percentage_increase FROM ( SELECT c."segment", SUM(CASE WHEN SUBSTRING(y."date", 1, 4) = '2012' THEN y."consumption" ELSE 0 END) AS consumption_2012, SUM(CASE WHEN SUBSTRING(y."date", 1, 4) = '2013' THEN y."consumption" ELSE 0 END) AS consumption_2013 FROM public.customers c INNER JOIN public.yearmonth y ON c."customerid" = y."customerid" WHERE c."currency" = 'EUR' AND c."segment" IN ('KAM', 'LAM', 'SME') GROUP BY c."segment" ) AS segment_consumption ORDER BY percentage_increase ASC LIMIT 1;
['public.yearmonth', 'public.transactions_1k', 'public.customers']


 33%|███▎      | 10/30 [01:11<02:27,  7.38s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT SUM("consumption") AS total_consumption FROM public.yearmonth WHERE "customerid" = 6 AND "date" BETWEEN '201308' AND '201311';
['public.gasstations', 'public.transactions_1k', 'public.products']


 37%|███▋      | 11/30 [01:18<02:15,  7.12s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT COALESCE(SUM(CASE WHEN "country" = 'CZE' AND "segment" = 'Discount' THEN 1 ELSE 0 END), 0) - COALESCE(SUM(CASE WHEN "country" = 'SVK' AND "segment" = 'Discount' THEN 1 ELSE 0 END), 0) AS discount_difference FROM public.gasstations;
['public.customers', 'public.transactions_1k', 'public.gasstations']


 40%|████      | 12/30 [01:26<02:10,  7.26s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT COUNT(CASE WHEN "currency" = 'CZK' THEN 1 END) AS count_czk, COUNT(CASE WHEN "currency" = 'EUR' THEN 1 END) AS count_eur, COUNT(CASE WHEN "currency" = 'CZK' THEN 1 END) - COUNT(CASE WHEN "currency" = 'EUR' THEN 1 END) AS difference FROM public.customers WHERE "segment" = 'SME';
['public.customers', 'public.yearmonth', 'public.transactions_1k']


 43%|████▎     | 13/30 [01:31<01:55,  6.79s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT (CAST(COUNT(CASE WHEN y."consumption" > 46.73 THEN 1 END) AS float) / NULLIF(COUNT(c."customerid"), 0)) * 100 AS percentage_consumed_more_than_46_73 FROM public.customers c JOIN public.yearmonth y ON c."customerid" = y."customerid" WHERE c."segment" = 'LAM';
['public.transactions_1k', 'public.yearmonth', 'public.customers']


 47%|████▋     | 14/30 [01:36<01:40,  6.30s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT COALESCE((COUNT(DISTINCT CASE WHEN "consumption" > 528.3 THEN "customerid" END) * 100.0 / NULLIF(COUNT(DISTINCT "customerid"), 0)), 0) AS percentage FROM public.yearmonth WHERE "date" = '201202';
['public.yearmonth', 'public.transactions_1k', 'public.products']


 50%|█████     | 15/30 [01:44<01:39,  6.64s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT MAX(CAST("consumption" AS float)) AS max_consumption FROM public.yearmonth WHERE "date" LIKE '2012%';
['public.yearmonth', 'public.transactions_1k', 'public.products']


 53%|█████▎    | 16/30 [01:50<01:32,  6.61s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT DISTINCT p."description" FROM public.transactions_1k t JOIN public.yearmonth ym ON to_char(t."date", 'YYYYMM') = ym."date" JOIN public.products p ON t."productid" = p."productid" WHERE ym."date" = '201309'
['public.gasstations', 'public.transactions_1k', 'public.yearmonth']


 57%|█████▋    | 17/30 [02:06<02:02,  9.40s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT DISTINCT g."country" FROM public.transactions_1k t JOIN public.gasstations g ON t."gasstationid" = g."gasstationid" JOIN public.yearmonth y ON to_char(t."date", 'YYYYMM') = y."date" WHERE y."date" = '201306';
['public.customers', 'public.yearmonth', 'public.transactions_1k']


 60%|██████    | 18/30 [02:12<01:39,  8.32s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT COUNT(DISTINCT c."customerid") AS number_of_customers FROM public.customers c JOIN public.yearmonth ym ON c."customerid" = ym."customerid" WHERE c."currency" = 'EUR' AND ym."consumption" > 1000;
['public.transactions_1k', 'public.gasstations', 'public.products']


 63%|██████▎   | 19/30 [02:18<01:23,  7.58s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT DISTINCT products."description" FROM public.transactions_1k JOIN public.gasstations ON transactions_1k."gasstationid" = gasstations."gasstationid" JOIN public.products ON transactions_1k."productid" = products."productid" WHERE gasstations."country" = 'CZE';
['public.yearmonth', 'public.transactions_1k', 'public.gasstations']


 67%|██████▋   | 20/30 [02:26<01:16,  7.69s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT DISTINCT t."time" FROM public.transactions_1k t JOIN public.gasstations g ON t.gasstationid = g.gasstationid WHERE g.chainid = 11;
['public.yearmonth', 'public.transactions_1k', 'public.gasstations']


 70%|███████   | 21/30 [02:34<01:11,  7.96s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT COUNT(*) FROM public.transactions_1k WHERE "date" > '2012-01-01' AND "gasstationid" IN (SELECT "gasstationid" FROM public.gasstations WHERE "country" = 'CZE')
['public.customers', 'public.yearmonth', 'public.transactions_1k']


 73%|███████▎  | 22/30 [02:42<01:01,  7.73s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT c."currency" FROM public.transactions_1k t JOIN public.customers c ON t."customerid" = c."customerid" WHERE t."date" = '2012-08-24' AND t."time" = '16:25:00';
['public.customers', 'public.transactions_1k', 'public.gasstations']


 77%|███████▋  | 23/30 [02:49<00:52,  7.53s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT c."segment" FROM public.transactions_1k t JOIN public.customers c ON t."customerid" = c."customerid" WHERE t."date" = '2012-08-23' AND t."time" = '21:20:00'
['public.yearmonth', 'public.transactions_1k', 'public.gasstations']


 80%|████████  | 24/30 [02:54<00:41,  6.94s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT COUNT(*) FROM public.transactions_1k t JOIN public.gasstations g ON t."gasstationid" = g."gasstationid" WHERE t."date" = '2012-08-26' AND t."time" BETWEEN '08:00:00' AND '09:00:00' AND g."country" = 'CZE';
['public.customers', 'public.transactions_1k', 'public.yearmonth']


 83%|████████▎ | 25/30 [03:13<00:52, 10.47s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT c."currency" FROM public.transactions_1k t JOIN public.customers c ON t."customerid" = c."customerid" WHERE t."amount" = 548.4 AND t."date" = '2012-08-24';
['public.customers', 'public.yearmonth', 'public.transactions_1k']


 87%|████████▋ | 26/30 [03:19<00:36,  9.02s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT (CASE WHEN COUNT(DISTINCT t."customerid") = 0 THEN 0 ELSE (COUNT(DISTINCT CASE WHEN c."currency" = 'EUR' THEN t."customerid" END)::float / COUNT(DISTINCT t."customerid")) * 100 END) AS eur_customer_percentage FROM public.transactions_1k t JOIN public.customers c ON t."customerid" = c."customerid" WHERE t."date" = '2012-08-25'
['public.yearmonth', 'public.transactions_1k', 'public.customers']


 90%|█████████ | 27/30 [03:28<00:27,  9.00s/it]

Query checks: {'query_valid': True}
SQL Query: WITH target_customer AS (
  SELECT "customerid"
  FROM public.transactions_1k
  WHERE date = '2012-08-25' AND 634.8 = (SELECT amount) -- Assuming '634.8' is the exact transaction amount
  LIMIT 1
),
consumption_2012 AS (
  SELECT SUM("consumption") AS total_2012
  FROM public.yearmonth
  WHERE "customerid" IN (SELECT "customerid" FROM target_customer) AND "date" LIKE '2012%'
),
consumption_2013 AS (
  SELECT SUM("consumption") AS total_2013
  FROM public.yearmonth
  WHERE "customerid" IN (SELECT "customerid" FROM target_customer) AND "date" LIKE '2013%'
)
SELECT 
  (c2012.total_2012 - c2013.total_2013) / NULLIF(c2012.total_2012, 0) AS consumption_decrease_rate
FROM 
  consumption_2012 c2012
  CROSS JOIN consumption_2013 c2013;
['public.customers', 'public.gasstations', 'public.transactions_1k']


 93%|█████████▎| 28/30 [03:42<00:21, 10.77s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT CASE WHEN total_gas_stations_SVK > 0 THEN (premium_gas_stations / total_gas_stations_SVK::float) * 100 ELSE 0 END AS premium_percentage FROM (SELECT COUNT(*) FILTER (WHERE "segment" = 'Premium') AS premium_gas_stations, COUNT(*) AS total_gas_stations_SVK FROM public.gasstations WHERE "country" = 'SVK') AS gas_station_counts;
['public.yearmonth', 'public.transactions_1k', 'public.customers']


 97%|█████████▋| 29/30 [03:49<00:09,  9.62s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT SUM(t."amount" * t."price") AS total_spent, SUM(CASE WHEN y."date" = '201201' THEN t."amount" * t."price" ELSE 0 END) AS spent_in_january_2012 FROM public.transactions_1k t JOIN public.yearmonth y ON CAST(EXTRACT(YEAR FROM t."date") AS TEXT) || LPAD(CAST(EXTRACT(MONTH FROM t."date") AS TEXT), 2, '0') = y."date" WHERE t."customerid" = 38508;
['public.customers', 'public.yearmonth', 'public.transactions_1k']


100%|██████████| 30/30 [03:58<00:00,  9.21s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT t1."customerid", (SUM(t1."amount" * t1."price")) AS total_spending, (SUM(t1."amount") > 0) AS has_valid_amount, CASE WHEN SUM(t1."amount") > 0 THEN SUM(t1."price") / SUM(t1."amount") ELSE 0 END AS average_price_per_item, c."currency" FROM public.transactions_1k t1 JOIN public.customers c ON t1."customerid" = c."customerid" GROUP BY t1."customerid", c."currency" ORDER BY total_spending DESC LIMIT 1
['public.yearmonth', 'public.transactions_1k', 'public.products']


100%|██████████| 30/30 [04:11<00:00,  8.40s/it]

Query checks: {'query_valid': True}
SQL Query: SELECT ym."customerid", ym."consumption" FROM public.yearmonth ym JOIN (SELECT t."customerid" FROM public.transactions_1k t WHERE t."productid" = 5 AND t."price" > 29) AS high_price_customers ON ym."customerid" = high_price_customers."customerid" WHERE ym."date" = '201208';





In [28]:
eval.head(7)

Unnamed: 0,question_id,db_id,question,evidence,SQL,difficulty,llm_result,generated_sql_query,generated_query_result,ground_truth_result,IsValidSQL,SQLExecutionAccuracy
0,1471,debit_card_specializing,What is the ratio of customers who pay in EUR ...,ratio of customers who pay in EUR against cust...,SELECT CAST(SUM(CASE WHEN Currency = 'EUR' THE...,simple,{'sql_query': 'SELECT COUNT(DISTINCT CASE WHEN...,"SELECT COUNT(DISTINCT CASE WHEN ""currency"" = '...","[{""eur_to_czk_ratio"": 0.06572769953051644}]","[{""?column?"": 0.06572769953051644}]",1,"{\n ""is_correct"": 1,\n ""justification"": ..."
1,1472,debit_card_specializing,"In 2012, who had the least consumption in LAM?",Year 2012 can be presented as Between 201201 A...,SELECT T1.CustomerID FROM customers AS T1 INNE...,moderate,{'sql_query': 'SELECT customerid FROM public.y...,SELECT customerid FROM public.yearmonth WHERE ...,"[{""customerid"": 47273, ""total_consumption"": 0....","[{""customerid"": 47273}]",1,"{\n ""is_correct"": 1,\n ""justification"": ..."
2,1473,debit_card_specializing,What was the average monthly consumption of cu...,Average Monthly consumption = AVG(Consumption)...,"SELECT AVG(T2.Consumption) / NULLIF(12, 0) FRO...",moderate,"{'sql_query': 'SELECT AVG(""consumption"") / 12 ...","SELECT AVG(""consumption"") / 12 AS average_mont...","[{""average_monthly_consumption"": 459.956264211...","[{""?column?"": 459.95626421124325}]",1,"{\n ""is_correct"": 1,\n ""justification"": ..."
3,1476,debit_card_specializing,What was the difference in gas consumption bet...,cast the consumption into float when perform c...,SELECT SUM(CASE WHEN T1.Currency = 'CZK' THEN ...,challenging,{'sql_query': 'SELECT (CAST(SUM(CASE WHEN cust...,SELECT (CAST(SUM(CASE WHEN customers.currency ...,"[{""consumption_difference"": 402524570.0228404}]","[{""?column?"": 402524570.0228404}]",1,"{\n ""is_correct"": 1,\n ""justification"": ..."
4,1479,debit_card_specializing,Which year recorded the most consumption of ga...,The first 4 strings of the Date values in the ...,"SELECT SUBSTR(T2.Date, 1, 4) FROM customers AS...",moderate,"{'sql_query': 'SELECT LEFT(y.""date"", 4) AS yea...","SELECT LEFT(y.""date"", 4) AS year, SUM(y.""consu...","[{""year"": ""2013"", ""total_consumption"": 2992772...","[{""substr"": ""2013""}]",1,"{\n ""is_correct"": 1,\n ""justification"": ..."
5,1480,debit_card_specializing,What was the gas consumption peak month for SM...,Year 2013 can be presented as Between 201301 A...,"SELECT SUBSTR(T2.Date, 5, 2) FROM customers AS...",moderate,{'sql_query': 'SELECT 'The required calculatio...,SELECT 'The required calculation cannot be per...,"[{""date"": ""201304""}]","[{""substr"": ""04""}]",1,"{\n ""is_correct"": 1,\n ""justification"": ..."
6,1481,debit_card_specializing,What is the difference in the annual average c...,annual average consumption of customer with th...,SELECT CAST(SUM(CASE WHEN T1.Segment = 'SME' T...,challenging,{'sql_query': 'WITH Least_Consumption_CZK AS (...,"WITH Least_Consumption_CZK AS ( SELECT c.""segm...","[{""sme_lam_difference"": -12438.070068359375, ""...","[{""?column?"": -582092.875}]",1,"{\n ""is_correct"": 0,\n ""justification"": ..."


In [19]:
# execute the queries against the database and output the results

def run_query(query):
    try:
        result = sql_client.query(query)
        return result
    except Exception as e:
        return str(e)
    
eval['generated_query_result'] = eval['generated_sql_query'].progress_apply(run_query)
eval['ground_truth_result'] = eval['SQL'].progress_apply(run_query)

100%|██████████| 30/30 [00:05<00:00,  5.54it/s]
100%|██████████| 30/30 [00:04<00:00,  6.01it/s]


In [20]:
class isValidSQL:
    def __init__(self, db_client: DatabaseClient):
        self.db_client = db_client

    def __call__(self,*, query,**kwds):
        try:
            result = self.db_client.query(query)
            return 1
        except Exception as e:
            return 0

class SQLExecutionAccuracy:
    def __init__(self, db_client: DatabaseClient, openai_client: AzureOpenAI):
        self.db_client = db_client

    def __call__(self,*, question, query, expected_result, **kwds):
        evaluator_prompt = f"""You are a helpful AI SQL expert.
        Your are given a question, the expected result and the results generated by a SQL query.
        Your role is to determine if the ground truth result and the generated result are the same.
        These are query results so the columns may not be in the same order or have the same names but contain the right information.


        ## Specifically you must:
        - Compare the expected result with the generated result
        - If the generated result is not the same as the expected result but contains the same information or numerical data, you return 1
        - If the generated result is exactly the same as the expected result, you return 1
        - If the generated result is different from the expected result and does not contain the same information, you return 0
        - If two numeric values are rounded differently, you should consider them the same

        - If the generated result is not exactly the same as the expected result but the expected result can be derived from the generated 
        results by applying post processing like (dividing, adding, extracting or multiplying, you should return 1). Pay very close attention to the question and the expected result.

        ## Additional instructions:
        Questions about months might have an answer formatted differently but having the same meaning, e.g 202403, March 2024, 03/2024, 2024-03, etc.,

        Your response should be formatted as a JSON object with the following structure:
        {{
            "is_correct": 1 or 0,
            "justification": "Your reasoning"
        }}        
        ## Examples
        - Question: What is the average amount spent by customer 234?
        - Generated Result: [{{"average_spend": 0.06572769953051644}}]
        - Expecter Result: [{{"?column?": 0.06572769953051644}}]

        Answer: {{"is_correct": 1, "justification": "The generated result is the same as the expected result"}}
        """
        try:
            result = self.db_client.query(query)
            user_prompt = f"""
                    Question: {question}
                    Generated Result: {result}
                    Expected Result: {expected_result}
                    """
            messages = [{"role": "system", "content": evaluator_prompt},
                        {"role": "user", "content": user_prompt},
                        ]
            response = openai_client.chat.completions.create(
                model="gpt-4o-global",
                messages=messages,
                response_format={"type": "json_object"}
            )
            response = response.choices[0].message.content
            return response
        except Exception as e:
            return None
                        

In [21]:

eval['IsValidSQL']=eval.progress_apply(lambda row: isValidSQL(sql_client)(query=row['generated_sql_query']), axis=1)
eval['SQLExecutionAccuracy'] = eval.progress_apply(lambda row: SQLExecutionAccuracy(sql_client, openai_client)(question=row['question'],
                                                                                                                query=row['generated_sql_query'],
                                                                                                                expected_result=row['ground_truth_result']), axis=1)


100%|██████████| 30/30 [00:02<00:00, 12.68it/s]
100%|██████████| 30/30 [00:56<00:00,  1.90s/it]


In [22]:
execution_accuracy = eval['SQLExecutionAccuracy'].apply(lambda x: json.loads(x).get('is_correct') if x else None).mean()
valid_sql = eval['IsValidSQL'].mean()

print("SQL Execution Accuracy:", execution_accuracy)
print("Valid SQL Queries:", valid_sql)

SQL Execution Accuracy: 0.5862068965517241
Valid SQL Queries: 0.9666666666666667


In [79]:
eval.to_csv('mini_dev_postgresql_evaluated.csv', index=False)

In [74]:
## Langchain 

from langchain_community.utilities.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain_openai import AzureChatOpenAI
import psycopg2

# Construct the database URI
pg_db_uri = f"postgresql+psycopg2://{username}:{password}@{server}:5432/{database}"

# Establish database connections
pg_db = SQLDatabase.from_uri(pg_db_uri)

# Initialize the Azure OpenAI language model
azure_llm = AzureChatOpenAI(
  azure_endpoint = aoai_endpoint,
  api_key=aoai_key,
  api_version="2024-10-21",
  deployment_name='gpt-4o-global',
)

context = pg_db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE customers (
	customerid BIGINT NOT NULL, 
	segment TEXT, 
	currency TEXT, 
	CONSTRAINT idx_47287_customers_pkey PRIMARY KEY (customerid)
)

/*
3 rows from customers table:
customerid	segment	currency
3	SME	EUR
5	LAM	EUR
6	SME	EUR
*/


CREATE TABLE gasstations (
	gasstationid BIGINT NOT NULL, 
	chainid BIGINT, 
	country TEXT, 
	segment TEXT, 
	CONSTRAINT idx_47292_gasstations_pkey PRIMARY KEY (gasstationid)
)

/*
3 rows from gasstations table:
gasstationid	chainid	country	segment
44	13	CZE	Value for money
45	6	CZE	Premium
46	23	CZE	Other
*/


CREATE TABLE products (
	productid BIGINT NOT NULL, 
	description TEXT, 
	CONSTRAINT idx_47297_products_pkey PRIMARY KEY (productid)
)

/*
3 rows from products table:
productid	description
1	Rucní zadání
2	Nafta
3	Special
*/


CREATE TABLE transactions_1k (
	transactionid BIGSERIAL NOT NULL, 
	date DATE, 
	time TEXT, 
	customerid BIGINT, 
	cardid BIGINT, 
	gasstationid BIGINT, 
	productid BIGINT, 
	amount

In [91]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain.callbacks.base import BaseCallbackHandler


class SQLHandler(BaseCallbackHandler):
    def __init__(self):
        self.sql_result = []

    def on_agent_action(self, action, **kwargs):
        """Run on agent action. if the tool being used is sql_db_query,
         it means we're submitting the sql and we can 
         record it as the final sql"""

        if action.tool in ["sql_db_query_checker","sql_db_query"]:
            self.sql_result.append(action.tool_input)

agent_executor = create_sql_agent(azure_llm, db=pg_db, agent_type="openai-tools", verbose=False)

def execute_sql_agent(user_query):
    handler = SQLHandler()
    try:
        response = agent_executor.invoke({'input': user_query}, {'callbacks': [handler]})
        sql_queries = handler.sql_result
        return sql_queries
    except Exception as e:
        return None

eval['langchain_result'] = eval['question'].progress_apply(execute_sql_agent)


  0%|          | 0/30 [00:00<?, ?it/s]

100%|██████████| 30/30 [03:17<00:00,  6.58s/it]


In [100]:
eval['langchain_query'] = eval['langchain_result'].apply(lambda x: x[-1].get('query') if x else None)
eval['langchain_execution'] = eval.progress_apply(lambda row: SQLExecutionAccuracy(sql_client, openai_client)(question=row['question'], 
                                                                                                            query=row['langchain_query'],
                                                                                                            expected_result=row['ground_truth_result']), axis=1)
execution_accuracy = eval['langchain_execution'].apply(lambda x: json.loads(x).get('is_correct') if x else None).mean()
print("Langchain SQL Execution Accuracy:", execution_accuracy)

100%|██████████| 30/30 [00:32<00:00,  1.08s/it]

Langchain SQL Execution Accuracy: 0.27586206896551724





In [96]:
eval['langchain_response'] = eval['langchain_query'].progress_apply(run_query)

100%|██████████| 30/30 [00:00<00:00, 32.85it/s]


In [98]:
eval['langchain_response']

0             [{"eur_count": 2002, "czk_count": 30459}]
1      [{"customerid": 6056, "total_consumption": 0.0}]
2                                                    []
3     [{"currency": "CZK", "total_gas_consumption": ...
4     [{"year": 2012.0, "total_consumption_czk": 175...
5     [{"date": "201304", "peak_consumption": 786530...
6                                                    []
7     [{"segment": "KAM", "year": 2012.0, "total_amo...
8                         [{"total_consumption": null}]
9     [{"country": "SVK", "discount_gas_stations": 5...
10            [{"czk_count": 25134, "eur_count": 1629}]
11    [{"percent_consumed_more_than_46_73": 81.33333...
12      Query must be a string unless using sqlalchemy.
13                  [{"highest_consumption": 445279.7}]
14                                                   []
15                                                   []
16    [{"customer_count": 1}, {"customer_count": 1},...
17    [{"description": "Dalnic.popl."}, {"descri

[{'name': 'public.products',
  'description': 'The products table contains information regarding various products offered within the database. Each product is uniquely identified by a Product ID, which serves as the primary key. The table also includes descriptions of these products, providing specific information that helps in identifying and differentiating them. This table is essential for linking product-related information across different tables in the database.'},
 {'name': 'public.customers',
  'description': 'The customers table contains information about individual customers, including a unique identifier, their client segment classification, and the currency they use for transactions. This table helps in distinguishing customers and categorizes them into different groups such as Key Account Management (KAM), Large Account Management (LAM), and Small and Medium Enterprises (SME). The currency column indicates the currency codes like CZK and EUR used by customers in their tran