In [1]:
from datasets import load_dataset, DatasetDict, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer, SFTConfig, GRPOConfig, GRPOTrainer
import duckdb
import re

  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 05-22 00:34:33 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-22 00:34:33 [__init__.py:239] Automatically detected platform cuda.


2025-05-22 00:34:34,861	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


### Load the dataset

In [2]:
DATASET_REPO_ID = "proton98/sql-distill-llama-3-1-70b-instruct-reasoning"
train_dataset = load_dataset(DATASET_REPO_ID, split="train[:500]")

### Prompt template

In [3]:
REASONING_START = "<think>"
REASONING_END = "</think>"
SOLUTION_START = "<sql>"
SOLUTION_END = "</sql>"

SYSTEM_PROMPT_TEMPLATE = \
f"""You are an expert in writing optimized SQL queries.
Think about the problem and provide your working out.
Place it between {REASONING_START} and {REASONING_END}.
Then, provide your solution between {SOLUTION_START} and {SOLUTION_END}.

Context:
{{context}}""".rstrip()

### Process the dataset

In [4]:
sql_match_format = re.compile(rf"{SOLUTION_START}(.*?){SOLUTION_END}", re.DOTALL)
reasoning_match_format = re.compile(rf"{REASONING_START}(.*?){REASONING_END}", re.DOTALL)


def extract_schema_tables(sql: str) -> str:
    """
    Extracts the SQL using regex from the generated text.
    :param sql:
    :return:
    """
    try:
        con = duckdb.connect(database=':memory:')
        sql_statements = [stmt.strip() for stmt in sql.strip().split(';') if stmt.strip()]
        for statement in sql_statements:
            con.execute(statement)

        tables_info = con.execute(
            "SELECT table_schema, table_name FROM information_schema.tables WHERE table_type='BASE TABLE';").fetchall() # noqa

        context_lines = []
        for schema, table in tables_info:
            context_lines.append(f"Table: {table}")
            columns = con.execute(f"PRAGMA table_info('{table}')").fetchall()
            for col in columns:
                col_name, col_type = col[1], col[2]
                context_lines.append(f"  - {col_name}: {col_type}")
            context_lines.append("")
        return "\n".join(context_lines).strip()
    except Exception as e:
        return str(e)


def extract_sql(text: str) -> str | None:
    """
    Extracts the SQL using regex from the generated text.
    :param text:
    :return:
    """
    sql_match = sql_match_format.search(text)
    if sql_match:
        return sql_match.group(1).strip()


def extract_think(text: str) -> str | None:
    """
    Extracts the think using regex from the generated text.
    :param text:
    :return:
    """
    think_match = reasoning_match_format.search(text)
    if think_match:
        return think_match.group(1).strip()


def correct_reasoning_format(content: str) -> str:
    """
    Formatting prompt for exact match.
    :param content:
    :return:
    """
    __think = extract_think(content)
    __sql = extract_sql(content)
    return f"""{REASONING_START}\n{__think}\n{REASONING_END}\n{SOLUTION_START}\n{__sql}\n{SOLUTION_END}"""


def conversations_formatting(dataset: DatasetDict) -> Dataset | DatasetDict:
    """
    Format the conversations for the model.
    :param dataset:
    :return:
    """
    dataset = dataset.map(lambda x: {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT_TEMPLATE.format(context=extract_schema_tables(x["sql_context"])).strip()},
            {"role": "user", "content": x["sql_prompt"]},
        ],
        "questions": x["sql_prompt"],
        "contexts": x["sql_context"],
        "answers": extract_sql(x["generation"]),
    }, remove_columns=[
        k for k in dataset.column_names if k not in ["questions", "contexts", "answers", "prompt"]
    ])
    return dataset


def conversations_supervised_fine_tuning(dataset: DatasetDict) -> Dataset | DatasetDict:
    """
    Format the conversations for supervised fine-tuning.
    :param dataset:
    :return:
    """
    dataset = dataset.map(lambda x: {
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT_TEMPLATE.format(context=x["sql_context"]).strip()},
            {"role": "user", "content": x["sql_prompt"]},
            {"role": "assistant", "content": correct_reasoning_format(x["generation"])},
        ],
    }, remove_columns=[
        k for k in dataset.column_names if k not in ["messages"]
    ])
    return dataset


train_dataset_grpo = conversations_formatting(train_dataset)
train_dataset_sft = conversations_supervised_fine_tuning(train_dataset)

Map: 100%|██████████| 500/500 [00:13<00:00, 38.35 examples/s]


### Load the model and tokenizer

In [5]:
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
MAX_SEQ_LEN = 4096
MAX_LORA_RANK = 16
LORA_RANK = 16

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_NAME,
    max_seq_length = MAX_SEQ_LEN,
    load_in_4bit = True,
    fast_inference = True,
    max_lora_rank = MAX_LORA_RANK,
    gpu_memory_utilization = 0.7,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = LORA_RANK,
    lora_alpha = LORA_RANK,
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,
)

==((====))==  Unsloth 2025.5.6: Fast Llama patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA GeForce RTX 3060. Num GPUs = 1. Max memory: 12.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit with actual GPU utilization = 63.86%
Unsloth: Your GPU has CUDA compute capability 8.6 with VRAM = 12.0 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 4096. Num Sequences = 192.
Unsloth: vLLM's KV Cache can use up to 5.24 GB. Also swap space = 0 GB.
INFO 05-22 00:35:31 [config.py:717] This model supports multiple tasks: {'reward', 'score', 'generate', 'classify', 'embed'}. Defaulting to 'generate'.
INFO 05-22 00:35:32 [

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:02<00:00,  2.86s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:02<00:00,  2.87s/it]

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.06it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.06it/s]


INFO 05-22 00:35:43 [punica_selector.py:18] Using PunicaWrapperGPU.





INFO 05-22 00:35:43 [gpu_model_runner.py:1347] Model loading took 2.2969 GiB and 8.085926 seconds
INFO 05-22 00:35:58 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/a11df5d034/rank_0_0 for vLLM's torch.compile
INFO 05-22 00:35:58 [backends.py:430] Dynamo bytecode transform time: 14.76 s
INFO 05-22 00:36:06 [backends.py:118] Directly load the compiled graph(s) for shape None from the cache, took 7.505 s
INFO 05-22 00:36:10 [monitor.py:33] torch.compile takes 14.76 s in total
INFO 05-22 00:36:11 [kv_cache_utils.py:634] GPU KV cache size: 31,472 tokens
INFO 05-22 00:36:11 [kv_cache_utils.py:637] Maximum concurrency for 4,096 tokens per request: 7.68x
INFO 05-22 00:36:57 [gpu_model_runner.py:1686] Graph capturing finished in 47 secs, took 1.19 GiB
INFO 05-22 00:36:58 [core.py:159] init engine (profile, create kv cache, warmup model) took 74.83 seconds
Unsloth: Just some info: will skip parsing ['q_norm', 'pre_feedforward_layernorm', 'k_norm', 'post_feedforwa

Unsloth 2025.5.6 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


### Apply Chat Template

In [6]:
train_dataset_sft = train_dataset_sft.map(lambda x: {"text": tokenizer.apply_chat_template(x["messages"], batched=True, tokenize = False)})

In [7]:
train_dataset_sft[0]

{'messages': [{'content': "You are an expert in writing optimized SQL queries.\nThink about the problem and provide your working out.\nPlace it between <think> and </think>.\nThen, provide your solution between <sql> and </sql>.\n\nContext:\nCREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');",
   'role': 'system'},
  {'content': 'What is the total volume of timber sold by each salesperson, sorted by salesperson?',
   'role': 'user'},
  {'content': '<think>\nTo answer the question of what the total volume of timber sold by each salesperson is, sorted by salesperson, we need to follow a series of logical step

### Supervised fine-tuning

In [8]:
sft_trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset_sft,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 8,
        gradient_accumulation_steps = 1,
        warmup_steps = 5,
        num_train_epochs = 1,
        learning_rate = 2e-4,
        logging_steps = 5,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none",
    ),
)
sft_trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 500 | Num Epochs = 1 | Total steps = 63
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8
 "-____-"     Trainable parameters = 24,313,856/3,000,000,000 (0.81% trained)


Step,Training Loss
5,1.521
10,1.1031
15,0.771
20,0.6316
25,0.6118
30,0.5308
35,0.5783
40,0.5654
45,0.5732
50,0.5432


Unsloth: Will smartly offload gradients to save VRAM!


TrainOutput(global_step=63, training_loss=0.7027162313461304, metrics={'train_runtime': 343.5147, 'train_samples_per_second': 1.456, 'train_steps_per_second': 0.183, 'total_flos': 3100441869926400.0, 'train_loss': 0.7027162313461304})

In [9]:
text = tokenizer.apply_chat_template(
    train_dataset_sft[0]["messages"][:2],
    tokenize = False,
    add_generation_prompt = True, # Must add for generation
)
from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 0,
    max_new_tokens = 1024,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)

<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 22 May 2025

You are an expert in writing optimized SQL queries.
Think about the problem and provide your working out.
Place it between <think> and </think>.
Then, provide your solution between <sql> and </sql>.

Context:
CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');<|eot_id|><|start_header_id|>user<|end_header_id|>

What is the total volume of timber sold by each salesperson, sorted by salesperson?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

<think>
To solve this 

In [10]:
import torch
del train_dataset_sft
torch.cuda.empty_cache()
import gc
gc.collect()

1354

### Reward functions

In [11]:
COLORED_GREEN = "\033[92m"
COLORED_BLUE = "\033[94m"
COLORED_RESET = "\033[0m"
BOLD = "\033[1m"

response_match_format = re.compile(
    rf"^{REASONING_START}\n.*?\n{REASONING_END}\n{SOLUTION_START}\n.*?\n{SOLUTION_END}$",
    re.DOTALL | re.MULTILINE
)


def match_format_exactly(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: # noqa
    """
    Check if the completions match the expected format exactly.
    :param completions:
    :param kwargs:
    :return:
    """
    scores: list[float] = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        if response_match_format.search(response) is not None: score += 1.1
        scores.append(score)
    return scores


def match_format_approximately(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: # noqa
    """
    Check if the completions match the expected format approximately.
    :param completions:
    :param kwargs:
    :return:
    """
    scores: list[float] = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        score += 0.125 if response.count(f"{REASONING_START}\n")    == 1 else -1.0
        score += 0.125 if response.count(f"\n{REASONING_END}\n")    == 1 else -1.0
        score += 0.125 if response.count(f"\n{SOLUTION_START}\n")   == 1 else -1.0
        score += 0.125 if response.count(f"\n{SOLUTION_END}")       == 1 else -1.0
        scores.append(score)
    return scores


def validate_sql_query(sql_query: str, context: str) -> tuple[bool, str]:
    """
    Validate the SQL query against the provided context.

    Args:
        sql_query (str): The SQL query to validate.
        context (str): The context in which the SQL query is executed
            (e.g., table schema, data).
    Returns:
        tuple[bool, str]: Whether the SQL query is valid.
    """
    try:
        con = duckdb.connect(database=':memory:')

        context_statements = [stmt.strip() for stmt in context.strip().split(';') if stmt.strip()]
        for statement in context_statements:
            con.execute(statement)

        con.execute(sql_query)
        _ = con.fetchall()

        return True, "SQL query is valid and executed successfully."
    except Exception as e:
        return False, str(e)


def check_sql_reward(
    completions: list[list[dict[str, str]]],
    questions: list[str],
    answers: list[str],
    contexts: list[str],
    **kwargs # noqa
) -> list[float]:
    """
    Check the SQL reward for the given prompts and completions.
    :param completions:
    :param questions:
    :param answers:
    :param contexts:
    :param kwargs:
    :return:
    """
    responses = [completion[0]["content"] for completion in completions]
    pred_responses: list[str | None] = [extract_sql(r) for r in responses]

    print(
        '-'*20,
        f"{BOLD}Question:{COLORED_RESET} {questions[0]}{COLORED_RESET}",
        f"\n{BOLD}Response:{COLORED_RESET} \n{COLORED_GREEN}{responses[0]}{COLORED_RESET}",
        f"\n{BOLD}Real Answer:{COLORED_RESET} \n{COLORED_GREEN}{answers[0]}{COLORED_RESET}",
        f"\n{BOLD}Extracted:{COLORED_RESET} \n"
        f"{COLORED_BLUE}{pred_responses[0] if pred_responses[0] is not None else '-Invalid Format-'}{COLORED_RESET}",
    )

    scores: list[float] = []
    for pred, context, answer in zip(pred_responses, contexts, answers):
        score = 0
        if pred is None:
            scores.append(0)
            continue

        # Check if the SQL query is valid
        is_valid, _ = validate_sql_query(pred, context)

        # Check if the SQL query is similar to the true SQL
        # Correct answer gets 3 points
        if is_valid:
            # Exact match
            if pred == answer: score += 3.0
            # Match if spaces are seen
            elif pred.strip() == answer.strip(): score += 1.5
        else: score -= 1.5 # Penalty for incorrect SQL
        scores.append(score)
    return scores

### Training

In [12]:
MAX_PROMPT_LEN = 2048
MAX_COMPLETION_LEN = 1024

training_args = GRPOConfig(
    use_vllm = True,
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 1,
    num_generations = 4,
    max_prompt_length = MAX_PROMPT_LEN,
    max_completion_length = MAX_COMPLETION_LEN,
    max_steps = 50,
    save_steps = 10,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs",
)

In [13]:
grpo_trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [ # type: ignore
        match_format_exactly,
        match_format_approximately,
        check_sql_reward,
    ],
    args = training_args,
    train_dataset = train_dataset_grpo,
)
grpo_trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 500 | Num Epochs = 1 | Total steps = 50
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8
 "-____-"     Trainable parameters = 24,313,856/3,000,000,000 (0.81% trained)


-------------------- [1mQuestion:[0m Delete health equity metrics for Arizona in 2019[0m 
[1mResponse:[0m 
[92m<think>
To delete health equity metrics for Arizona in 2019, I need to consider the following factors:
1. I must filter the records to include only those from the state 'Arizona'. This is because the instruction explicitly states that I should delete metrics for Arizona, and I assume 'Arizona' is the relevant state.
2. I need to ensure that the records are from the year 2019. This is because the instruction specifically mentions 2019, and I need to exclude records from any other year.
3. The instruction does not mention any other criteria or conditions that could affect the deletion of records, so I will assume that the only relevant conditions are state and year.

Given these considerations, I will use a filter clause to exclude records from the state 'Arizona' and year 2019. This will effectively delete the relevant records from the database.
</think>
<sql>
DELETE FROM

Step,Training Loss,reward,reward_std,completion_length,kl,rewards / match_format_exactly,rewards / match_format_approximately,rewards / check_sql_reward
1,0.0092,1.0375,0.808013,311.625,0.229944,1.1,0.5,-0.5625
2,0.0082,0.759375,0.18125,405.5,0.205422,0.9625,0.359375,-0.5625
3,0.0099,3.1,1.5,308.125,0.248246,1.1,0.5,1.5
4,0.0107,0.85,0.866025,405.0,0.266733,1.1,0.5,-0.75
5,0.0099,0.009375,0.767529,397.625,0.24722,0.9625,0.359375,-1.3125
6,0.0089,1.4125,1.376742,375.5,0.223467,1.1,0.5,-0.1875
7,0.0109,1.4125,0.375,370.0,0.27279,1.1,0.5,-0.1875
8,0.0087,0.1,0.0,362.0,0.216333,1.1,0.5,-1.5
9,0.0099,1.6,0.0,366.125,0.248262,1.1,0.5,0.0
10,0.0087,0.196875,0.880363,386.75,0.216799,0.9625,0.359375,-1.125


-------------------- [1mQuestion:[0m Create a table named 'training_programs'[0m 
[1mResponse:[0m 
[92m<think>
To create a table named 'training_programs', we need to follow a series of steps. First, we identify the instruction which clearly states that we should create a table named 'training_programs'. This instruction does not provide any specific columns or data for the table, so we can infer that the objective is to create an empty table with the specified name.

The instruction does not provide any requirements for the table's structure or content, so we will assume that the table should have the following columns:

- id: This is likely the primary key for the table, as most databases use a unique identifier for each row.
- program_name: This column will store the name of the training program.
- start_date: This column will store the start date of the training program.
- end_date: This column will store the end date of the training program.
- department: This column will st

TrainOutput(global_step=50, training_loss=0.009207687238231302, metrics={'train_runtime': 2009.5061, 'train_samples_per_second': 0.199, 'train_steps_per_second': 0.025, 'total_flos': 0.0, 'train_loss': 0.009207687238231302})