In [1]:
# --- Core Libraries ---
import os
import random
import json
import pandas as pd
import numpy as np
import torch
import time
from tqdm import tqdm
import re

# --- Hugging Face: Dataset, Tokenizer, Model ---
from datasets import load_dataset, DatasetDict, Dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer, 
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    pipeline
)

# --- LoRA & Parameter-Efficient Tuning ---
from peft import LoraConfig, get_peft_model, TaskType, PeftModel

# --- W&B Experiment Tracking ---
import wandb

# --- SQL Evaluation ---
import sqlite3
import sqlparse
from tabulate import tabulate
import evaluate  # for BLEU, ROUGE
import nltk
nltk.download('wordnet')
nltk.download('punkt')

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\sidpk\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\sidpk\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
torch.cuda.empty_cache()

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    print("GPU not detected — will fall back to CPU.")

PyTorch version: 2.5.1+cu121
CUDA available: True
Using GPU: NVIDIA GeForce RTX 4050 Laptop GPU


In [3]:
# Load dataset
dataset = load_dataset("Clinton/Text-to-SQL-v1")
shuffled_dataset = dataset.shuffle(seed=42)

df = pd.DataFrame(shuffled_dataset["train"])
df.sample(5)

Unnamed: 0,instruction,input,response,source,text
121368,How many players did Boston Red Stockings have...,"CREATE TABLE college (\n college_id text,\n...",SELECT COUNT(*) FROM salary AS T1 JOIN team AS...,spider,Below are sql tables schemas paired with instr...
41538,"What is Label, when Format is Double CD?",CREATE TABLE table_name_33 (\n label VARCHA...,SELECT label FROM table_name_33 WHERE format =...,sql_create_context,Below are sql tables schemas paired with instr...
42828,What's the number of poles in the season where...,CREATE TABLE table_20016339_1 (\n poles VAR...,SELECT poles FROM table_20016339_1 WHERE motor...,sql_create_context,Below are sql tables schemas paired with instr...
4677,How many jockeys are listed running at the Sir...,"CREATE TABLE table_19755 (\n ""Result"" text,...","SELECT COUNT(""Jockey"") FROM table_19755 WHERE ...",wikisql,Below are sql tables schemas paired with instr...
261936,"Show me a bar chart, that simply displays the ...",CREATE TABLE locations (\n LOCATION_ID deci...,"SELECT LAST_NAME, LOCATION_ID FROM employees A...",nvbench,Below are sql tables schemas paired with instr...


In [4]:
df_clean = df[df["instruction"] != ""].reset_index(drop=True)
print(f"Filtered dataset size: {len(df_clean)}")

Filtered dataset size: 262206


In [5]:
formatted_dataset = Dataset.from_pandas(df_clean[["text"]])
formatted_dataset = formatted_dataset.train_test_split(test_size=0.1, seed=42)

print(formatted_dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 235985
    })
    test: Dataset({
        features: ['text'],
        num_rows: 26221
    })
})


In [6]:
# Load Tokenizer

model_name = "deepseek-ai/deepseek-coder-1.3b-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

In [7]:
#Smart Padding
def tokenize(examples):
    input_ids_list = []
    attention_mask_list = []
    labels_list = []
    
    max_length = 4096

    for full_text in examples["text"]:
        # Extract prompt and response
        prompt_text = full_text.split("### Response:")[0].strip() + "\n### Response:\n"
        response_text = full_text.split("### Response:")[1].strip()
        
        # Tokenize with truncation
        prompt_tokens = tokenizer(prompt_text, truncation=True, max_length=max_length)["input_ids"]
        response_tokens = tokenizer(response_text, truncation=True, max_length=max_length)["input_ids"]
        response_tokens.append(tokenizer.eos_token_id)
        
        # Combine tokens for input
        input_ids = prompt_tokens + response_tokens
        attention_mask = [1] * len(input_ids)
        
        # Create labels - keep prompt tokens, mask response tokens
        labels = input_ids.copy()  # Start with full sequence
        labels = [-100] * len(prompt_tokens) + response_tokens #mask prompt tokens

        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        labels_list.append(labels)

    return {
        "input_ids": input_ids_list,
        "attention_mask": attention_mask_list,
        "labels": labels_list
    }

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # because this is causal LM
    pad_to_multiple_of=16  # speeds up training on GPU
)

In [8]:
import sqlite3
import re

def fix_missing_semicolons(sql_code):
    """
    Inserts semicolons between multiple CREATE TABLE statements if missing.
    Looks for patterns like `) CREATE TABLE` and adds a semicolon between them.
    """
    return re.sub(r'\)\s*(?=CREATE TABLE)', r');\n', sql_code.strip())

def can_execute_sql(generated_sql, schema=None, verbose=True):
    """
    Check if a SQL query or script can be executed against a given schema.

    Args:
        generated_sql (str): The SQL query or script to test.
        schema (str, optional): The database schema to create before testing.
        verbose (bool, optional): Whether to print detailed errors.

    Returns:
        tuple: (bool, str) - (success status, message or error)
    """
    conn = None
    try:
        conn = sqlite3.connect(":memory:")
        cursor = conn.cursor()

        # Create schema if provided
        if schema:
            try:
                schema = fix_missing_semicolons(schema)
                cursor.executescript(schema)
                conn.commit()
            except sqlite3.Error as e:
                if verbose:
                    print("Schema execution failed.")
                    print("Error:", e)
                return False

        # Execute the query or script
        try:
            if ';' in generated_sql.strip().rstrip(';'):
                cursor.executescript(generated_sql)
                return True
            else:
                cursor.execute(generated_sql)
                return True
        except sqlite3.Error as e:
            if verbose:
                print("Query execution failed.")
                print("Error:", e)
            return False

    except Exception as e:
        if verbose:
            print("General error.")
            print("Error:", e)
        return False

    finally:
        if conn:
            conn.close()

In [14]:
#computing the metrics for the baseline model based on similarilty of output, sql compilation and time

# Load metrics
meteor_metric = evaluate.load("meteor")

def extract_sql_from_output(output_text, prompt_text):
    """Extract SQL query from model output, handling various formats."""
    # Remove the prompt from the output
    sql_text = output_text[len(prompt_text):].strip()
    
    # Remove any markdown code blocks if present
    sql_text = re.sub(r'```sql\s*|\s*```', '', sql_text)
    sql_text = re.sub(r'```\s*|\s*```', '', sql_text)
    
    # Remove any trailing text after semicolon
    if ';' in sql_text:
        sql_text = sql_text.split(';')[0] + ';'
    
    return sql_text.strip()

def evaluate_model_on_dataset(
    model,
    tokenizer,
    dataset,
    max_new_tokens=2048
):
    predictions = []
    references = []
    compile_success = 0
    execution_times = []

    dataset_slice = dataset

    for example in tqdm(dataset_slice, desc="Evaluating"):
        # Extract prompt and response using the same format as tokenize function
        prompt_text = example["text"].split("### Response:")[0].strip() + "\n### Response:\n"
        ground_truth = example["text"].split("### Response:")[1].strip()
        schema = example["text"].split("### Input:")[1].split("### Response:")[0].strip()

        inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                eos_token_id=tokenizer.eos_token_id,
                max_new_tokens=2048,
                pad_token_id=tokenizer.eos_token_id
                )
        
        # Get the generated SQL - everything after the prompt
        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        #generated_sql = extract_sql_from_output(decoded, prompt_text)
        generated_sql = decoded.split("### Response:")[-1].strip().split("###")[0]
        print("SQL Output:", generated_sql)

        # Add prediction for METEOR
        predictions.append(generated_sql)
        references.append([ground_truth])  # METEOR expects references as a list of lists

        # Compile SQL Query and measure time
        start_time = time.perf_counter()
        success = can_execute_sql(generated_sql, schema)
        end_time = time.perf_counter()

        if success:
            compile_success += 1
            execution_times.append(end_time - start_time)

    # Compute metrics
    meteor_score = meteor_metric.compute(predictions=predictions, references=references)["meteor"]
    sql_compilation_rate = compile_success / len(dataset_slice)
    
    # Calculate average execution time for successful queries
    avg_execution_time = sum(execution_times) / len(execution_times) if execution_times else 0

    metrics = {
        "meteor_score": round(meteor_score, 4),
        "sql_compilation_rate": round(sql_compilation_rate, 4),
        "avg_execution_time_ms": round(avg_execution_time * 1000, 2),  # Convert to milliseconds
        "num_eval_samples": len(dataset_slice),
        "num_successful_queries": compile_success
    }

    return metrics

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\sidpk\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\sidpk\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\sidpk\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [10]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

In [11]:
prompt = """	
Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What model has a launch of September 3, 2010? ### Input: CREATE TABLE table_28269 (
"Model" text,
"Launch" text,
"Code name" text,
"Transistors (million)" real,
"Die size (mm 2 )" real,
"Bus interface" text,
"Memory ( MB )" text,
"SM count" real,
"Core config 1,3" text,
"Core ( MHz )" real,
"Shader ( MHz )" real,
"Memory ( MHz )" text,
"Pixel ( GP /s)" text,
"Texture ( GT /s)" text,
"Bandwidth ( GB /s)" text,
"DRAM type" text,
"Bus width ( bit )" real,
"GFLOPS (FMA) 2" text,
"TDP (watts)" real,
"Release price (USD)" text
) ### Response:
"""

In [12]:
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = base_model.generate(
                **inputs,
                eos_token_id=tokenizer.eos_token_id,
                max_new_tokens=2048,
                pad_token_id=tokenizer.eos_token_id
                )

generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the response part (everything after "### Response:")
generated_sql = generated_sql.split("### Response:")[-1].strip()
print(generated_sql)




In [15]:
# Evaluate base model performance
print("Evaluating base model...")

# Create a test set
test_samples = formatted_dataset["test"].select(range(10))  # Using 10 samples for evaluation

# Evaluate base model
base_metrics = evaluate_model_on_dataset(
    model=base_model,  # Base model
    tokenizer=tokenizer,
    dataset=test_samples,
    max_new_tokens=256
)

# Print metrics
print("\nBase Model Performance:")
print(f"{'Metric':<25} {'Value':<15}")
print("-" * 40)

for metric in ['meteor_score', 'sql_compilation_rate', 'avg_execution_time_ms']:
    value = base_metrics[metric]
    print(f"{metric:<25} {value:<15.4f}")

print(f"\nNumber of samples evaluated: {base_metrics['num_eval_samples']}")
print(f"Number of successful queries: {base_metrics['num_successful_queries']}")

Evaluating base model...


Evaluating:  10%|█         | 1/10 [02:14<20:06, 134.10s/it]

SQL Output: ```sql
SELECT high_points, game
FROM table_13464416_27
WHERE game = 7
```


Query execution failed.
Error: near "```sql
SELECT high_points, game
FROM table_13464416_27
WHERE game = 7
```": syntax error


Evaluating:  20%|██        | 2/10 [04:48<19:29, 146.16s/it]

SQL Output: ```sql
SELECT COUNT(*)
FROM table_1007
WHERE Class = 'shasta h.s.' AND Position = 'guard' AND Height = '6\'11"' AND Weight = '200 lbs' AND Games_started = '10' AND Player = 'guard' AND Player = 'guard' AND Player = 'guard' AND Player = 'guard' AND Player = 'guard'
```


Query execution failed.
Error: near "```sql
SELECT COUNT(*)
FROM table_1007
WHERE Class = 'shasta h.s.' AND Position = 'guard' AND Height = '6\'11"' AND Weight = '200 lbs' AND Games_started = '10' AND Player = 'guard' AND Player = 'guard' AND Player = 'guard' AND Player = 'guard' AND Player = 'guard'
```": syntax error


Evaluating:  30%|███       | 3/10 [07:19<17:16, 148.08s/it]

SQL Output: SELECT
    icd9_code,
    short_title,
    long_title
FROM
    d_icd_procedures
WHERE
    icd9_code IN (
        SELECT
            icd9_code
        FROM
            procedures_icd
        GROUP BY
            icd9_code
        HAVING
            COUNT(*) > 1
    )

SELECT
    icd9_code,
    short_title,
    long_title
FROM
    d_icd_diagnoses
WHERE
    icd9_code IN (
        SELECT
            icd9_code
        FROM
            diagnoses_icd
        GROUP BY
            icd9_code
        HAVING
            COUNT(*) > 1
    )

SELECT
    icd9_code,
    short_title,
    long_title
FROM
    d_icd_diagnoses
WHERE
    icd9_code IN (
        SELECT
            icd9_code
        FROM
            diagnoses_icd
        GROUP BY
            icd9_code
        HAVING
            COUNT(*) > 1
    )

SELECT
    icd9_code,
    short_title,
    long_title
FROM
    d_icd_diagnoses
WHERE
    icd9_code IN (
        SELECT
            icd9_code
        FROM
            diagnoses_icd
        

Evaluating:  40%|████      | 4/10 [09:55<15:07, 151.29s/it]

SQL Output: ```sql
SELECT * FROM table_10263 WHERE "Player" = "Jordan Browne"
```


Query execution failed.
Error: near "```sql
SELECT * FROM table_10263 WHERE "Player" = "Jordan Browne"
```": syntax error


Evaluating:  50%|█████     | 5/10 [12:17<12:20, 148.10s/it]

SQL Output: ```sql
INSERT INTO table_1788 (Country, Population (2011), GDP (nominal) (2010, US$ millions), Military expenditures (2011, US$ millions), Military expenditures (2011, % of GDP), Defence expenditures, (2011, per capita), Deployable military (2011, thousands))
VALUES
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania", 11.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5),
    ("Romania

Evaluating:  60%|██████    | 6/10 [14:46<09:53, 148.26s/it]

SQL Output: 


Evaluating:  70%|███████   | 7/10 [17:16<07:26, 148.74s/it]

SQL Output: ```sql
SELECT table_8208.Label
FROM table_8208
INNER JOIN table_8209
ON table_8208.Label = table_8209.Label
WHERE table_8208.Label = "Original CD"
```


Query execution failed.
Error: near "```sql
SELECT table_8208.Label
FROM table_8208
INNER JOIN table_8209
ON table_8208.Label = table_8209.Label
WHERE table_8208.Label = "Original CD"
```": syntax error


Evaluating:  80%|████████  | 8/10 [19:52<05:02, 151.31s/it]

SQL Output: ```sql
INSERT INTO table_26425 (Stage (Winner), General classification, Sprint Classification, Mountains Classification, Youth Classification, Aggressive Rider, Team Classification)
VALUES ('1st', '1st', '1st', '1st', '1st', '1st', '1st');
```
Query execution failed.
Error: near "```sql
INSERT INTO table_26425 (Stage (Winner), General classification, Sprint Classification, Mountains Classification, Youth Classification, Aggressive Rider, Team Classification)
VALUES ('1st', '1st', '1st', '1st', '1st', '1st', '1st');
```": syntax error


Evaluating:  90%|█████████ | 9/10 [22:20<02:30, 150.03s/it]

SQL Output: SELECT
    airport.Country,
    COUNT(*)
FROM
    flight
    INNER JOIN airport ON flight.airport_id = airport.id
GROUP BY
    airport.Country
ORDER BY
    COUNT(*) DESC

SELECT
    operate_company.name,
    COUNT(*)
FROM
    flight
    INNER JOIN operate_company ON flight.company_id = operate_company.id
GROUP BY
    operate_company.name
ORDER BY
    COUNT(*) DESC

SELECT
    operate_company.name,
    COUNT(*)
FROM
    flight
    INNER JOIN operate_company ON flight.company_id = operate_company.id
GROUP BY
    operate_company.name
ORDER BY
    COUNT(*) DESC

SELECT
    operate_company.name,
    COUNT(*)
FROM
    flight
    INNER JOIN operate_company ON flight.company_id = operate_company.id
GROUP BY
    operate_company.name
ORDER BY
    COUNT(*) DESC

SELECT
    operate_company.name,
    COUNT(*)
FROM
    flight
    INNER JOIN operate_company ON flight.company_id = operate_company.id
GROUP BY
    operate_company.name
ORDER BY
    COUNT(*) DESC

SELECT
    operate_company.na

Evaluating: 100%|██████████| 10/10 [25:05<00:00, 150.58s/it]

SQL Output: ```sql
INSERT INTO table_22669044_8 (record, location_attendance)
VALUES ('18,838', 'united center');
```
Query execution failed.
Error: near "```sql
INSERT INTO table_22669044_8 (record, location_attendance)
VALUES ('18,838', 'united center');
```": syntax error






Base Model Performance:
Metric                    Value          
----------------------------------------
meteor_score              0.2901         
sql_compilation_rate      0.1000         
avg_execution_time_ms     4.0100         

Number of samples evaluated: 10
Number of successful queries: 1
