In [1]:
import os
import re
import glob
import sqlite3
import pandas as pd
from typing import Dict, List, Tuple, Any, Optional, Union, Callable
from collections import defaultdict
from sklearn.model_selection import train_test_split
import sqlparse
import json
from datetime import datetime
import logging
from tqdm import tqdm

ROOT_PATH = '/Users/sinabehnam/Desktop/Projects/Polito/Thesis/MA_text2SQL/'

import sys
sys.path.append(ROOT_PATH + 'DataSampling/src/models/pipeline')

In [2]:
def read_api_key(file_path: str) -> str:
    """
    Reads the API key from a file.
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"API key file not found: {file_path}")
    
    with open(file_path, 'r') as file:
        api_key = file.read().strip()
    
    if not api_key:
        raise ValueError("API key is empty.")
    
    return api_key

In [3]:
BIRD_DATA_PATH = ROOT_PATH + 'DataSampling/data/enriched_dataset/v2/bird_set_stratified'
SPIDER_DATA_PATH = ROOT_PATH + 'DataSampling/data/enriched_dataset/v2/spider_set_stratified'
SPIDER2_DATA_PATH = ROOT_PATH + 'DataSampling/data/enriched_dataset/v2/spider2_lite_set'

model_configs = [
    # Anthropic Claude with extended thinking
    {
        "type": "anthropic",
        "name": "claude-3-7-sonnet-20250219",
        "api_key": read_api_key(ROOT_PATH + 'Data/Auth/anthropic.api.key/text2sql.key'),
        "extended_thinking": True,
    },
    # Together.ai model
    {
        "type": "together_ai",
        "name": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
        "api_key": read_api_key(ROOT_PATH + 'Data/Auth/together.ai.api.key/API.key'),
    },
    # Together.ai model DeepSeek R1 
    {
        "type" : "together_ai",
        "name" : "deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
        "api_key": read_api_key(ROOT_PATH + 'Data/Auth/together.ai.api.key/API.key'),
    }
]

# Loading Data

1. **Load each dataset**
2. **forming a uniform dataset**
3. **split into train and test sets**
4. **Grouping the data instances by their schemas**: We are grouping beased on the schema, because we want to generate the prompts that all questions for similar schemas are grouped together to avoid redundancy and reduce the token count.

In [4]:
def load_data(bird_path: str = BIRD_DATA_PATH, 
             spider_path: str = SPIDER_DATA_PATH,
             spider2_path: str = SPIDER2_DATA_PATH) -> Tuple[List[Tuple[Dict, str]], List[Tuple[Dict, str]]]:
        """Load data from BIRD and SPIDER datasets"""
        bird_data = _load_json_files(bird_path)
        spider_data = _load_json_files(spider_path)
        # spider2
        spider2_data = _load_json_files(spider2_path)
        
        all_data = bird_data + spider_data + spider2_data
        
        print(f"Total data points: {len(all_data)}")
        print(f"Bird data points: {len(bird_data)}")
        print(f"Spider data points: {len(spider_data)}")
        print(f"Spider2 data points: {len(spider2_data)}")

        return all_data

def _load_json_files(dir_path: str) -> List[Tuple[Dict, str]]:
    """Load all JSON files from a directory"""
    data = []
    for filepath in glob.glob(os.path.join(dir_path, 'instance_*.json')):
        with open(filepath, 'r') as file:
            json_data = json.load(file)
            data.append((json_data, filepath))
    return data

def _group_instances_by_schema(train_data) -> Dict[str, List[Tuple[Dict, str]]]:
        """Group instances by their database schema"""
        schema_groups = defaultdict(list)
        
        for instance_data, file_path in train_data:
            # Create a unique key for the database schema
            db_name = instance_data['database']['name']
            dataset_type = instance_data['dataset']
            database_type = instance_data['database'].get('type', 'unknown')
            schema_key = f"{dataset_type}_{database_type}_{db_name}"
            
            schema_groups[schema_key].append((instance_data, file_path))
        
        return dict(schema_groups)

In [5]:
# reading the snowflake credentials
SNOWFLAKE_CREDENTIALS_PATH = ROOT_PATH + 'Data/Spider2/spider2-lite/evaluation_suite/snowflake_credential.json'
with open(SNOWFLAKE_CREDENTIALS_PATH, 'r') as file:
    snowflake_credentials = json.load(file)

# Inferencing

In [None]:
train_data = load_data()
train_data = _group_instances_by_schema(train_data)

df = pd.DataFrame(train_data.items(), columns=['schema_key', 'instances'])

df.head()

Total data points: 604
Bird data points: 250
Spider data points: 250
Spider2 data points: 104
Training data points: 483
Testing data points: 121


Unnamed: 0,schema_key,instances
0,bird_sqlite_student_club,"[({'id': 1318, 'dataset': 'bird', 'database': ..."
1,spider_sqlite_cre_Doc_Template_Mgt,"[({'id': 369, 'dataset': 'spider', 'database':..."
2,spider2-lite_snowflake_IDC,"[({'id': 271, 'original_instance_id': 'sf_bq34..."
3,bird_sqlite_formula_1,"[({'id': 864, 'dataset': 'bird', 'database': {..."
4,spider_sqlite_world_1,"[({'id': 772, 'dataset': 'spider', 'database':..."


In [6]:
print("====== The number of instances in each schema group ======")
description = df.apply(lambda x: len(x['instances']), axis=1).describe()
print("The number of total Databases accross the training sets : ",description['count'])
print("The number of average instances per database accross the training sets : ",description['mean'])

The number of total Databases accross the training sets :  75.0
The number of average instances per database accross the training sets :  6.44


In [8]:
# a Sample data
sample_instance,sample_instance_path = train_data['spider2-lite_sqlite_EntertainmentAgency'][0]

In [9]:
from pipeline.text2sql_enricher import OptimizedText2SQLPipeline

print("This Model is :", model_configs[0]['name'])

pipeline = OptimizedText2SQLPipeline(model_config=model_configs[0],
                                     snowflake_config=snowflake_credentials)

sample_schema_intro_prompt = pipeline._create_schema_introduction_prompt(sample_instance)

print("Sample Schema Introduction Prompt:")
print(sample_schema_intro_prompt)

  warn_incompatible_dep(
INFO:pipeline.text2sql_enricher:Initializing OptimizedText2SQLPipeline...
INFO:pipeline.text2sql_enricher:Schema understanding logging enabled. Logs will be saved to: schema_understanding_logs/schema_understanding_20250606_230308.log


This Model is : claude-3-7-sonnet-20250219
Sample Schema Introduction Prompt:
You are now working with the "EntertainmentAgency" database. 

Here's the complete database schema:

## Table: Agents
```sql
CREATE TABLE Agents (
    AgentID INT,
    AgtFirstName nvarchar (25),
    AgtLastName nvarchar (25),
    AgtStreetAddress nvarchar (50),
    AgtCity nvarchar (30),
    AgtState nvarchar (2),
    AgtZipCode nvarchar (10),
    AgtPhoneNumber nvarchar (15),
    DateHired date,
    Salary decimal(15, 2),
    CommissionRate float(24)
);
```

## Table: Customers
```sql
CREATE TABLE Customers (
    CustomerID INT,
    CustFirstName nvarchar (25),
    CustLastName nvarchar (25),
    CustStreetAddress nvarchar (50),
    CustCity nvarchar (30),
    CustState nvarchar (2),
    CustZipCode nvarchar (10),
    CustPhoneNumber nvarchar (15)
);
```

## Table: Engagements
```sql
CREATE TABLE Engagements (
    EngagementNumber INT,
    StartDate date,
    EndDate date,
    StartTime time,
    StopTime t

In [10]:
from pipeline.schemahandler import SequentialSchemaHandler

pipeline.model_provider.start_new_conversation()

schema_handler = SequentialSchemaHandler(pipeline.model_provider,
                                         max_tokens_per_chunk=4000,
                                         token_threshold=6000,
                                         nlp_model='en_core_web_sm')

system_message = (
                "You are a database expert specializing in SQL query generation. "
                "You will be working with a specific database schema and answering "
                "multiple questions about it. Please pay careful attention to the "
                "schema structure and relationships between tables."
            )

_,final_response = schema_handler.handle_large_schema(sample_instance, system_message=system_message)


In [11]:
print("Final Response from Schema Handler:")
print(final_response)

Final Response from Schema Handler:



In [12]:
# get tables names from sample_instance['schemas']
table_names = [table['table_name'] for table in sample_instance['schemas']]
print("Tables in the sample instance:")
for table in table_names:
    print(f"- {table}")

Tables in the sample instance:
- Agents
- Customers
- Engagements
- Entertainer_Members
- Entertainer_Styles
- Entertainers
- Members
- Musical_Preferences
- Musical_Styles
- ztblDays
- ztblMonths
- ztblSkipLabels
- ztblWeeks


> The above results comparing between the tables list of model understanding from the whole schema based on `token_windowing` strategy over large database schemas. Where notably, there are some tables that either not included in the schema or not used in the model understanding. This is because the model missed some tables as it the number of tokens per each table schemas is larger than the token limit of the model. 


In [13]:
question_prompt = pipeline._create_question_prompt(sample_instance)

print(50 * "=")
print("The question prompt was:")
print(question_prompt)
print(50 * "=")
raw_response = pipeline.model_provider.generate_with_context("", question_prompt)
generated_sql = pipeline.extract_sql_query_from_text(raw_response)
print("Generated SQL Query:")
print(generated_sql)

The question prompt was:
Question: Could you list each musical style with the number of times it appears as a 1st, 2nd, or 3rd preference in a single row per style?

Please generate the SQL query to answer this question using the database schema we discussed.


INFO:httpx:HTTP Request: POST https://api.anthropic.com/v1/messages "HTTP/1.1 200 OK"


Generated SQL Query:
SELECT 
    style_name,
    COUNT(CASE WHEN preference_rank = 1 THEN 1 END) AS first_preference,
    COUNT(CASE WHEN preference_rank = 2 THEN 1 END) AS second_preference,
    COUNT(CASE WHEN preference_rank = 3 THEN 1 END) AS third_preference
FROM musical_preferences
GROUP BY style_name
ORDER BY style_name;


In [14]:
conn,_  = pipeline.get_db_connection(sample_instance,sample_instance_path )

cursor = conn.cursor()
try:
    cursor.execute(sample_instance['sql'])
    results = cursor.fetchall()
    print("Query executed successfully. Results:")
    print(set(results))
except Exception as e:
    print(f"Error executing query: {e}")

Query executed successfully. Results:
{('Rhythm and Blues', 2, 0, 1), ('Folk', 0, 1, 0), ('Country Rock', 1, 0, 0), ('Top 40 Hits', 2, 0, 0), ('Modern Rock', 0, 1, 1), ('Variety', 1, 0, 0), ('Classic Rock & Roll', 1, 1, 0), ('Standards', 2, 2, 0), ('Motown', 0, 1, 0), ("70's Music", 0, 1, 0), ('Chamber Music', 1, 0, 0), ('Jazz', 2, 1, 0), ('Show Tunes', 1, 1, 0), ('Contemporary', 1, 2, 0), ("60's Music", 1, 0, 0), ("80's Music", 0, 0, 1), ("40's Ballroom Music", 0, 1, 1), ('Country', 0, 1, 0), ('Salsa', 0, 1, 1), ('Classical', 0, 1, 1)}


In [15]:
# close the past cursor and connection
cursor.close()
conn.close()

In [19]:
bird_sqlite_sample,path = train_data['bird_sqlite_student_club'][0]

conn, _ = pipeline.get_db_connection(bird_sqlite_sample, path)

cursor = conn.cursor()
try:
    cursor.execute(bird_sqlite_sample['sql'])
    results = cursor.fetchall()
    print("Query executed successfully. Results:")
    print(set(results))
except Exception as e:
    print(f"Error executing query: {e}")

Query executed successfully. Results:
{('Yearly Kickoff',)}


In [None]:
from pipeline.text2sql_enricher import OptimizedText2SQLPipeline

print("This Model is :", model_configs[0]['name'])

pipeline = OptimizedText2SQLPipeline(model_config=model_configs[0],
                                     snowflake_config=snowflake_credentials)

output_dir = ROOT_PATH + 'DataSampling/data/enriched_dataset/enriched_v4'

results = pipeline.run_pipeline(
    schema_groups=train_data,
    save_updated_files=True,
    output_dir=output_dir,
)

# Store summary metrics
summary_metrics = {
    'num_evaluated': results['num_evaluated'],
    'num_with_prediction': results['num_with_prediction'],
    'prediction_rate': results['prediction_rate'],
    'execution_accuracy': results['execution_accuracy'],
    'exact_match_accuracy': results['exact_match_accuracy'],
    'semantic_equivalent_accuracy': results.get('semantic_equivalent_accuracy', 0.0),
    'model': results['model'],
    'optimization': results.get('optimization_used', 'conversational_schema_context')
}

In [17]:
summary_metrics

{'num_evaluated': 483,
 'num_with_prediction': 373,
 'prediction_rate': 0.772256728778468,
 'execution_accuracy': 0.46648793565683644,
 'exact_match_accuracy': 0.013404825737265416,
 'semantic_equivalent_accuracy': 0.49865951742627346,
 'model': {'model_name': 'claude-3-7-sonnet-20250219',
  'model_type': 'anthropic',
  'timestamp': '2025-06-06T23:04:59.084928',
  'optimization': 'conversational_schema_context'},
 'optimization': 'conversational_schema_context'}

# Process the missing data

In [None]:
output_dir = ROOT_PATH + 'DataSampling/data/enriched_dataset/enriched_v3'

In [None]:
bird_instances = output_dir + '/' + 'instance_bird_*.json'
spider_instances = output_dir + '/' + 'instance_spider_*.json'
spider2_instances = output_dir + '/' + 'instance_spider2-lite_*.json'

# ! 
v2_llama3 = ROOT_PATH + 'DataSampling/data/enriched_dataset/v2_Llama3'

bird_set_dir = v2_llama3 + '/bird_set_stratified'
spider_set_dir = v2_llama3 + '/spider_set_stratified'
spider2_set_dir = v2_llama3 + '/spider2_lite_set'

# copy the bird_instances to the v2_llama3
#/bird_set_dir
import shutil

def copy_files(src_pattern: str, dest_dir: str):
    for src_file in glob.glob(src_pattern):
        shutil.copy(src_file, dest_dir)

# copy_files(bird_instances, bird_set_dir)
# copy_files(spider_instances, spider_set_dir)
# copy_files(spider2_instances, spider2_set_dir)

In [6]:
v2_llama3 = ROOT_PATH + 'DataSampling/data/enriched_dataset/v2_Llama3'

bird_set_dir = v2_llama3 + '/bird_set_stratified'
spider_set_dir = v2_llama3 + '/spider_set_stratified'
spider2_set_dir = v2_llama3 + '/spider2_lite_set'

output_generated_data = load_data(bird_path=bird_set_dir,
                                    spider_path=spider_set_dir,
                                    spider2_path=spider2_set_dir)

Total data points: 604
Bird data points: 250
Spider data points: 250
Spider2 data points: 104


In [7]:
missing_sql_data = [item for item in output_generated_data if 'inference_results' not in item[0] or item[0]['inference_results']['has_prediction'] == False]

print(f"Total instances with missing SQL queries: {len(missing_sql_data)}")

missing_grouped = _group_instances_by_schema(missing_sql_data)

missing_df = pd.DataFrame(missing_grouped.items(), columns=['schema_key', 'instances'])

missing_df.head()

Total instances with missing SQL queries: 17


Unnamed: 0,schema_key,instances
0,spider2-lite_snowflake_GEO_OPENSTREETMAP_WORLDPOP,"[({'id': 101, 'original_instance_id': 'sf_bq25..."
1,spider2-lite_snowflake_TCGA_MITELMAN,"[({'id': 306, 'original_instance_id': 'sf_bq16..."
2,spider2-lite_snowflake_TCGA,"[({'id': 283, 'original_instance_id': 'sf_bq04..."
3,spider2-lite_sqlite_modern_data,"[({'id': 442, 'original_instance_id': 'local06..."
4,spider2-lite_snowflake_HUMAN_GENOME_VARIANTS,"[({'id': 122, 'original_instance_id': 'sf_bq03..."


In [8]:
print("====== The number of instances in each schema group ======")
description = missing_df.apply(lambda x: len(x['instances']), axis=1).describe()
print("The number of total Databases accross the training sets : ",description['count'])
print("The number of average instances per database accross the training sets : ",description['mean'])

The number of total Databases accross the training sets :  13.0
The number of average instances per database accross the training sets :  1.3076923076923077


In [9]:
from pipeline.text2sql_enricher import OptimizedText2SQLPipeline

print("This Model is :", model_configs[1]['name'])

pipeline = OptimizedText2SQLPipeline(model_config=model_configs[1],
                                     snowflake_config=snowflake_credentials)

results = pipeline.run_pipeline(
    schema_groups=missing_grouped,
    save_updated_files=True
)

# Store summary metrics
summary_metrics = {
    'num_evaluated': results['num_evaluated'],
    'num_with_prediction': results['num_with_prediction'],
    'prediction_rate': results['prediction_rate'],
    'execution_accuracy': results['execution_accuracy'],
    'exact_match_accuracy': results['exact_match_accuracy'],
    'semantic_equivalent_accuracy': results.get('semantic_equivalent_accuracy', 0.0),
    'model': results['model'],
    'optimization': results.get('optimization_used', 'conversational_schema_context')
}

  warn_incompatible_dep(
  from .autonotebook import tqdm as notebook_tqdm
INFO:pipeline.text2sql_enricher:Initializing OptimizedText2SQLPipeline...
INFO:pipeline.text2sql_enricher:Schema understanding logging enabled. Logs will be saved to: schema_understanding_logs/schema_understanding_20250608_004322.log


This Model is : meta-llama/Llama-3.3-70B-Instruct-Turbo-Free


INFO:pipeline.text2sql_enricher:
=== Processing schema group: spider2-lite_snowflake_GEO_OPENSTREETMAP_WORLDPOP ===
INFO:pipeline.text2sql_enricher:Number of instances: 1
INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 200 OK"
INFO:pipeline.text2sql_enricher:Schema introduction completed successfully
Processing spider2-lite_snowflake_GEO_OPENSTREETMAP_WORLDPOP:   0%|          | 0/1 [00:00<?, ?instance/s]INFO:pipeline.text2sql_enricher:Processing instance 101...
INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 200 OK"
INFO:pipeline.text2sql_enricher:Failed to extract SQL from model response
INFO:pipeline.text2sql_enricher:--------------------------------------------------
Processing spider2-lite_snowflake_GEO_OPENSTREETMAP_WORLDPOP: 100%|██████████| 1/1 [00:07<00:00,  7.06s/instance]
INFO:pipeline.text2sql_enricher:
=== Processing schema group: spider2-lite_snowflake_TCGA_MITELMAN ===
INFO:pipeline.text2sql_enriche

Large schema detected. Sending in 9 chunks...


INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 200 OK"


Sent schema chunk 1/9


INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 422 Unprocessable Entity"


OpenAI format conversation error: Error code: 422 - {'id': 'nxDrRtB-4Yz4kd-94c3aba5af47ee6d', 'error': {'message': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8193. Given: 7878 `inputs` tokens and 1024 `max_new_tokens`', 'type': 'invalid_request_error', 'param': None, 'code': None}}


INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 429 Too Many Requests"
INFO:openai._base_client:Retrying request to /chat/completions in 6.000000 seconds
INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 422 Unprocessable Entity"
ERROR:pipeline.text2sql_enricher:Error processing schema group spider2-lite_snowflake_TCGA_MITELMAN: Error code: 422 - {'id': 'nxDrTu5-4Yz4kd-94c3abcfa878ed94', 'error': {'message': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8193. Given: 7612 `inputs` tokens and 1024 `max_new_tokens`', 'type': 'invalid_request_error', 'param': None, 'code': None}}
INFO:pipeline.text2sql_enricher:
=== Processing schema group: spider2-lite_snowflake_TCGA ===
INFO:pipeline.text2sql_enricher:Number of instances: 1


Large schema detected. Sending in 8 chunks...


INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 429 Too Many Requests"
INFO:openai._base_client:Retrying request to /chat/completions in 3.000000 seconds
INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 429 Too Many Requests"
INFO:openai._base_client:Retrying request to /chat/completions in 0.828335 seconds
INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 422 Unprocessable Entity"


OpenAI format conversation error: Error code: 422 - {'id': 'nxDrVwF-3NKUce-94c3abfbbe48ed94', 'error': {'message': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8193. Given: 7633 `inputs` tokens and 1024 `max_new_tokens`', 'type': 'invalid_request_error', 'param': None, 'code': None}}


INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 422 Unprocessable Entity"
ERROR:pipeline.text2sql_enricher:Error processing schema group spider2-lite_snowflake_TCGA: Error code: 422 - {'id': 'nxDrW6q-4Yz4kd-94c3abfe5960ed94', 'error': {'message': 'Input validation error: `inputs` tokens + `max_new_tokens` must be <= 8193. Given: 7631 `inputs` tokens and 1024 `max_new_tokens`', 'type': 'invalid_request_error', 'param': None, 'code': None}}
INFO:pipeline.text2sql_enricher:
=== Processing schema group: spider2-lite_sqlite_modern_data ===
INFO:pipeline.text2sql_enricher:Number of instances: 2
INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 429 Too Many Requests"
INFO:openai._base_client:Retrying request to /chat/completions in 9.000000 seconds
INFO:httpx:HTTP Request: POST https://api.together.xyz/v1/chat/completions "HTTP/1.1 200 OK"
INFO:pipeline.text2sql_enricher:Schema introduction completed successfully
Processi

In [10]:
summary_metrics

{'num_evaluated': 13,
 'num_with_prediction': 12,
 'prediction_rate': 0.9230769230769231,
 'execution_accuracy': 0.08333333333333333,
 'exact_match_accuracy': 0.0,
 'semantic_equivalent_accuracy': 0.08333333333333333,
 'model': {'model_name': 'meta-llama/Llama-3.3-70B-Instruct-Turbo-Free',
  'model_type': 'together_ai',
  'timestamp': '2025-06-08T00:43:22.382162',
  'optimization': 'conversational_schema_context'},
 'optimization': 'conversational_schema_context'}