In [1]:
from unsloth import FastLanguageModel
import torch
from huggingface_hub import login
from datasets import load_dataset
from dotenv import load_dotenv
import os
import wandb
from datetime import datetime
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
from transformers import EarlyStoppingCallback
from transformers.trainer_callback import TrainerCallback
import json
from tqdm import tqdm
import numpy as np
from unsloth import is_bfloat16_supported
import gc

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth Zoo will now patch everything to make training faster!


In [2]:
load_dotenv()
login(os.getenv("HF_TOKEN_WRITE"))

In [3]:
# Constants

# Model
BASE_MODEL = "unsloth/Qwen2.5-7B-instruct"
MAX_SEQ_LENGTH = 4096
DTYPE = torch.bfloat16
LOAD_IN_4BIT = True

# Project
HF_USER = "Yihim"
PROJECT_NAME = "sql_expert"
RUN_NAME = f"{datetime.now():%Y-%m-%d_%H.%M.%S}"
PROJECT_RUN_NAME = f"{PROJECT_NAME}--{RUN_NAME}"

# LoRA
LORA_R = 16
LORA_ALPHA = 32
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
LORA_DROPOUT = 0.1

# Training
MAX_STEPS = 500
TRAIN_BATCH_SIZE = 4
EVAL_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 2e-5
LR_SCHEDULER_TYPE = "cosine"
WARMUP_RATIO = 0.1
WARMUP_STEPS = 500
LOG_STEPS = 1
SAVE_STEPS = 2000

In [4]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=DTYPE,
    load_in_4bit=LOAD_IN_4BIT
)

  GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f"cuda:{i}") for i in range(n_gpus)])


==((====))==  Unsloth 2025.3.8: Fast Qwen2 patching. Transformers: 4.49.0. vLLM: 0.7.3.
   \\   /|    NVIDIA GeForce RTX 4080 SUPER. Num GPUs = 1. Max memory: 15.992 GB. Platform: Windows.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 2/2 [00:08<00:00,  4.30s/it]


In [5]:
model.get_memory_footprint() / 1e9

7.075802112

In [6]:
model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584, padding_idx=151654)
    (layers): ModuleList(
      (0-1): 2 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear4bit(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear4bit(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear4bit(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)

In [7]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model.generation_config.pad_token_id = tokenizer.pad_token_id

In [8]:
model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_R,
    target_modules=TARGET_MODULES,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None
)

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.1.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.3.8 patched 28 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


In [9]:
wandb.init(project=PROJECT_NAME, name=RUN_NAME)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: yihimchan (yihimchan-personal) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


In [9]:
sql_dataset = load_dataset("gretelai/synthetic_text_to_sql")

In [10]:
sql_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
        num_rows: 100000
    })
    test: Dataset({
        features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
        num_rows: 5851
    })
})

In [11]:
test = sql_dataset["test"]

In [12]:
train_size = int(0.95 * len(sql_dataset["train"]))

In [13]:
indices = np.random.permutation(len(sql_dataset["train"]))
train_indices = indices[:train_size]
val_indices = indices[train_size:]

In [14]:
train = sql_dataset["train"].select(train_indices)
val = sql_dataset["train"].select(val_indices)

In [15]:
INSTRUCTION = """You are a specialized SQL query generator that helps users write efficient SQL queries. 
Your role is to analyze the database schema in the `sql_context` and generate the appropriate SQL code with explanation that answers the `sql_prompt`.
Both `sql_context` and `sql_prompt` are given by the user.

### Input Example:
sql_context: CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');

sql_prompt: "What is the total volume of timber sold by each salesperson, sorted by salesperson?"

### Output Example:
SQL: SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;

Explanation: Joins timber_sales and salesperson tables, groups sales by salesperson, calculates total volume sold by each salesperson, and orders the results by total volume in descending order."""

In [16]:
INPUT_EXAMPLE = """sql_context: {sql_context}

sql_prompt: {sql_prompt}"""

In [17]:
OUTPUT_EXAMPLE = """SQL: {sql}

Explanation: {sql_explanation}"""

## Use alpaca format

In [18]:
# ALPACA_PROMPT = """Below is an instruction that describes a task, paired with an input that provides futther contenxt. Write a response that appropriately completes the request.

# ### Instruction:
# {instruction}

# ### Input:
# {input_example}

# ### Response:
# {output_example}"""

In [19]:
# EOS_TOKEN = tokenizer.eos_token

# def format_example(examples):
#     sql_contexts = examples["sql_context"]
#     sql_prompts = examples["sql_prompt"]
#     sqls = examples["sql"]
#     sql_explanations = examples["sql_explanation"]
#     texts = []
#     for sql_context, sql_prompt, sql, sql_explanation in zip(sql_contexts, sql_prompts, sqls, sql_explanations):
#         instruction = INSTRUCTION
#         input_example = INPUT_EXAMPLE.format(sql_context=sql_context, sql_prompt=sql_prompt)
#         output_example = OUTPUT_EXAMPLE.format(sql=sql, sql_explanation=sql_explanation)
#         text = ALPACA_PROMPT.format(instruction=instruction, input_example=input_example, output_example=output_example) + EOS_TOKEN
#         texts.append(text)
#     return {"text": texts}

In [20]:
# train = train.map(format_example, batched=True)
# val = val.map(format_example, batched=True)
# test = test.map(format_example, batched=True)

In [21]:
# print(val["text"][0])

In [22]:
# response_template = "### Response:\n"
# collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

In [23]:
def format_example(examples):
    sql_context = examples["sql_context"]
    sql_prompt = examples["sql_prompt"]
    sql = examples["sql"]
    sql_explanation = examples["sql_explanation"]
    messages = [
            {"role": "system", "content": INSTRUCTION},
            {"role": "user", "content": INPUT_EXAMPLE.format(sql_context=sql_context, sql_prompt=sql_prompt)},
            {"role": "assistant", "content": OUTPUT_EXAMPLE.format(sql=sql, sql_explanation=sql_explanation)}
        ]
    applied_template = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    return {"text": applied_template}

In [24]:
train = train.map(format_example)
val = val.map(format_example)
test = test.map(format_example)

In [25]:
val["text"][0]

'<|im_start|>system\nYou are a specialized SQL query generator that helps users write efficient SQL queries. \nYour role is to analyze the database schema in the `sql_context` and generate the appropriate SQL code with explanation that answers the `sql_prompt`.\nBoth `sql_context` and `sql_prompt` are given by the user.\n\n### Input Example:\nsql_context: CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, \'John Doe\', \'North\'), (2, \'Jane Smith\', \'South\'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, \'2021-01-01\'), (2, 1, 150, \'2021-02-01\'), (3, 2, 180, \'2021-01-01\');\n\nsql_prompt: "What is the total volume of timber sold by each salesperson, sorted by salesperson?"\n\n### Output Example:\nSQL: SELECT salesperson_id, name, SUM(volume) as total_volume FROM

In [33]:
from datasets import DatasetDict

In [34]:
curated_dataset_dict = DatasetDict(
    {"train": train,
    "val": val,
    "test": test}
)

In [35]:
curated_dataset_dict.push_to_hub(f"{HF_USER}/synthetic_text_to_sql-qwen2.5-instruct-curated")

Uploading the dataset shards:   0%|                                                              | 0/1 [00:00<?, ?it/s]
[Aating parquet from Arrow format:   0%|                                                       | 0/95 [00:00<?, ?ba/s]
[Aating parquet from Arrow format:  22%|█████████▉                                   | 21/95 [00:00<00:00, 209.90ba/s]
[Aating parquet from Arrow format:  44%|███████████████████▉                         | 42/95 [00:00<00:00, 205.79ba/s]
[Aating parquet from Arrow format:  66%|█████████████████████████████▊               | 63/95 [00:00<00:00, 207.60ba/s]
Creating parquet from Arrow format: 100%|█████████████████████████████████████████████| 95/95 [00:00<00:00, 209.01ba/s]

[A%|                                                                                            | 0/1 [00:00<?, ?it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:14<00:00, 14.22s/it]
Uploading the dataset shards: 100%|████

CommitInfo(commit_url='https://huggingface.co/datasets/Yihim/synthetic_text_to_sql-qwen2.5-instruct-curated/commit/92930713d1f10a6e2ade79914e84b0af2cb5a1d8', commit_message='Upload dataset', commit_description='', oid='92930713d1f10a6e2ade79914e84b0af2cb5a1d8', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/Yihim/synthetic_text_to_sql-qwen2.5-instruct-curated', endpoint='https://huggingface.co', repo_type='dataset', repo_id='Yihim/synthetic_text_to_sql-qwen2.5-instruct-curated'), pr_revision=None, pr_num=None)

In [86]:
response_template = "<|im_start|>assistant\n"
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

In [66]:
train_parameters = SFTConfig(
    num_train_epochs=2,
    output_dir=PROJECT_RUN_NAME,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    eval_strategy="steps",
    eval_steps=500,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    logging_steps=LOG_STEPS,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    fp16= not is_bfloat16_supported(),
    bf16= is_bfloat16_supported(),
    warmup_ratio=WARMUP_RATIO,
    group_by_length=True,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    report_to="wandb",
    run_name=RUN_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    save_strategy="steps",
    hub_strategy="end",
    push_to_hub=True,
    hub_model_id=f"{HF_USER}/unsloth-qwen2.5-7b-instruct-text-to-sql-v1",
    hub_private_repo=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    seed=3407
)

In [29]:
class MemoryCleanupCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        gc.collect()
        torch.cuda.empty_cache()

In [30]:
memory_cleanup_callback = MemoryCleanupCallback()

In [31]:
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.001)

In [32]:
fine_tuning = SFTTrainer(
    dataset_num_proc=2,
    model=model,
    train_dataset=train,
    eval_dataset=val,
    processing_class=tokenizer,
    args=train_parameters,
    max_seq_length=MAX_SEQ_LENGTH,
    packing=False,
    data_collator=collator,
    # callbacks=[early_stopping_callback, memory_cleanup_callback]
)

Tokenizing to ["text"] (num_proc=2): 100%|██████████████████████████████| 95000/95000 [01:02<00:00, 1524.62 examples/s]
Tokenizing to ["text"] (num_proc=2): 100%|█████████████████████████████████| 5000/5000 [00:11<00:00, 419.34 examples/s]


In [33]:
gc.collect()
torch.cuda.empty_cache()

In [34]:
fine_tuning_stats = fine_tuning.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 95,000 | Num Epochs = 2 | Total steps = 11,874
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 40,370,176/4,931,917,312 (0.82% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
500,0.9793,0.871987
1000,1.1992,0.872043
1500,1.0486,0.872973
2000,1.0381,0.871487
2500,1.2212,0.871695
3000,0.9463,0.869801
3500,0.938,0.871059
4000,1.0797,0.7918
4500,0.6391,0.771342
5000,0.9032,0.752175


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Unsloth: Not an error, but Qwen2ForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead

In [35]:
fine_tuning_stats

TrainOutput(global_step=11874, training_loss=0.7826375707052805, metrics={'train_runtime': 79821.1322, 'train_samples_per_second': 2.38, 'train_steps_per_second': 0.149, 'total_flos': 4.4383198831549563e+18, 'train_loss': 0.7826375707052805})

In [37]:
fine_tuning.model.push_to_hub(repo_id=f"{HF_USER}/unsloth-qwen2.5-7b-instruct-text-to-sql-v1", private=True)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.91s/it]


Saved model to https://huggingface.co/Yihim/unsloth-qwen2.5-7b-instruct-text-to-sql-v1


In [38]:
fine_tuning.processing_class.push_to_hub(repo_id=f"{HF_USER}/unsloth-qwen2.5-7b-instruct-text-to-sql-v1", private=True)

  0%|                                                                                            | 0/1 [00:00<?, ?it/s]
tokenizer.json:   0%|                                                                      | 0.00/11.4M [00:00<?, ?B/s]
tokenizer.json:   3%|█▊                                                            | 328k/11.4M [00:00<00:03, 3.00MB/s]
tokenizer.json:  10%|█████▊                                                       | 1.10M/11.4M [00:00<00:01, 5.47MB/s]
tokenizer.json:  14%|████████▊                                                    | 1.65M/11.4M [00:00<00:05, 1.82MB/s]
tokenizer.json:  70%|██████████████████████████████████████████▌                  | 7.98M/11.4M [00:00<00:00, 12.2MB/s]
tokenizer.json: 100%|█████████████████████████████████████████████████████████████| 11.4M/11.4M [00:03<00:00, 3.60MB/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.46s/it]


In [39]:
fine_tuning.model.save_pretrained("unsloth-qwen2.5-7b-instruct-text-to-sql-v1")
fine_tuning.processing_class.save_pretrained("unsloth-qwen2.5-7b-instruct-text-to-sql-v1")

('unsloth-qwen2.5-7b-instruct-text-to-sql-v1\\tokenizer_config.json',
 'unsloth-qwen2.5-7b-instruct-text-to-sql-v1\\special_tokens_map.json',
 'unsloth-qwen2.5-7b-instruct-text-to-sql-v1\\vocab.json',
 'unsloth-qwen2.5-7b-instruct-text-to-sql-v1\\merges.txt',
 'unsloth-qwen2.5-7b-instruct-text-to-sql-v1\\added_tokens.json',
 'unsloth-qwen2.5-7b-instruct-text-to-sql-v1\\tokenizer.json')

In [40]:
fine_tuned_model = FastLanguageModel.for_inference(fine_tuning.model)
fine_tuned_tokenizer = fine_tuning.processing_class

In [106]:
test

Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation', 'text'],
    num_rows: 5851
})

In [107]:
def format_test_example(examples):
    sql_context = examples["sql_context"]
    sql_prompt = examples["sql_prompt"]
    sql = examples["sql"]
    sql_explanation = examples["sql_explanation"]
    messages = [
            {"role": "system", "content": INSTRUCTION},
            {"role": "user", "content": INPUT_EXAMPLE.format(sql_context=sql_context, sql_prompt=sql_prompt)},
        ]
    return {"messages": messages, "expected_output": OUTPUT_EXAMPLE.format(sql=sql, sql_explanation=sql_explanation)}

In [108]:
test = test.map(format_test_example)

Map: 100%|███████████████████████████████████████████████████████████████| 5851/5851 [00:00<00:00, 12133.95 examples/s]


In [128]:
print(test["messages"][10])

[{'content': 'You are a specialized SQL query generator that helps users write efficient SQL queries. \nYour role is to analyze the database schema in the `sql_context` and generate the appropriate SQL code with explanation that answers the `sql_prompt`.\nBoth `sql_context` and `sql_prompt` are given by the user.\n\n### Input Example:\nsql_context: CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, \'John Doe\', \'North\'), (2, \'Jane Smith\', \'South\'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, \'2021-01-01\'), (2, 1, 150, \'2021-02-01\'), (3, 2, 180, \'2021-01-01\');\n\nsql_prompt: "What is the total volume of timber sold by each salesperson, sorted by salesperson?"\n\n### Output Example:\nSQL: SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber

In [123]:
from transformers import TextStreamer

In [124]:
text_streamer = TextStreamer(skip_prompt=True, tokenizer=fine_tuned_tokenizer)

In [125]:
input_ids = fine_tuned_tokenizer.apply_chat_template(test["messages"][10], add_generation_prompt=True, return_tensors="pt").to("cuda")

In [126]:
_ = fine_tuned_model.generate(input_ids, streamer=text_streamer, max_new_tokens=1024, pad_token_id=fine_tuned_tokenizer.eos_token_id)

SQL: SELECT COUNT(dapp_id) AS total_downloads FROM dapp_ranking WHERE dapp_region = 'Asia-Pacific';

Explanation: The SQL query counts the number of decentralized applications (dapps) that have been downloaded from the 'Asia-Pacific' region by filtering the `dapp_ranking` table where `dapp_region` equals 'Asia-Pacific'. The result will give you the total number of dapps downloaded from this region.<|im_end|>


In [127]:
test["expected_output"][10]

"SQL: SELECT SUM(dapp_downloads) FROM dapp_ranking WHERE dapp_region = 'Asia-Pacific';\n\nExplanation: The SQL query calculates the total number of downloads for all decentralized applications from the 'Asia-Pacific' region by summing the 'dapp_downloads' values for all records with the 'dapp_region' value of 'Asia-Pacific'."