In [2]:
from datasets import load_dataset, DatasetDict, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer, SFTConfig, GRPOConfig, GRPOTrainer
import duckdb
import re

### Load the dataset

In [3]:
DATASET_REPO_ID = "proton98/sql-distill-llama-3-1-70b-instruct-reasoning"
train_dataset = load_dataset(DATASET_REPO_ID, split="train[:100]")

### Prompt template

In [4]:
REASONING_START = "<think>"
REASONING_END = "</think>"
SOLUTION_START = "<sql>"
SOLUTION_END = "</sql>"

SYSTEM_PROMPT_TEMPLATE = \
f"""You are an expert in writing optimized SQL queries.
Think about the problem and provide your working out.
Place it between {REASONING_START} and {REASONING_END}.
Then, provide your solution between {SOLUTION_START} and {SOLUTION_END}.

Context:
---
{{context}}
---""".rstrip()

### Process the dataset

In [5]:
sql_match_format = re.compile(rf"{SOLUTION_START}(.*?){SOLUTION_END}", re.DOTALL)
reasoning_match_format = re.compile(rf"{REASONING_START}(.*?){REASONING_END}", re.DOTALL)

def extract_sql(text: str) -> str | None:
    """
    Extracts the SQL using regex from the generated text.
    :param text:
    :return:
    """
    sql_match = sql_match_format.search(text)
    if sql_match:
        return sql_match.group(1).strip()


def extract_think(text: str) -> str | None:
    """
    Extracts the think using regex from the generated text.
    :param text:
    :return:
    """
    think_match = reasoning_match_format.search(text)
    if think_match:
        return think_match.group(1).strip()


def correct_reasoning_format(content: str) -> str:
    """
    Formatting prompt for exact match.
    :param content:
    :return:
    """
    __think = extract_think(content)
    __sql = extract_sql(content)
    return f"""{REASONING_START}\n{__think}\n{REASONING_END}\n{SOLUTION_START}\n{__sql}\n{SOLUTION_END}"""


def conversations_formatting(dataset: DatasetDict) -> Dataset | DatasetDict:
    """
    Format the conversations for the model.
    :param dataset:
    :return:
    """
    dataset = dataset.map(lambda x: {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT_TEMPLATE.format(context=x["sql_context"]).strip()},
            {"role": "user", "content": x["sql_prompt"]},
        ],
        "questions": x["sql_prompt"],
        "contexts": x["sql_context"],
        "answers": extract_sql(x["generation"]),
    }, remove_columns=[
        k for k in dataset.column_names if k not in ["questions", "contexts", "answers", "prompt"]
    ])
    return dataset


def conversations_supervised_fine_tuning(dataset: DatasetDict) -> Dataset | DatasetDict:
    """
    Format the conversations for supervised fine-tuning.
    :param dataset:
    :return:
    """
    dataset = dataset.map(lambda x: {
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT_TEMPLATE.format(context=x["sql_context"]).strip()},
            {"role": "user", "content": x["sql_prompt"]},
            {"role": "assistant", "content": correct_reasoning_format(x["generation"])},
        ],
    }, remove_columns=[
        k for k in dataset.column_names if k not in ["messages"]
    ])
    return dataset


train_dataset_grpo = conversations_formatting(train_dataset)
train_dataset_sft = conversations_supervised_fine_tuning(train_dataset)

Map: 100%|██████████| 100/100 [00:00<00:00, 5529.66 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 6772.98 examples/s]


### Load the model and tokenizer

In [6]:
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
MAX_SEQ_LEN = 4096
MAX_LORA_RANK = 16
LORA_RANK = 16

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_NAME,
    max_seq_length = MAX_SEQ_LEN,
    load_in_4bit = True,
    fast_inference = True,
    max_lora_rank = MAX_LORA_RANK,
    gpu_memory_utilization = 0.7,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = LORA_RANK,
    lora_alpha = LORA_RANK,
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,
)

==((====))==  Unsloth 2025.5.6: Fast Llama patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA GeForce RTX 3060. Num GPUs = 1. Max memory: 12.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit with actual GPU utilization = 63.86%
Unsloth: Your GPU has CUDA compute capability 8.6 with VRAM = 12.0 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 4096. Num Sequences = 192.
Unsloth: vLLM's KV Cache can use up to 5.24 GB. Also swap space = 0 GB.
INFO 05-20 01:59:33 [config.py:717] This model supports multiple tasks: {'score', 'embed', 'generate', 'reward', 'classify'}. Defaulting to 'generate'.
INFO 05-20 01:59:33 [

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:02<00:00,  2.87s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:02<00:00,  2.87s/it]

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.27it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.26it/s]


INFO 05-20 01:59:45 [punica_selector.py:18] Using PunicaWrapperGPU.





INFO 05-20 01:59:45 [gpu_model_runner.py:1347] Model loading took 2.2969 GiB and 9.422128 seconds
INFO 05-20 01:59:59 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/a11df5d034/rank_0_0 for vLLM's torch.compile
INFO 05-20 01:59:59 [backends.py:430] Dynamo bytecode transform time: 13.82 s
INFO 05-20 02:00:07 [backends.py:118] Directly load the compiled graph(s) for shape None from the cache, took 6.871 s
INFO 05-20 02:00:11 [monitor.py:33] torch.compile takes 13.82 s in total
INFO 05-20 02:00:11 [kv_cache_utils.py:634] GPU KV cache size: 31,472 tokens
INFO 05-20 02:00:11 [kv_cache_utils.py:637] Maximum concurrency for 4,096 tokens per request: 7.68x
INFO 05-20 02:00:58 [gpu_model_runner.py:1686] Graph capturing finished in 46 secs, took 1.19 GiB
INFO 05-20 02:00:58 [core.py:159] init engine (profile, create kv cache, warmup model) took 72.86 seconds
Unsloth: Just some info: will skip parsing ['q_norm', 'pre_feedforward_layernorm', 'post_feedforward_layerno

Unsloth 2025.5.6 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


### Apply Chat Template

In [7]:
train_dataset_sft = train_dataset_sft.map(lambda x: {"text": tokenizer.apply_chat_template(x["messages"], batched=True, tokenize = False)})

Map: 100%|██████████| 100/100 [00:00<00:00, 3510.20 examples/s]


In [8]:
train_dataset_sft[0]

{'messages': [{'content': "You are an expert in writing optimized SQL queries.\nThink about the problem and provide your working out.\nPlace it between <think> and </think>.\nThen, provide your solution between <sql> and </sql>.\n\nContext:\n---\nCREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');\n---",
   'role': 'system'},
  {'content': 'What is the total volume of timber sold by each salesperson, sorted by salesperson?',
   'role': 'user'},
  {'content': '<think>\nTo answer the question of what the total volume of timber sold by each salesperson is, sorted by salesperson, we need to follow a series of lo

### Supervised fine-tuning

In [9]:
sft_trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset_sft,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 8,
        gradient_accumulation_steps = 1,
        warmup_steps = 5,
        num_train_epochs = 1,
        learning_rate = 2e-4,
        logging_steps = 5,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none",
    ),
)
sft_trainer.train()

Unsloth: Tokenizing ["text"] (num_proc=12): 100%|██████████| 100/100 [00:03<00:00, 28.10 examples/s]
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 100 | Num Epochs = 1 | Total steps = 13
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8
 "-____-"     Trainable parameters = 24,313,856/3,000,000,000 (0.81% trained)


Step,Training Loss
5,1.5144
10,1.2108


Unsloth: Will smartly offload gradients to save VRAM!


TrainOutput(global_step=13, training_loss=1.2801625178410456, metrics={'train_runtime': 58.8642, 'train_samples_per_second': 1.699, 'train_steps_per_second': 0.221, 'total_flos': 556659810385920.0, 'train_loss': 1.2801625178410456})

In [10]:
text = tokenizer.apply_chat_template(
    train_dataset_sft[0]["messages"][:2],
    tokenize = False,
    add_generation_prompt = True, # Must add for generation
)
from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 0,
    max_new_tokens = 1024,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)

<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 20 May 2025

You are an expert in writing optimized SQL queries.
Think about the problem and provide your working out.
Place it between <think> and </think>.
Then, provide your solution between <sql> and </sql>.

Context:
---
CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');
---<|eot_id|><|start_header_id|>user<|end_header_id|>

What is the total volume of timber sold by each salesperson, sorted by salesperson?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

<think>
To sol

In [11]:
import torch
del train_dataset_sft
torch.cuda.empty_cache()
import gc
gc.collect()

373

### Reward functions

In [12]:
COLORED_GREEN = "\033[92m"
COLORED_BLUE = "\033[94m"
COLORED_RESET = "\033[0m"
BOLD = "\033[1m"

response_match_format = re.compile(
    rf"^{REASONING_START}\n.*?\n{REASONING_END}\n{SOLUTION_START}\n.*?\n{SOLUTION_END}$",
    re.DOTALL | re.MULTILINE
)


def match_format_exactly(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: # noqa
    """
    Check if the completions match the expected format exactly.
    :param completions:
    :param kwargs:
    :return:
    """
    scores: list[float] = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        if response_match_format.search(response) is not None: score += 1.1
        scores.append(score)
    return scores


def match_format_approximately(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: # noqa
    """
    Check if the completions match the expected format approximately.
    :param completions:
    :param kwargs:
    :return:
    """
    scores: list[float] = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        score += 0.125 if response.count(f"{REASONING_START}\n")    == 1 else -1.0
        score += 0.125 if response.count(f"\n{REASONING_END}\n")    == 1 else -1.0
        score += 0.125 if response.count(f"\n{SOLUTION_START}\n")   == 1 else -1.0
        score += 0.125 if response.count(f"\n{SOLUTION_END}")       == 1 else -1.0
        scores.append(score)
    return scores


def validate_sql_query(sql_query: str, context: str) -> tuple[bool, str]:
    """
    Validate the SQL query against the provided context.

    Args:
        sql_query (str): The SQL query to validate.
        context (str): The context in which the SQL query is executed
            (e.g., table schema, data).
    Returns:
        tuple[bool, str]: Whether the SQL query is valid.
    """
    try:
        con = duckdb.connect(database=':memory:')

        context_statements = [stmt.strip() for stmt in context.strip().split(';') if stmt.strip()]
        for statement in context_statements:
            con.execute(statement)

        con.execute(sql_query)
        _ = con.fetchall()

        return True, "SQL query is valid and executed successfully."
    except Exception as e:
        return False, str(e)


def check_sql_reward(
    completions: list[list[dict[str, str]]],
    questions: list[str],
    answers: list[str],
    contexts: list[str],
    **kwargs # noqa
) -> list[float]:
    """
    Check the SQL reward for the given prompts and completions.
    :param completions:
    :param questions:
    :param answers:
    :param contexts:
    :param kwargs:
    :return:
    """
    responses = [completion[0]["content"] for completion in completions]
    pred_responses: list[str | None] = [extract_sql(r) for r in responses]

    print(
        '-'*20,
        f"{BOLD}Question:{COLORED_RESET} {questions[0]}{COLORED_RESET}",
        f"\n{BOLD}Response:{COLORED_RESET} \n{COLORED_GREEN}{responses[0]}{COLORED_RESET}",
        f"\n{BOLD}Real Answer:{COLORED_RESET} \n{COLORED_GREEN}{answers[0]}{COLORED_RESET}",
        f"\n{BOLD}Extracted:{COLORED_RESET} \n"
        f"{COLORED_BLUE}{pred_responses[0] if pred_responses[0] is not None else '-Invalid Format-'}{COLORED_RESET}",
    )

    scores: list[float] = []
    for pred, context, answer in zip(pred_responses, contexts, answers):
        score = 0
        if pred is None:
            scores.append(0)
            continue

        # Check if the SQL query is valid
        is_valid, _ = validate_sql_query(pred, context)

        # Check if the SQL query is similar to the true SQL
        # Correct answer gets 3 points
        if is_valid:
            # Exact match
            if pred == answer: score += 3.0
            # Match if spaces are seen
            elif pred.strip() == answer.strip(): score += 1.5
        else: score -= 1.5 # Penalty for incorrect SQL
        scores.append(score)
    return scores

### Training

In [13]:
MAX_PROMPT_LEN = 2048
MAX_COMPLETION_LEN = 1024

training_args = GRPOConfig(
    use_vllm = True,
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 1,
    num_generations = 4,
    max_prompt_length = MAX_PROMPT_LEN,
    max_completion_length = MAX_COMPLETION_LEN,
    max_steps = 500,
    save_steps = 150,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs",
)

In [None]:
grpo_trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [ # type: ignore
        match_format_exactly,
        match_format_approximately,
        check_sql_reward,
    ],
    args = training_args,
    train_dataset = train_dataset_grpo,
)
grpo_trainer.train()