# LoRA Finetuning for T5-Base-Finetuned-WikiSQL

This notebook walks you through the process of finetuning the `mrm8488/t5-base-finetuned-wikiSQL` model using LoRA (Low-Rank Adaptation).

## Model Info
- **Base Model**: T5-base (~220M parameters)
- **Pre-trained on**: WikiSQL dataset
- **Input Format**: `translate English to SQL: <question> context: <table_columns>`

## Why LoRA?
- **Memory Efficient**: Only trains a small number of parameters
- **Fast Training**: Significantly faster than full finetuning
- **Preserves Base Knowledge**: Original model weights remain frozen
- **Easy Deployment**: Small adapter files can be shared/deployed separately

## Data formatting

DATA Generate Data

In [None]:
import getpass, os

# 这里让你在运行时输入 token, 不会显示在输出里
token = getpass.getpass("input Token: ")

# 根据实际情况填写
username = "qiulinfan"
owner = "qiulinfan"  # 通常就是你的用户名
repo = "eecs-595-project-nl2sql"

os.environ["GITHUB_TOKEN"] = token
os.environ["GITHUB_USER"] = username
os.environ["GITHUB_OWNER"] = owner
os.environ["GITHUB_REPO"] = repo

input Token: ··········


In [None]:
!git clone https://${GITHUB_USER}:${GITHUB_TOKEN}@github.com/${GITHUB_OWNER}/${GITHUB_REPO}.git

Cloning into 'eecs-595-project-nl2sql'...
remote: Enumerating objects: 52, done.[K
remote: Counting objects: 100% (52/52), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 52 (delta 11), reused 36 (delta 7), pack-reused 0 (from 0)[K
Receiving objects: 100% (52/52), 714.12 KiB | 37.58 MiB/s, done.
Resolving deltas: 100% (11/11), done.


In [None]:
%cd eecs-595-project-nl2sql/

/content/eecs-595-project-nl2sql


In [None]:
!pip install sentence-transformers



In [65]:
!git pull

remote: Enumerating objects: 16, done.[K
remote: Counting objects:   6% (1/16)[Kremote: Counting objects:  12% (2/16)[Kremote: Counting objects:  18% (3/16)[Kremote: Counting objects:  25% (4/16)[Kremote: Counting objects:  31% (5/16)[Kremote: Counting objects:  37% (6/16)[Kremote: Counting objects:  43% (7/16)[Kremote: Counting objects:  50% (8/16)[Kremote: Counting objects:  56% (9/16)[Kremote: Counting objects:  62% (10/16)[Kremote: Counting objects:  68% (11/16)[Kremote: Counting objects:  75% (12/16)[Kremote: Counting objects:  81% (13/16)[Kremote: Counting objects:  87% (14/16)[Kremote: Counting objects:  93% (15/16)[Kremote: Counting objects: 100% (16/16)[Kremote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects:  14% (1/7)[Kremote: Compressing objects:  28% (2/7)[Kremote: Compressing objects:  42% (3/7)[Kremote: Compressing objects:  57% (4/7)[Kremote: Compressing objects:  71% (5/7)[Kremote: Compressing objects:  8

In [None]:
!python download_data.py



NL2SQL Dataset Downloader

[1/2] WikiSQL Dataset
----------------------------------------
WikiSQL already downloaded.

[2/2] Spider Dataset
----------------------------------------
Downloading Spider dataset from Google Drive...
Downloading...
From (original): https://drive.google.com/uc?id=1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J
From (redirected): https://drive.google.com/uc?id=1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J&confirm=t&uuid=759f23ab-dd63-4234-a1b8-8b81f8b23a8c
To: /content/eecs-595-project-nl2sql/data/spider_data.zip
100% 206M/206M [00:01<00:00, 201MB/s]
Extracting Spider dataset...
Renamed spider_data/ -> spider/
Spider extracted to /content/eecs-595-project-nl2sql/data/spider

Dataset Verification
WikiSQL train: 56355 entries
WikiSQL dev: 8421 entries
WikiSQL test: 15878 entries
WikiSQL train_tables: 18585 entries
WikiSQL dev_tables: 2716 entries
WikiSQL test_tables: 5230 entries
Spider train: 7000 entries
Spider dev: 1034 entries
Spider tables: 166 entries

Done!


In [71]:
!python prepare_training_data.py --semantic  --semantic-threshold 0.8 --wikisql-balanced 5000 --spider


Processing WikiSQL train split...
Sampling 5000 balanced WikiSQL examples from train...
  Pattern distribution in dataset:
    where_only: 40598
    count: 5114
    min: 3231
    max: 3161
    avg: 2201
    sum: 2042
    select_only: 8
  Sampled 4293 examples
WikiSQL: 100% 4293/4293 [00:00<00:00, 39926.71it/s]
Processed 4293 WikiSQL examples
Processing WikiSQL dev split...
Sampling 1000 balanced WikiSQL examples from dev...
  Pattern distribution in dataset:
    where_only: 6017
    count: 779
    max: 507
    min: 468
    avg: 329
    sum: 321
    select_only: 0
  Sampled 999 examples
WikiSQL: 100% 999/999 [00:00<00:00, 44669.25it/s]
Processed 999 WikiSQL examples
Saved 4293 examples to /content/eecs-595-project-nl2sql/training_data/wikisql_train.jsonl
Saved 999 examples to /content/eecs-595-project-nl2sql/training_data/wikisql_dev.jsonl
Processing Spider train split...
Spider:   0% 0/7000 [00:00<?, ?it/s]2025-12-06 12:20:49.661781: E external/local_xla/xla/stream_executor/cuda/cuda_f

In [72]:
import pandas as pd
pd.set_option("display.max_rows", None)   # 不省略行
pd.set_option("display.max_columns", None)  # 不省略列
pd.set_option("display.max_colwidth", None) # 不省略内容长度
df = pd.read_json("training_data/spider_train.jsonl", lines=True)
df['schema'][0:50]

Unnamed: 0,schema
0,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"
1,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"
2,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"
3,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"
4,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"
5,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"
6,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"
7,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"
8,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"
9,"[DATABASE]\ndepartment_management\n\n[TABLES]\ndepartment:\n Department_ID (PK)\n Name\n Creation\n Ranking\n Budget_in_Billions\n Num_Employees\nhead:\n head_ID (PK)\n name\n born_state\n age\nmanagement:\n department_ID (PK, FK)\n head_ID (FK)\n temporary_acting\n\n[FOREIGN KEYS]\nmanagement.head_ID -> head.head_ID\nmanagement.department_ID -> department.Department_ID\n\n[SEMANTIC LINKS]\ndepartment.Name ≈ head.name"


In [None]:
import pandas as pd
df = pd.read_json("training_data/wikisql_train.jsonl", lines=True)
df['schema'][0:50]

Unnamed: 0,schema
0,"Schema Graph:\nTable: table(District, s Barangay, Population (2010 census), Area ( has .), Pop. density (per km2))"
1,"Schema Graph:\nTable: table(Rural municipality (RM), RM No., SARM Div. No., Census Div. No., Population (2011), Population (2006), Change (%), Land area (km²), Population density (per km²))"
2,"Schema Graph:\nTable: table(Radical (variants), Stroke count, Pīnyīn, Hiragana - Romaji, Meaning, Frequency, Examples)"
3,"Schema Graph:\nTable: table(South West DFL, Wins, Byes, Losses, Draws, Against)"
4,"Schema Graph:\nTable: table(Pick #, Player, Position, Nationality, NHL team, College/junior/club team)"
5,"Schema Graph:\nTable: table(Player, Attempts, Yards, Average, Long, Touchdowns)"
6,"Schema Graph:\nTable: table(Episode Number, Broadcast Date, Title, Written by, Viewership (Millions))"
7,"Schema Graph:\nTable: table(Ballarat FL, Wins, Byes, Losses, Draws, Against)"
8,"Schema Graph:\nTable: table(Season, Champion, Motorcycle, Wins, 2nd pl., 3rd pl., Team)"
9,"Schema Graph:\nTable: table(HR no., HR name, CR no., LMS no., Built, Works, Withdrawn)"


## Training

### Install Dependencies

In [None]:
!pip install transformers datasets peft accelerate bitsandbytes torch wandb --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!pip install -r requirements_train.txt

In [67]:
import transformers
import peft
import accelerate

print(f"transformers: {transformers.__version__}")
print(f"peft: {peft.__version__}")
print(f"accelerate: {accelerate.__version__}")

transformers: 4.57.2
peft: 0.18.0
accelerate: 1.12.0


In [78]:
import os
import json
from pathlib import Path
from datetime import datetime

import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType,
)

import wandb

### GPU Check

In [74]:
import torch

print("=" * 60)
print("GPU INFORMATION")
print("=" * 60)

if torch.cuda.is_available():
    print(f"✓ CUDA available: True")
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    props = torch.cuda.get_device_properties(0)
    total_mem = props.total_memory / 1e9
    print(f"  Memory: {total_mem:.2f} GB")
    print(f"  CUDA Version: {torch.version.cuda}")

    # Recommend config based on VRAM
    if total_mem < 12:
        print(f"\n  Recommended config: small_gpu (4-bit quantization)")
    elif total_mem < 24:
        print(f"\n Recommended config: default")
    else:
        print(f"\n Recommended config: large_gpu (more LoRA capacity)")
else:
    print("✗ CUDA not available - Training will be VERY slow on CPU")

print("=" * 60)

GPU INFORMATION
✓ CUDA available: True
  GPU: NVIDIA A100-SXM4-80GB
  Memory: 85.17 GB
  CUDA Version: 12.6

 Recommended config: large_gpu (more LoRA capacity)


### Configuration

Define your hyperparameters and LoRA configuration.

In [94]:
# =============================================================================
# MODEL CONFIGURATION
# =============================================================================

# Model selection - choose one:
# - "Qwen/Qwen2.5-7B-Instruct"          (~7B params, recommended)
# - "microsoft/Phi-3.5-mini-instruct"   (~3.8B params, faster)
# - "deepseek-ai/deepseek-coder-6.7b-instruct"  (~6.7B, code-focused)
# - "codellama/CodeLlama-7b-Instruct-hf" (~7B, Meta's code model)
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"

# Quantization (for memory efficiency)
LOAD_IN_4BIT = True    # Recommended for <16GB VRAM
LOAD_IN_8BIT = False   # Alternative: less compression, slightly better quality

# =============================================================================
# LORA CONFIGURATION
# =============================================================================

LORA_R = 32              # Rank: higher = more capacity, more memory (8, 16, 32, 64)
LORA_ALPHA = 64          # Scaling factor (typically 1-2x r)
LORA_DROPOUT = 0.05      # Dropout for regularization

# Target modules for LoRA
# Full: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
# Light: ["q_proj", "v_proj"]
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]

# =============================================================================
# TRAINING CONFIGURATION
# =============================================================================

# Training phases
WIKISQL_EPOCHS = 1       # Phase 1: WikiSQL warmup
SPIDER_EPOCHS = 3        # Phase 2: Spider main training

# Batch size
BATCH_SIZE = 2                    # Per-device batch size
GRADIENT_ACCUMULATION = 8         # Effective batch = BATCH_SIZE * GRADIENT_ACCUMULATION

# Learning rate
LEARNING_RATE = 2e-4
LR_SCHEDULER = "cosine"           # "cosine", "linear", "constant"
WARMUP_RATIO = 0.05

# Sequence length
MAX_SEQ_LENGTH = 1024             # Input + Output combined

# Checkpointing
SAVE_STRATEGY = "epoch"           # "epoch" or "steps"
SAVE_TOTAL_LIMIT = 5              # Keep last N checkpoints

# Mixed precision
USE_BF16 = True                   # Use bfloat16 (recommended for newer GPUs)
USE_FP16 = False                  # Use float16 (for older GPUs)

# Gradient checkpointing (saves memory, slightly slower)
GRADIENT_CHECKPOINTING = True

# =============================================================================
# DATA CONFIGURATION
# =============================================================================

DATA_DIR = "./training_data"
OUTPUT_DIR = "./checkpoints"

# Limit samples (None = use all)
MAX_TRAIN_SAMPLES = None          # Set to e.g. 1000 for quick testing
MAX_EVAL_SAMPLES = 500            # Limit eval for speed

# =============================================================================
# WANDB CONFIGURATION
# =============================================================================

USE_WANDB = True
WANDB_PROJECT = "nl2sql-finetuning"
WANDB_RUN_NAME = None             # Auto-generated if None

eval_strategy="epoch",  # ← Now matches
save_strategy="epoch",
load_best_model_at_end=True,


print(" Configuration loaded!")
print(f"  Model: {MODEL_NAME}")
print(f"  LoRA rank: {LORA_R}, alpha: {LORA_ALPHA}")
print(f"  Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION} effective")
print(f"  Training: {WIKISQL_EPOCHS} epoch WikiSQL + {SPIDER_EPOCHS} epochs Spider")

 Configuration loaded!
  Model: Qwen/Qwen2.5-7B-Instruct
  LoRA rank: 32, alpha: 64
  Batch size: 2 x 8 = 16 effective
  Training: 1 epoch WikiSQL + 3 epochs Spider


### Load Datasets

In [81]:
def load_jsonl(file_path: str) -> list:
    """Load data from JSONL file."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# Load datasets
data_dir = Path(DATA_DIR)

print("=" * 60)
print("DATASET INFORMATION")
print("=" * 60)

# WikiSQL
wikisql_train_path = data_dir / "wikisql_train.jsonl"
wikisql_dev_path = data_dir / "wikisql_dev.jsonl"

if wikisql_train_path.exists():
    wikisql_train_data = load_jsonl(str(wikisql_train_path))
    wikisql_dev_data = load_jsonl(str(wikisql_dev_path))
    print(f"\n WikiSQL Dataset:")
    print(f"   Train samples: {len(wikisql_train_data):,}")
    print(f"   Dev samples:   {len(wikisql_dev_data):,}")
    print(f"   Columns: {list(wikisql_train_data[0].keys())}")
else:
    print(f"\n  WikiSQL not found at {wikisql_train_path}")
    wikisql_train_data = []
    wikisql_dev_data = []

# Spider
spider_train_path = data_dir / "spider_train.jsonl"
spider_dev_path = data_dir / "spider_dev.jsonl"

if spider_train_path.exists():
    spider_train_data = load_jsonl(str(spider_train_path))
    spider_dev_data = load_jsonl(str(spider_dev_path))
    print(f"\n Spider Dataset:")
    print(f"   Train samples: {len(spider_train_data):,}")
    print(f"   Dev samples:   {len(spider_dev_data):,}")
    print(f"   Columns: {list(spider_train_data[0].keys())}")

    # Count multi-table and JOIN queries
    multi_table = sum(1 for ex in spider_train_data if ex.get('num_tables', 1) > 1)
    has_join = sum(1 for ex in spider_train_data if ex.get('has_join', False))
    print(f"   Multi-table: {multi_table:,} ({100*multi_table/len(spider_train_data):.1f}%)")
    print(f"   With JOIN:   {has_join:,} ({100*has_join/len(spider_train_data):.1f}%)")
else:
    print(f"\n  Spider not found at {spider_train_path}")
    spider_train_data = []
    spider_dev_data = []

print("\n" + "=" * 60)

DATASET INFORMATION

 WikiSQL Dataset:
   Train samples: 4,293
   Dev samples:   999
   Columns: ['input', 'schema', 'question', 'output', 'sql', 'dataset', 'table_id', 'split']

 Spider Dataset:
   Train samples: 7,000
   Dev samples:   1,034
   Columns: ['input', 'schema', 'question', 'output', 'sql', 'dataset', 'db_id', 'split', 'num_tables', 'has_join']
   Multi-table: 7,000 (100.0%)
   With JOIN:   2,783 (39.8%)



In [82]:
# Show sample examples
print("=" * 60)
print("SAMPLE EXAMPLES")
print("=" * 60)

if wikisql_train_data:
    print("\n WikiSQL Example:")
    example = wikisql_train_data[0]
    print(f"Question: {example.get('question', 'N/A')}")
    print(f"SQL: {example.get('sql', 'N/A')}")
    print(f"\nSchema preview:")
    schema = example.get('schema', 'N/A')
    print(schema[:500] + "..." if len(schema) > 500 else schema)

if spider_train_data:
    print("\n" + "-" * 60)
    print("\n Spider Example:")
    example = spider_train_data[0]
    print(f"Question: {example.get('question', 'N/A')}")
    print(f"SQL: {example.get('sql', 'N/A')}")
    print(f"Database: {example.get('db_id', 'N/A')}")
    print(f"\nSchema preview:")
    schema = example.get('schema', 'N/A')
    print(schema[:500] + "..." if len(schema) > 500 else schema)

SAMPLE EXAMPLES

 WikiSQL Example:
Question: What is the total of Barangay with an area larger than 865.13?
SQL: SELECT SUM("s Barangay") FROM "table" WHERE "Area ( has .)" > 865.13

Schema preview:
[TABLES]
Manila:
    District (PK)
    s Barangay
    Population (2010 census)
    Area ( has .)
    Pop. density (per km2)

------------------------------------------------------------

 Spider Example:
Question: How many heads of the departments are older than 56 ?
SQL: SELECT count(*) FROM head WHERE age  >  56
Database: department_management

Schema preview:
[DATABASE]
department_management

[TABLES]
department:
    Department_ID (PK)
    Name
    Creation
    Ranking
    Budget_in_Billions
    Num_Employees
head:
    head_ID (PK)
    name
    born_state
    age
management:
    department_ID (PK, FK)
    head_ID (FK)
    temporary_acting

[FOREIGN KEYS]
management.head_ID -> head.head_ID
management.department_ID -> department.Department_ID

[SEMANTIC LINKS]
department.Name ≈ head.name


### Load Model and Tokenizer

In [83]:
print("=" * 60)
print("LOADING MODEL")
print("=" * 60)
print(f"\nModel: {MODEL_NAME}")

# Configure quantization
quantization_config = None
if LOAD_IN_4BIT:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    print("Using 4-bit quantization (QLoRA)")
elif LOAD_IN_8BIT:
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
    )
    print("Using 8-bit quantization")

# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right",
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Load model
print("Loading model (this may take a few minutes)...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16 if USE_BF16 else torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

# Prepare for k-bit training
if LOAD_IN_4BIT or LOAD_IN_8BIT:
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=GRADIENT_CHECKPOINTING
    )

print("\n✓ Model loaded!")

LOADING MODEL

Model: Qwen/Qwen2.5-7B-Instruct
Using 4-bit quantization (QLoRA)

Loading tokenizer...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Loading model (this may take a few minutes)...


config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]


✓ Model loaded!


### Apply LoRA

In [89]:
print("=" * 60)
print("APPLYING LoRA")
print("=" * 60)

print(f"\n LoRA Configuration:")
print(f"   Rank (r):        {LORA_R}")
print(f"   Alpha:           {LORA_ALPHA}")
print(f"   Dropout:         {LORA_DROPOUT}")
print(f"   Target modules:  {LORA_TARGET_MODULES}")

lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=LORA_TARGET_MODULES,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)

print("\n LoRA applied!")

APPLYING LoRA

 LoRA Configuration:
   Rank (r):        32
   Alpha:           64
   Dropout:         0.05
   Target modules:  ['q_proj', 'k_proj', 'v_proj', 'o_proj']





 LoRA applied!


### Model Statistics

In [90]:
print("=" * 60)
print("MODEL PARAMETERS (AFTER LoRA)")
print("=" * 60)

# Use PEFT's built-in method
model.print_trainable_parameters()

# Also compute manually for more detail
total_params, trainable_params = count_parameters(model)
frozen_params = total_params - trainable_params

print(f"\n📊 Detailed Parameter Count:")
print(f"   Total parameters:     {total_params:>15,} ({format_params(total_params)})")
print(f"   Trainable parameters: {trainable_params:>15,} ({format_params(trainable_params)})")
print(f"   Frozen parameters:    {frozen_params:>15,} ({format_params(frozen_params)})")
print(f"   Trainable %:          {100 * trainable_params / total_params:>14.4f}%")

# LoRA parameter breakdown
print(f"\n🔧 LoRA Adapter Size:")
lora_params = sum(p.numel() for n, p in model.named_parameters() if 'lora' in n.lower())
print(f"   LoRA parameters:      {lora_params:>15,} ({format_params(lora_params)})")

MODEL PARAMETERS (AFTER LoRA)
trainable params: 20,185,088 || all params: 7,635,801,600 || trainable%: 0.2643

📊 Detailed Parameter Count:
   Total parameters:       4,373,157,376 (4.37B)
   Trainable parameters:      20,185,088 (20.19M)
   Frozen parameters:      4,352,972,288 (4.35B)
   Trainable %:                  0.4616%

🔧 LoRA Adapter Size:
   LoRA parameters:           20,185,088 (20.19M)


### Prepare  Dataset

WikiSQL model expects format: `translate English to SQL: <question> context: <columns>`

Dataset should have:
- `question`: Natural language question
- `schema`: Database schema/table columns
- `sql`: Target SQL query

In [91]:
def create_prompt(example: dict) -> str:
    """Create training prompt from example."""
    system_msg = (
        "You are a SQL expert. Given a database schema and a natural language question, "
        "generate the correct SQL query. Output only the SQL query."
    )

    user_input = example.get("input", "")
    sql_output = example.get("sql", example.get("output", ""))

    # Remove [SQL] prefix if present
    if sql_output.startswith("[SQL]\n"):
        sql_output = sql_output[6:]

    prompt = f"""<|im_start|>system
{system_msg}<|im_end|>
<|im_start|>user
{user_input}<|im_end|>
<|im_start|>assistant
{sql_output}<|im_end|>"""

    return prompt


def preprocess_function(examples, tokenizer, max_length):
    """Preprocess examples for training."""
    prompts = [create_prompt({"input": inp, "sql": sql})
               for inp, sql in zip(examples["input"], examples["sql"])]

    tokenized = tokenizer(
        prompts,
        truncation=True,
        max_length=max_length,
        padding="max_length",
        return_tensors=None,
    )

    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized


def prepare_dataset(data: list, tokenizer, max_length: int, max_samples: int = None, desc: str = "data"):
    """Prepare dataset for training."""
    if max_samples:
        data = data[:max_samples]

    dataset = Dataset.from_list(data)

    processed = dataset.map(
        lambda x: preprocess_function(x, tokenizer, max_length),
        batched=True,
        remove_columns=dataset.column_names,
        desc=f"Tokenizing {desc}"
    )

    return processed

print("✓ Data preprocessing functions defined!")

✓ Data preprocessing functions defined!


### Initialize Wandb

In [92]:
if USE_WANDB:
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_name = WANDB_RUN_NAME or f"nl2sql_{timestamp}"

    wandb.init(
        project=WANDB_PROJECT,
        name=run_name,
        config={
            "model": MODEL_NAME,
            "lora_r": LORA_R,
            "lora_alpha": LORA_ALPHA,
            "lora_dropout": LORA_DROPOUT,
            "batch_size": BATCH_SIZE,
            "gradient_accumulation": GRADIENT_ACCUMULATION,
            "learning_rate": LEARNING_RATE,
            "max_seq_length": MAX_SEQ_LENGTH,
            "wikisql_epochs": WIKISQL_EPOCHS,
            "spider_epochs": SPIDER_EPOCHS,
        }
    )
    print(f"✓ WandB initialized! Run: {run_name}")
else:
    run_name = "nl2sql_training"
    print("ℹ️  WandB disabled")

0,1
train/epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train/global_step,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train/grad_norm,▇█▇▄▃▃▃▂▂▂▂▂▂▂▁▂▂▁▁▁
train/learning_rate,▁▂▃▃▄▅▆▆▇███████████
train/loss,██▇▆▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
train/epoch,0.3047
train/global_step,200.0
train/grad_norm,0.26317
train/learning_rate,0.00028
train/loss,0.2484


✓ WandB initialized! Run: nl2sql_20251206_125353


### Phase 1: WikiSQL Warmup Training

In [96]:
if wikisql_train_data and WIKISQL_EPOCHS > 0:
    print("=" * 60)
    print("PHASE 1: WikiSQL Warmup")
    print("=" * 60)

    # Prepare datasets
    print("\nPreparing WikiSQL data...")
    wikisql_train_dataset = prepare_dataset(
        wikisql_train_data, tokenizer, MAX_SEQ_LENGTH, MAX_TRAIN_SAMPLES, "WikiSQL train"
    )
    wikisql_eval_dataset = prepare_dataset(
        wikisql_dev_data, tokenizer, MAX_SEQ_LENGTH, MAX_EVAL_SAMPLES, "WikiSQL eval"
    )

    print(f"\n📊 WikiSQL Dataset Ready:")
    print(f"   Train: {len(wikisql_train_dataset):,} samples")
    print(f"   Eval:  {len(wikisql_eval_dataset):,} samples")

    # Training arguments
    phase1_output = f"{OUTPUT_DIR}/phase1_wikisql"

    training_args = TrainingArguments(
        output_dir=phase1_output,
        num_train_epochs=WIKISQL_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION,
        learning_rate=LEARNING_RATE,
        lr_scheduler_type=LR_SCHEDULER,
        warmup_ratio=WARMUP_RATIO,
        weight_decay=0.01,
        logging_steps=10,
        eval_strategy="epoch",
        eval_steps=200,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=SAVE_TOTAL_LIMIT,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        bf16=USE_BF16,
        fp16=USE_FP16,
        gradient_checkpointing=GRADIENT_CHECKPOINTING,
        gradient_checkpointing_kwargs={"use_reentrant": False} if GRADIENT_CHECKPOINTING else None,
        report_to="wandb" if USE_WANDB else "none",
        run_name=f"{run_name}_phase1",
        remove_unused_columns=False,
    )

    # Data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=wikisql_train_dataset,
        eval_dataset=wikisql_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    print("\n🚀 Starting Phase 1 training...")
    trainer.train()

    # Save
    print(f"\n💾 Saving Phase 1 model to {phase1_output}/final")
    trainer.save_model(f"{phase1_output}/final")
    tokenizer.save_pretrained(f"{phase1_output}/final")

    print("\n✓ Phase 1 complete!")
else:
    print("⏭️  Skipping Phase 1 (WikiSQL not available or epochs=0)")

PHASE 1: WikiSQL Warmup

Preparing WikiSQL data...


Tokenizing WikiSQL train:   0%|          | 0/4293 [00:00<?, ? examples/s]

Tokenizing WikiSQL eval:   0%|          | 0/500 [00:00<?, ? examples/s]

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.



📊 WikiSQL Dataset Ready:
   Train: 4,293 samples
   Eval:  500 samples

🚀 Starting Phase 1 training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

### Phase 2: Spider Main Training

In [None]:
if spider_train_data and SPIDER_EPOCHS > 0:
    print("=" * 60)
    print("PHASE 2: Spider Main Training")
    print("=" * 60)

    # Prepare datasets
    print("\nPreparing Spider data...")
    spider_train_dataset = prepare_dataset(
        spider_train_data, tokenizer, MAX_SEQ_LENGTH, MAX_TRAIN_SAMPLES, "Spider train"
    )
    spider_eval_dataset = prepare_dataset(
        spider_dev_data, tokenizer, MAX_SEQ_LENGTH, MAX_EVAL_SAMPLES, "Spider eval"
    )

    print(f"\n Spider Dataset Ready:")
    print(f"   Train: {len(spider_train_dataset):,} samples")
    print(f"   Eval:  {len(spider_eval_dataset):,} samples")

    # Training arguments
    phase2_output = f"{OUTPUT_DIR}/phase2_spider"

    training_args = TrainingArguments(
        output_dir=phase2_output,
        num_train_epochs=SPIDER_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION,
        learning_rate=LEARNING_RATE,
        lr_scheduler_type=LR_SCHEDULER,
        warmup_ratio=WARMUP_RATIO,
        weight_decay=0.01,
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=200,
        save_strategy=SAVE_STRATEGY,
        save_total_limit=SAVE_TOTAL_LIMIT,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        bf16=USE_BF16,
        fp16=USE_FP16,
        gradient_checkpointing=GRADIENT_CHECKPOINTING,
        gradient_checkpointing_kwargs={"use_reentrant": False} if GRADIENT_CHECKPOINTING else None,
        report_to="wandb" if USE_WANDB else "none",
        run_name=f"{run_name}_phase2",
        remove_unused_columns=False,
    )

    # Data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=spider_train_dataset,
        eval_dataset=spider_eval_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    print("\n Starting Phase 2 training...")
    trainer.train()

    # Save
    print(f"\n Saving Phase 2 model to {phase2_output}/final")
    trainer.save_model(f"{phase2_output}/final")
    tokenizer.save_pretrained(f"{phase2_output}/final")

    print("\n Phase 2 complete!")
else:
    print("Skipping Phase 2 (Spider not available or epochs=0)")

In [None]:
if USE_WANDB:
    wandb.finish()
    print("✓ WandB run finished!")

print("\n" + "=" * 60)
print("🎉 TRAINING COMPLETE!")
print("=" * 60)
print(f"\nCheckpoints saved to: {OUTPUT_DIR}")
print(f"\nTo test the model, run:")
print(f"  python inference.py --model {OUTPUT_DIR}/phase2_spider/final")

### Finish Training

In [None]:
# Save the final model
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Model saved to {OUTPUT_DIR}")

# Finish WandB run
wandb.finish()
print("WandB run finished. Check your dashboard for loss graphs!")

Model saved to ./t5_wikisql_lora_finetuned


0,1
train/epoch,▁
train/global_step,▁

0,1
total_flos,9297442897920.0
train/epoch,3.0
train/global_step,3.0
train_loss,3.20488
train_runtime,3.245
train_samples_per_second,4.622
train_steps_per_second,0.924


WandB run finished. Check your dashboard for loss graphs!


### Download to local

In [None]:
DOWNLOAD_PATH = "./checkpoints/phase2_spider/final"

def zip_and_download(model_path):
    model_path = Path(model_path)
    if not model_path.exists():
        print(f" Path not found: {model_path}")
        return

    zip_name = f"{model_path.name}_lora_adapter"
    zip_path = f"/tmp/{zip_name}"

    print(f" Creating zip: {model_path}")
    shutil.make_archive(zip_path, 'zip', model_path.parent, model_path.name)

    zip_file = f"{zip_path}.zip"
    print(f"   Size: {os.path.getsize(zip_file)/(1024*1024):.1f} MB")

    try:
        from google.colab import files
        print(" Downloading...")
        files.download(zip_file)
    except ImportError:
        print(f"\n Download manually: {zip_file}")

zip_and_download(DOWNLOAD_PATH)

## Quick Test

In [None]:
# Quick test with the trained model
def generate_sql(question: str, schema: str, max_new_tokens: int = 256):
    """Generate SQL from question and schema."""
    system_msg = (
        "You are a SQL expert. Given a database schema and a natural language question, "
        "generate the correct SQL query. Output only the SQL query."
    )

    instruction = (
        "Given the following database schema and question, "
        "generate the SQL query that answers the question."
    )
    user_input = f"{instruction}\n\n{schema}\n\n[QUESTION]\n{question}"

    prompt = f"""<|im_start|>system
{system_msg}<|im_end|>
<|im_start|>user
{user_input}<|im_end|>
<|im_start|>assistant
[SQL]
"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )

    generated = tokenizer.decode(outputs[0], skip_special_tokens=False)

    # Extract SQL
    if "[SQL]" in generated:
        sql_part = generated.split("[SQL]")[-1]
        if "<|im_end|>" in sql_part:
            sql_part = sql_part.split("<|im_end|>")[0]
        return sql_part.strip()

    return generated.strip()

# Test example
test_schema = """[TABLES]
student:
    id (PK)
    name
    age
    major"""

test_question = "How many students are majoring in Computer Science?"

print("=" * 60)
print("QUICK TEST")
print("=" * 60)
print(f"\nQuestion: {test_question}")
print(f"\nSchema:\n{test_schema}")
print(f"\nGenerated SQL:")
sql = generate_sql(test_question, test_schema)
print(sql)

## Loading the Finetuned Model and Testing

In [None]:
# =============================================================================
# CONFIGURATION - Modify this path to your checkpoint
# =============================================================================
CHECKPOINT_PATH = "./checkpoints/phase2_spider/final"

# =============================================================================

def load_finetuned_model(adapter_path, base_model_name=None):
    adapter_path = Path(adapter_path)

    if base_model_name is None:
        config_path = adapter_path / "adapter_config.json"
        if config_path.exists():
            with open(config_path) as f:
                base_model_name = json.load(f).get("base_model_name_or_path", MODEL_NAME)
        else:
            base_model_name = MODEL_NAME

    print(f"Loading base model: {base_model_name}")

    quant_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,
    ) if LOAD_IN_4BIT else None

    tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True, padding_side="left")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name, quantization_config=quant_config,
        torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True,
    )

    model = PeftModel.from_pretrained(base_model, str(adapter_path))
    model.eval()
    print("✓ Model loaded!")
    return model, tokenizer

if Path(CHECKPOINT_PATH).exists():
    eval_model, eval_tokenizer = load_finetuned_model(CHECKPOINT_PATH)
else:
    print(f"⚠️ Checkpoint not found: {CHECKPOINT_PATH}")
    print("Using currently loaded model...")
    eval_model, eval_tokenizer = model, tokenizer

## Evaluate on Spider Dev Set

In [None]:
EVAL_MAX_SAMPLES = 100
EVAL_MAX_NEW_TOKENS = 256

def generate_sql_for_eval(model, tokenizer, question, schema):
    system_msg = "You are a SQL expert. Generate the correct SQL query."
    instruction = "Given the database schema and question, generate the SQL query."
    user_input = f"{instruction}\n\n{schema}\n\n[QUESTION]\n{question}"
    prompt = f"""<|im_start|>system\n{system_msg}<|im_end|>\n<|im_start|>user\n{user_input}<|im_end|>\n<|im_start|>assistant\n[SQL]\n"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=EVAL_MAX_NEW_TOKENS, temperature=0.1, do_sample=False, pad_token_id=tokenizer.pad_token_id)

    generated = tokenizer.decode(outputs[0], skip_special_tokens=False)
    if "[SQL]" in generated:
        sql_part = generated.split("[SQL]")[-1]
        if "<|im_end|>" in sql_part:
            sql_part = sql_part.split("<|im_end|>")[0]
        return sql_part.strip()
    return generated.strip()

def evaluate_model(model, tokenizer, eval_data, max_samples=None):
    if max_samples: eval_data = eval_data[:max_samples]
    results, correct = [], 0

    print(f"Evaluating on {len(eval_data)} samples...")
    for i, ex in enumerate(eval_data):
        pred_sql = generate_sql_for_eval(model, tokenizer, ex["question"], ex["schema"])
        gold_norm = " ".join(ex["sql"].lower().split())
        pred_norm = " ".join(pred_sql.lower().split())
        is_match = gold_norm == pred_norm
        if is_match: correct += 1
        results.append({"question": ex["question"], "gold": ex["sql"], "pred": pred_sql, "match": is_match})
        if (i+1) % 10 == 0:
            print(f"  [{i+1}/{len(eval_data)}] Accuracy: {100*correct/(i+1):.1f}%")

    return {"accuracy": 100*correct/len(eval_data), "correct": correct, "total": len(eval_data), "results": results}

if spider_dev_data:
    eval_results = evaluate_model(eval_model, eval_tokenizer, spider_dev_data, EVAL_MAX_SAMPLES)
    print(f"\n Exact Match Accuracy: {eval_results['accuracy']:.2f}%")
    print(f"   Correct: {eval_results['correct']} / {eval_results['total']}")
else:
    print("Spider dev data not loaded")

## Interactive testing

In [None]:
TEST_SCHEMA = """[DATABASE]\nuniversity\n\n[TABLES]\nstudent:\n    student_id (PK)\n    name\n    age\n    department_id (FK)\ncourse:\n    course_id (PK)\n    title\n    credits"""

TEST_QUESTIONS = [
    "How many students are there?",
    "What are the names of students in Computer Science?",
    "List all courses with more than 3 credits.",
]

print("=" * 60)
print("INTERACTIVE TESTING")
print("=" * 60)

for i, q in enumerate(TEST_QUESTIONS):
    print(f"\n[{i+1}] Q: {q}")
    sql = generate_sql_for_eval(eval_model, eval_tokenizer, q, TEST_SCHEMA)
    print(f"    SQL: {sql}")

## Tips for Better Results

1. **Model Comparison**:
   - `mrm8488/t5-base-finetuned-wikiSQL`: Smaller (~220M), faster, good for simple queries
   - `gaussalgo/T5-LM-Large-text2sql-spider`: Larger (~770M), better for complex queries

2. **Input Format**: WikiSQL model expects `translate English to SQL: <question> context: <columns>`

3. **More Data**: The sample dataset is tiny. Use hundreds or thousands of examples for real training.

4. **LoRA Rank**: Start with r=16. Increase to 32 or 64 if underfitting, decrease to 8 if overfitting.

5. **Learning Rate**: 1e-4 to 3e-4 typically works well for LoRA.

6. **Batch Size**: T5-base can use larger batches (8-16) compared to T5-large (4-8).