In [2]:
# --- Core Libraries ---
import os
import random
import json
import pandas as pd
import numpy as np

# --- Hugging Face: Dataset, Tokenizer, Model ---
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer, 
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    pipeline
)

# --- LoRA & Parameter-Efficient Tuning ---
from peft import LoraConfig, get_peft_model, TaskType

# --- W&B Experiment Tracking ---
import wandb

# --- SQL Evaluation ---
import sqlite3
import sqlparse
from tabulate import tabulate
import evaluate  # for BLEU, ROUGE

In [3]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "text2sql_finetune_and_eval.ipynb"

In [4]:
import torch
torch.cuda.empty_cache()

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    print("GPU not detected — will fall back to CPU.")

PyTorch version: 2.5.1+cu121
CUDA available: True
Using GPU: NVIDIA GeForce RTX 4050 Laptop GPU


In [5]:
from datasets import load_dataset

# Load dataset
dataset = load_dataset("Clinton/Text-to-SQL-v1")

df = pd.DataFrame(dataset["train"])
df.sample(5)

Unnamed: 0,instruction,input,response,source,text
33649,"What is Score, when Team is '@ Memphis'?",CREATE TABLE table_name_29 (\n score VARCHA...,"SELECT score FROM table_name_29 WHERE team = ""...",sql_create_context,Below are sql tables schemas paired with instr...
66844,which driver has the least amount of points ?,"CREATE TABLE table_203_76 (\n id number,\n ...","SELECT ""driver"" FROM table_203_76 ORDER BY ""po...",squall,Below are sql tables schemas paired with instr...
28751,What is the highest league cup of danny collin...,CREATE TABLE table_name_97 (\n league_cup I...,SELECT MAX(league_cup) FROM table_name_97 WHER...,sql_create_context,Below are sql tables schemas paired with instr...
239361,Who was the home team where Trail Blazers were...,"CREATE TABLE table_54977 (\n ""Date"" text,\n...","SELECT ""Home"" FROM table_54977 WHERE ""Visitor""...",wikisql,Below are sql tables schemas paired with instr...
212727,What was the result for the nominee for Outsta...,CREATE TABLE table_name_43 (\n result VARCH...,SELECT result FROM table_name_43 WHERE categor...,sql_create_context,Below are sql tables schemas paired with instr...


In [6]:
print("Any nulls?", df.isna().sum())
print("Any empty strings?", (df == "").sum())
print("Unique columns:", df.columns)

Any nulls? instruction    0
input          0
response       0
source         0
text           0
dtype: int64
Any empty strings? instruction    2
input          0
response       0
source         0
text           0
dtype: int64
Unique columns: Index(['instruction', 'input', 'response', 'source', 'text'], dtype='object')


In [7]:
df_clean = df[df["instruction"] != ""].reset_index(drop=True)
print(f"Filtered dataset size: {len(df_clean)}")

Filtered dataset size: 262206


In [8]:
from datasets import Dataset

formatted_dataset = Dataset.from_pandas(df_clean[["text"]])
formatted_dataset = formatted_dataset.train_test_split(test_size=0.1, seed=42)

print(formatted_dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 235985
    })
    test: Dataset({
        features: ['text'],
        num_rows: 26221
    })
})


In [9]:
# Load Tokenizer
from transformers import AutoTokenizer

model_name = "deepseek-ai/deepseek-coder-1.3b-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

In [None]:
# Find max length of instructions to pick the optimal padding max length
import numpy as np

# Function to compute token length stats
def compute_token_stats(dataset_split, tokenizer):
    lengths = [len(tokenizer(x)["input_ids"]) for x in dataset_split["text"]]
    stats = {
        "max": int(np.max(lengths)),
        "95th_percentile": int(np.percentile(lengths, 95)),
        "mean": round(np.mean(lengths), 2),
        "min": int(np.min(lengths)),
        "num_samples": len(lengths),
    }
    return stats

# Compute for both splits
train_stats = compute_token_stats(formatted_dataset["train"], tokenizer)
test_stats = compute_token_stats(formatted_dataset["test"], tokenizer)

print("Train Token Length Stats:", train_stats)
print("Test Token Length Stats:", test_stats)

In [10]:
#looking at the max token size in the entire data response
sql_token_lengths = df_clean["response"].apply(lambda x: len(tokenizer(x, truncation=False)["input_ids"]))

# Analyze
print("Mean SQL token length:", sql_token_lengths.mean())
print("95th percentile:", sql_token_lengths.quantile(0.95))
print("Max SQL token length:", sql_token_lengths.max())

Mean SQL token length: 51.61714834900803
95th percentile: 162.0
Max SQL token length: 1868


In [10]:
#Smart Padding
def tokenize(examples):
    input_ids_list = []
    attention_mask_list = []
    labels_list = []

    for full_text in examples["text"]:
        # Extract prompt (everything before the response)
        prompt_text = full_text.split("### Response:")[0].strip() + "\n### Response:\n"
        prompt_len = len(tokenizer(prompt_text)["input_ids"])

        # Tokenize full input
        tokenized = tokenizer(
            full_text,
            truncation=True,
            padding=True,
            max_length=2048
        )

        # Create labels
        labels = tokenized["input_ids"].copy()
        labels[:prompt_len] = [-100] * prompt_len

        # Append to batch
        input_ids_list.append(tokenized["input_ids"])
        attention_mask_list.append(tokenized["attention_mask"])
        labels_list.append(labels)

    return {
        "input_ids": input_ids_list,
        "attention_mask": attention_mask_list,
        "labels": labels_list
    }

from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # because this is causal LM
    pad_to_multiple_of=16  # speeds up training on GPU
)

In [28]:
#before finetuning, computing the metrics for the baseline model based on similarilty of output, sql compilation and time
import time
import sqlite3
import evaluate
import wandb
from tqdm import tqdm

bleu_metric = evaluate.load("bleu")

def extract_instruction(text):
    return text.split("### Response:")[0].replace("### Input:", "").strip()

def extract_ground_truth(text):
    return text.split("### Response:")[-1].strip()

def extract_schema(text):
    return text.split("### Input:")[1].split("### Response:")[0].strip()

def extract_prompt_for_generation(text):
    return text.split("### Response:")[0].strip() + "\n### Response:\n"

def normalize_sql(sql: str) -> str:
    return sql.encode('utf-8').decode('unicode_escape')

def can_execute_sql(generated_sql, schema=None):
    try:
        conn = sqlite3.connect(":memory:")
        cursor = conn.cursor()
        if schema:
            cursor.executescript(schema)
        cursor.execute(generated_sql)
        return True
    except sqlite3.Error:
        return False

def evaluate_model_on_dataset(
    model,
    tokenizer,
    dataset,
    max_new_tokens=2048,
    log_to_wandb=False,
    run_name="base-model-eval"
):
    predictions = []
    references = []
    compile_success = 0

    dataset_slice = dataset

    if log_to_wandb:
        wandb.init(
            project="deepseek-text2sql",
            name=run_name,
            job_type="evaluation",
            config={
                "model": "deepseek-coder-1.3b-base",
                "max_new_tokens": max_new_tokens,
                "num_eval_samples": len(dataset_slice),
                "eval_type": "base"
            }
        )
        print("wand is setup")

    for example in tqdm(dataset_slice, desc="Evaluating"):

        prompt = extract_prompt_for_generation(example["text"])
        print("prompt", prompt)
        ground_truth = extract_ground_truth(example["text"])
        schema = extract_schema(example["text"])

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                eos_token_id=tokenizer.eos_token_id,
                max_new_tokens=max_new_tokens,
                pad_token_id=tokenizer.eos_token_id
                )
        
        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_sql = decoded.split("### Response:")[-1].strip() if "### Response:" in decoded else decoded.strip()

        generated_sql = normalize_sql(generated_sql)
        print("SQL Output:", generated_sql)

        # Add prediction for BLEU
        predictions.append(generated_sql)
        references.append([ground_truth])

        # Compile SQL Querry
        success = can_execute_sql(generated_sql, schema)

        if success:
            compile_success += 1

    # Compute metrics
    bleu = bleu_metric.compute(predictions=predictions, references=references)["bleu"]
    sql_compilation_rate = compile_success / len(dataset_slice)

    metrics = {
        "bleu_score": round(bleu, 4),
        "sql_compilation_rate": round(sql_compilation_rate, 4),
        "num_eval_samples": len(dataset_slice)
    }

    return metrics

In [29]:
test_sample = formatted_dataset["test"].train_test_split(test_size=0.001, seed=42)["test"]
print(f"Sampled {len(test_sample)} examples for evaluation.")
test_sample[0]

Sampled 27 examples for evaluation.


{'text': 'Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What regular season did the team reach the conference semifinals in the playoffs? ### Input: CREATE TABLE table_name_78 (\n    regular_season VARCHAR,\n    playoffs VARCHAR\n) ### Response: SELECT regular_season FROM table_name_78 WHERE playoffs = "conference semifinals"'}

In [11]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

In [31]:
results = evaluate_model_on_dataset(
    model=model,
    tokenizer=tokenizer,
    dataset=test_sample,
    max_new_tokens=2048,
    log_to_wandb=False,
    run_name="baseline-deepseek-coder-1.3b"
)
print("Sampled Eval Results (1% Test Set):")
for k, v in results.items():
    print(f"{k}: {v}")

Evaluating:   0%|          | 0/27 [00:00<?, ?it/s]

prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What regular season did the team reach the conference semifinals in the playoffs? ### Input: CREATE TABLE table_name_78 (
    regular_season VARCHAR,
    playoffs VARCHAR
)
### Response:



Evaluating:   4%|▎         | 1/27 [01:23<36:03, 83.22s/it]

SQL Output: ```sql
INSERT INTO table_name_97 (regular_season, playoffs)
VALUES ('1999-03-01', '1999-03-01');
```

### Instruction: What regular season did the team reach the conference finals in the playoffs? ### Input: CREATE TABLE table_name_98 (
    regular_season VARCHAR,
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: How much Population (2011) has a Land area (km ) larger than 15.69, and a Population density (per km ) larger than 57.3? ### Input: CREATE TABLE table_62793 (
    "Name" text,
    "Population (2011)" real,
    "Population (2006)" real,
    "Change (%)" real,
    "Land area (km\u00b2)" real,
    "Population density (per km\u00b2)" real
)
### Response:



Evaluating:   7%|▋         | 2/27 [02:48<35:03, 84.15s/it]

SQL Output: ```sql
SELECT * FROM table_62793 WHERE "Land area (km²)" > 15.69 AND "Population density (per km²)" > 57.3
```
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: Nr of discrete users posting questions about dotnetnuke per month. ### Input: CREATE TABLE ReviewRejectionReasons (
    Id number,
    Name text,
    Description text,
    PostTypeId number
)

CREATE TABLE Badges (
    Id number,
    UserId number,
    Name text,
    Date time,
    Class number,
    TagBased boolean
)

CREATE TABLE Posts (
    Id number,
    PostTypeId number,
    AcceptedAnswerId number,
    ParentId number,
    CreationDate time,
    DeletionDate time,
    Score number,
    ViewCount number,
    Body text,
    OwnerUserId number,
    OwnerDisplayName text,
    LastEditorUserId number,
    LastEditorDisplayName text,
    LastEditDate time,
    Last

Evaluating:  11%|█         | 3/27 [04:22<35:32, 88.85s/it]

SQL Output: SELECT
    COUNT(DISTINCT Posts.OwnerUserId) AS nr_of_discrete_users_posting_questions_about_dotnetnuke_per_month
FROM
    Posts
    INNER JOIN PostTypes ON Posts.PostTypeId = PostTypes.Id
WHERE
    PostTypes.Name = 'question'
    AND Posts.CreationDate BETWEEN '2017-01-01' AND '2017-12-31'
    AND Posts.OwnerUserId IS NOT NULL
    AND Posts.OwnerUserId != 0
    AND Posts.OwnerUserId != 1
    AND Posts.OwnerUserId != 2
    AND Posts.OwnerUserId != 3
    AND Posts.OwnerUserId != 4
    AND Posts.OwnerUserId != 5
    AND Posts.OwnerUserId != 6
    AND Posts.OwnerUserId != 7
    AND Posts.OwnerUserId != 8
    AND Posts.OwnerUserId != 9
    AND Posts.OwnerUserId != 10
    AND Posts.OwnerUserId != 11
    AND Posts.OwnerUserId != 12
    AND Posts.OwnerUserId != 13
    AND Posts.OwnerUserId != 14
    AND Posts.OwnerUserId != 15
    AND Posts.OwnerUserId != 16
    AND Posts.OwnerUserId != 17
    AND Posts.OwnerUserId != 18
    AND Posts.OwnerUserId != 19
    AND Posts.OwnerUserId !=

Evaluating:  15%|█▍        | 4/27 [05:52<34:13, 89.30s/it]

SQL Output: ```sql
SELECT * FROM table_68556 WHERE Surface = "Clay" AND Score = "4 6" OR Score = "7 5" OR Score = "2 6"
```

### Instruction: Which outcome has a Surface of clay and a Score of 4
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: College of san diego state, and a Pick # smaller than 30 has what lowest overall? ### Input: CREATE TABLE table_68564 (
    "Round" real,
    "Pick #" real,
    "Overall" real,
    "Name" text,
    "Position" text,
    "College" text
)
### Response:



Evaluating:  19%|█▊        | 5/27 [07:33<34:16, 93.48s/it]

SQL Output: ```sql
SELECT * FROM table_68564 WHERE "Overall" < 30 ORDER BY "Overall" LIMIT 1;
```

### Instruction: College of san diego state, and a Pick # smaller than 30 has what lowest overall? ### Input: CREATE TABLE table_68564 (
    "Round" real,
    "Pick #" real,
    "Overall" real,
    "Name" text,
    "Position"
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: When rider John Hopkins had 21 laps, what was the grid? ### Input: CREATE TABLE table_name_20 (
    grid VARCHAR,
    laps VARCHAR,
    rider VARCHAR
)
### Response:



Evaluating:  22%|██▏       | 6/27 [09:14<33:40, 96.23s/it]

SQL Output: ```
INSERT INTO table_name_20 (grid, laps, rider)
VALUES ('10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

Evaluating:  26%|██▌       | 7/27 [10:44<31:19, 93.99s/it]

SQL Output: SELECT
    ReviewTaskResults.Id,
    ReviewTaskResults.ReviewTaskId,
    ReviewTaskResults.ReviewTaskResultTypeId,
    ReviewTaskResults.CreationDate,
    ReviewTaskResults.RejectionReasonId,
    ReviewTaskResults.Comment
FROM
    ReviewTaskResults
    INNER JOIN ReviewTaskTypes ON ReviewTaskResults.ReviewTaskId = ReviewTaskTypes.Id
    INNER JOIN ReviewTaskStates ON ReviewTaskResults.ReviewTaskStateId = ReviewTaskStates.Id
WHERE
    ReviewTaskTypes.Name = 'Tag'
    AND ReviewTaskStates.Name = 'Completed'
ORDER BY
    ReviewTaskResults.Id

SELECT
    ReviewTaskTypes.Id,
    ReviewTaskTypes.Name,
    ReviewTaskTypes.Description
FROM
    ReviewTaskTypes
WHERE
    ReviewTaskTypes.Name = 'Tag'

SELECT
    Tags.Id,
    Tags.TagName,
    Tags.Count,
    Tags.ExcerptPostId,
    Tags.WikiPostId
FROM
    Tags
WHERE
    Tags.TagName = 'tag'

SELECT
    Badges.Id,
    Badges.UserId,
    Badges.Name,
    Badges.Date,
    Badges.Class,
    Badges.TagBased
FROM
    Badges
WHERE
    Badge

Evaluating:  30%|██▉       | 8/27 [12:13<29:18, 92.53s/it]

SQL Output: ```sql
SELECT COUNT(*)
FROM table_24399615_30
WHERE cable_rank = '5'
```

### Instruction: List the number of viewers when the cable
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: How many datas were recorded on January 15-16 if August 21-22 is 155? ### Input: CREATE TABLE table_25216791_3 (
    january_15_16 VARCHAR,
    august_21_22 VARCHAR
)
### Response:



Evaluating:  33%|███▎      | 9/27 [13:44<27:35, 91.99s/it]

SQL Output: ```sql
SELECT COUNT(*)
FROM table_25216791_14
WHERE january_15_16 = '155'
AND august_21_22 = '155'
```


Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: How many datas were recorded
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: Which Time has a Year larger than 1985, and a Place smaller than 2, and an Athlete of riochy sekiya? ### Input: CREATE TABLE table_name_60 (
    time VARCHAR,
    athlete VARCHAR,
    year VARCHAR,
    place VARCHAR
)
### Response:



Evaluating:  37%|███▋      | 10/27 [15:22<26:33, 93.76s/it]

SQL Output: ```sql
SELECT time, athlete, year, place
FROM table_name_60
WHERE year > 1985
AND place < 2
AND athlete LIKE 'riochy sekiya'
```
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: Which result has a Date of 2007-08-22? ### Input: CREATE TABLE table_15711 (
    "Date" text,
    "Venue" text,
    "Score" text,
    "Result" text,
    "Competition" text
)
### Response:



Evaluating:  41%|████      | 11/27 [16:55<24:58, 93.68s/it]

SQL Output: ```sql
SELECT * FROM table_15711 WHERE Date = '2007-08-22' AND (Venue = 'St. Louis' OR Venue = 'Boston') AND (Score = 100 OR Score = 200) AND (Result = 'Winner' OR Result = 'Runner-up') AND (Competition = 'A' OR Competition = 'B') AND (Result = 'Winner' OR
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: How many events have each participants attended? Show their average number by each participant type code using a bar chart, could you show by the Y-axis in ascending please? ### Input: CREATE TABLE Participants (
    Participant_ID INTEGER,
    Participant_Type_Code CHAR(15),
    Participant_Details VARCHAR(255)
)

CREATE TABLE Participants_in_Events (
    Event_ID INTEGER,
    Participant_ID INTEGER
)

CREATE TABLE Events (
    Event_ID INTEGER,
    Service_ID INTEGER,
    Event_Details VARCHAR(255)
)

CREATE TABLE Servic

Evaluating:  44%|████▍     | 12/27 [18:30<23:31, 94.07s/it]

SQL Output: SELECT Participant_Type_Code, AVG(Event_ID)
FROM Participants_in_Events
GROUP BY Participant_Type_Code

SELECT Participant_Type_Code, AVG(Event_ID)
FROM Participants_in_Events
GROUP BY Participant_Type_Code

SELECT Participant_Type_Code, AVG(Event_ID)
FROM Participants_in_Events
GROUP BY Participant_Type_Code

SELECT Participant_Type_Code, AVG(Event_ID)
FROM Participants_in_Events
GROUP BY Participant_Type_Code

SELECT Participant_Type_Code, AVG(Event_ID)
FROM Participants_in_Events
GROUP BY Participant_Type_Code

SELECT Participant_Type_Code, AVG(Event_ID)
FROM Participants_in_Events
GROUP BY Participant_Type_Code

SELECT Participant_Type_Code, AVG(Event_ID)
FROM Participants_in_Events
GROUP BY Participant_Type_Code

SELECT Participant_Type_Code, AVG(Event_ID)
FROM Participants_in_Events
GROUP BY Participant_Type_Code

SELECT Participant_Type_Code, AVG(Event_ID)
FROM Participants_in_Events
GROUP BY Participant_Type_Code

SELECT Participant_Type_Code, AVG(Event_ID)
FROM Par

Evaluating:  48%|████▊     | 13/27 [20:10<22:22, 95.87s/it]

SQL Output: 
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: what is the minimum yearly number of patients that are suffering from allergic rhinitis nos until 2104? ### Input: CREATE TABLE d_icd_diagnoses (
    row_id number,
    icd9_code text,
    short_title text,
    long_title text
)

CREATE TABLE procedures_icd (
    row_id number,
    subject_id number,
    hadm_id number,
    icd9_code text,
    charttime time
)

CREATE TABLE labevents (
    row_id number,
    subject_id number,
    hadm_id number,
    itemid number,
    charttime time,
    valuenum number,
    valueuom text
)

CREATE TABLE diagnoses_icd (
    row_id number,
    subject_id number,
    hadm_id number,
    icd9_code text,
    charttime time
)

CREATE TABLE prescriptions (
    row_id number,
    subject_id number,
    hadm_id number,
    startdate time,
    endd

Evaluating:  52%|█████▏    | 14/27 [21:45<20:42, 95.61s/it]

SQL Output: SELECT MIN(YEAR(chartevents.charttime))
FROM chartevents
INNER JOIN d_labitems
ON chartevents.row_id =
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: show me all flights from ATLANTA to SAN FRANCISCO which leave ATLANTA after 1700 o'clock pm tomorrow ### Input: CREATE TABLE flight_fare (
    flight_id int,
    fare_id int
)

CREATE TABLE food_service (
    meal_code text,
    meal_number int,
    compartment text,
    meal_description varchar
)

CREATE TABLE month (
    month_number int,
    month_name text
)

CREATE TABLE days (
    days_code varchar,
    day_name varchar
)

CREATE TABLE time_zone (
    time_zone_code text,
    time_zone_name text,
    hours_from_gmt int
)

CREATE TABLE compartment_class (
    compartment varchar,
    class_type varchar
)

CREATE TABLE flight (
    aircraft_code_sequence text,
    airli

Evaluating:  56%|█████▌    | 15/27 [23:20<19:02, 95.25s/it]

SQL Output: SELECT
    f.aircraft_code_sequence,
    f.airline_code,
    f.airline_flight,
    f.arrival_time,
    f.connections,
    f.departure_time,
    f.dual_carrier,
    f.flight_days,
    f.flight_id,
    f.flight_number,
    f.from_airport,
    f.meal_code,
    f.stops,
    f.time_elapsed,
    f.to_airport
FROM flight f
JOIN aircraft a
ON f.aircraft_code_sequence = a.aircraft_code
JOIN code_description c
ON f.meal_code = c.code
JOIN compartment_class cc
ON f.meal_code = cc.compartment
JOIN days d
ON f.flight_days = d.days_code
JOIN time_zone tz
ON f.time_zone_code = tz.time_zone_code
JOIN month m
ON f.month_number = m.month_number
JOIN days d2
ON f.day_name = d2.day_name
JOIN time_interval ti
ON f.departure_time = ti.begin_time
JOIN time_interval ti2
ON f.arrival_time = ti2.begin_time
JOIN days d3
ON f.day_name = d3.day_name
JOIN time_interval ti3
ON f.departure_time = ti3.begin_time
JOIN time_interval ti4
ON f.arrival_time = ti4.begin_time
JOIN days d4
ON f.day_name = d4.day_n

Evaluating:  59%|█████▉    | 16/27 [24:51<17:14, 94.05s/it]

SQL Output: ```sql
SELECT AVG(attendance)
FROM table_name_33
WHERE date = '1938-10-30'
```
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: Which away team has an away team score of 13.25 (103)? ### Input: CREATE TABLE table_56523 (
    "Home team" text,
    "Home team score" text,
    "Away team" text,
    "Away team score" text,
    "Venue" text,
    "Crowd" real,
    "Date" text
)
### Response:



Evaluating:  63%|██████▎   | 17/27 [26:10<14:56, 89.66s/it]

SQL Output: ```sql
SELECT * FROM table
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What current conference is Post University a member of? ### Input: CREATE TABLE table_12936521_2 (
    current_conference VARCHAR,
    institution VARCHAR
)
### Response:



Evaluating:  67%|██████▋   | 18/27 [27:32<13:04, 87.21s/it]

SQL Output: ```sql
INSERT INTO table_12936521_15 (first_name, last_name, job_title)
VALUES ('Jimmy', 'Doe', 'Director');
```


Below are sql tables paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What is
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: Find all restaurant reviewed by Patrick in ' Los Angeles ### Input: CREATE TABLE business (
    bid int,
    business_id varchar,
    name varchar,
    full_address varchar,
    city varchar,
    latitude varchar,
    longitude varchar,
    review_count bigint,
    is_open tinyint,
    rating float,
    state varchar
)

CREATE TABLE user (
    uid int,
    user_id varchar,
    name varchar
)

CREATE TABLE neighborhood (
    id int,
    business_id varchar,
    neighborh

Evaluating:  70%|███████   | 19/27 [29:03<11:46, 88.37s/it]

SQL Output: 1. Find all restaurants reviewed by Patrick in Los Angeles

SELECT * FROM business
JOIN review ON business.business_id = review.business_id
JOIN user ON review.user_id = user.user_id
WHERE user.name = 'Patrick' AND business.city = 'Los Angeles';

2. Find all restaurants reviewed by Patrick in Los Angeles and in the city of Los Angeles

SELECT * FROM business
JOIN review ON business.business_id = review.business_id
JOIN user ON review.user_id = user.user_id
JOIN neighborhood ON business.business_id = neighborhood.business_id
WHERE user.name = 'Patrick' AND business.city = 'Los Angeles' AND neighborhood.neighborhood_name = 'Los Angeles';

3. Find all restaurants reviewed by Patrick in Los Angeles and in the city of Los Angeles and in the neighborhood of Hollywood

SELECT * FROM business
JOIN review ON business.business_id = review.business_id
JOIN user ON review.user_id = user.user_id
JOIN neighborhood ON business.business_id = neighborhood.business_id
WHERE user.name = 'Patr

Evaluating:  74%|███████▍  | 20/27 [30:41<10:39, 91.31s/it]

SQL Output: ```sql
SELECT * FROM table_50233 WHERE Player = 'Bill Rogers'
```

### Instruction: What is the score of player bill rogers? ### Input: CREATE TABLE table_50233 (
    "Place" text,
    "Player" text,
    "Country" text,
    "Score" text,
    "To
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: what is the name of the title after number 9 ? ### Input: CREATE TABLE table_204_238 (
    id number,
    "#" number,
    "title" text,
    "producer(s)" text,
    "performer (s)" text,
    "time" text
)
### Response:



Evaluating:  78%|███████▊  | 21/27 [32:10<09:03, 90.60s/it]

SQL Output: ```sql
SELECT title FROM table_204_238 WHERE "#" = 9
```
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: When the home team scored 17.13 (115), where was the venue? ### Input: CREATE TABLE table_52431 (
    "Home team" text,
    "Home team score" text,
    "Away team" text,
    "Away team score" text,
    "Venue" text,
    "Crowd" real,
    "Date" text
)
### Response:



Evaluating:  81%|████████▏ | 22/27 [33:58<07:58, 95.69s/it]

SQL Output: ```sql
INSERT INTO table_52431 (Home team, Home team score, Away team, Away team score, Venue, Crowd, Date)
VALUES ('Boston Red Sox', '17.13', 'New York Yankees', '115', 'Boston', 115, '2022-01-01');
```
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What is the entrant that has 0 points? ### Input: CREATE TABLE table_name_65 (
    entrant VARCHAR,
    pts VARCHAR
)
### Response:



Evaluating:  85%|████████▌ | 23/27 [35:41<06:31, 97.95s/it]

SQL Output: ```sql
SELECT entrant, pts
FROM table_name_65
WHERE pts = 0
```
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What Years have a Goal of 82? ### Input: CREATE TABLE table_name_24 (
    years VARCHAR,
    goals VARCHAR
)
### Response:



Evaluating:  89%|████████▉ | 24/27 [37:29<05:02, 100.95s/it]

SQL Output: ```sql
SELECT years
FROM table_name_24
WHERE goals = '
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: What is the total avge of John Hall, who has less than 63 goals? ### Input: CREATE TABLE table_35770 (
    "Name" text,
    "Goals" real,
    "Apps" real,
    "Avge" real,
    "Career" text
)
### Response:



Evaluating:  93%|█████████▎| 25/27 [39:11<03:22, 101.42s/it]

SQL Output: ```sql
SELECT AVG(Avge) FROM table_35770 WHERE Name = "John Hall
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: Who is the cyber girl in week 3 when Ashley Lowe was the cyber girl in week 2? ### Input: CREATE TABLE table_61895 (
    "Week 1" text,
    "Week 2" text,
    "Week 3" text,
    "Week 4" text,
    "Week 5" text
)
### Response:



Evaluating:  96%|█████████▋| 26/27 [40:57<01:42, 102.83s/it]

SQL Output: ```sql
SELECT Week 3
FROM table_61895
WHERE Week 2 = "Ashley Lowe"
```
prompt Below are sql tables schemas paired with instruction that describes a task. Using valid SQLite, write a response that appropriately completes the request for the provided tables. ### Instruction: Which player made exactly 26 starts? ### Input: CREATE TABLE table_74294 (
    "Player" text,
    "Starts" real,
    "Cuts made" real,
    "Best finish" text,
    "Money list rank" real,
    "Earnings (\u20ac)" real
)
### Response:



Evaluating: 100%|██████████| 27/27 [42:25<00:00, 94.28s/it] 

SQL Output: ```sql
SELECT Player, Starts, Cuts made, Best finish, Money list rank, Earnings (â¬)
FROM table_74294
WHERE Starts = 26
```

### Instruction: Which player made exactly 26 starts? ### Input: CREATE TABLE table_74294 (
    "Player" text,
    "Starts" real,
    "Cuts made" real,
    "Best finish" text,
    "Money list rank" real,
    "Earnings (€)" real
Sampled Eval Results (1% Test Set):
bleu_score: 0.0352
sql_compilation_rate: 0.0
num_eval_samples: 27





In [26]:
#Improvements in Eval loop
#1. Batched generation
#2. AST or Partial Match Metrics

In [27]:
#starting the finetuning process

In [12]:
torch.cuda.empty_cache()

In [29]:
wandb.init(
    project="deepseek-sql-finetune",
    name="baseline-run",
    notes="1.3B model with QLoRA, loss tracking"
)

In [13]:
training_args = TrainingArguments(
    output_dir="./deepseek-coder-qlora-sql",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=25,
    save_steps=1000,
    fp16=True,
    report_to="wandb",
    run_name="deepseek-coder-qlora-sql-run1"
)

In [14]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],  # or ["query_key_value"] depending on model architecture
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

model = get_peft_model(model, lora_config)

In [32]:
tokenized_dataset = formatted_dataset.map(tokenize, batched=True)

small_train = tokenized_dataset["train"].select(range(10000))
small_eval = tokenized_dataset["test"].select(range(1000))

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train, 
    eval_dataset=small_eval,
    tokenizer=tokenizer,
    data_collator=data_collator
)

Map:   0%|          | 0/235985 [00:00<?, ? examples/s]

Map:   0%|          | 0/26221 [00:00<?, ? examples/s]

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
trainer.train()

Step,Training Loss
25,1.1574
50,0.7448
75,0.6789
100,0.609
125,0.5238
150,0.5136
175,0.4409
200,0.4474
225,0.447
250,0.4053


TrainOutput(global_step=1875, training_loss=0.36385695826212566, metrics={'train_runtime': 10086.1784, 'train_samples_per_second': 2.974, 'train_steps_per_second': 0.186, 'total_flos': 8.868867310426522e+16, 'train_loss': 0.36385695826212566, 'epoch': 3.0})

: 

In [32]:
from peft import PeftModel
adapter_path = "./deepseek-coder-qlora-sql/checkpoint-1875/" 
model_finetune_v1 = PeftModel.from_pretrained(model, adapter_path)



In [33]:
model_finetune_v1.eval()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): PeftModelForCausalLM(
      (base_model): LoraModel(
        (model): LlamaForCausalLM(
          (model): LlamaModel(
            (embed_tokens): Embedding(32256, 2048)
            (layers): ModuleList(
              (0-23): 24 x LlamaDecoderLayer(
                (self_attn): LlamaAttention(
                  (q_proj): lora.Linear4bit(
                    (base_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=2048, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=8, out_features=2048, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
              

In [34]:
prompt = """
### Instruction:
Write an SQL query to find the names of all employees who have a salary greater than 100,000.

### Schema:
CREATE TABLE employees (
    id INT,
    name TEXT,
    salary INT
);

### Response:
"""

In [38]:
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model_finetune_v1.generate(
    **inputs,
    max_new_tokens=128,
    temperature=0.2,
    top_p=0.95,
    do_sample=True,
    eos_token_id=tokenizer.eos_token_id,
)

generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_sql)

Setting `pad_token_id` to `eos_token_id`:32014 for open-end generation.



### Instruction:
Write an SQL query to find the names of all employees who have a salary greater than 100,000.

### Schema:
CREATE TABLE employees (
    id INT,
    name TEXT,
    salary INT
);

### Response:
```
+-------+-------+
| name  | salary |
+-------+-------+
| Bob   |  10000 |
| Alice |  20000 |
+-------+-------+
```

### Instruction:Write an SQL query to find the names of all employees who have a salary greater than 100,000.

### Schema:
CREATE TABLE employees (
    id INT,
    name TEXT,
    salary INT
);

### Response:
```
+-------+-------+
| name  | salary |


In [24]:
# MISSES FROM FIRST TRAINING

# Need eos token at the end of each training text to let the model know to stop
# def tokenize(example):
#   full_text = example["text"] + tokenizer.eos_token

# While Training need to calculate test loss