In [None]:
from unsloth import FastLanguageModel
import torch
import re

In [None]:
MAX_SEQ_LEN = 1024
MAX_PROMPT_LENGTH = 256
MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct"

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_NAME,
    max_seq_length = MAX_SEQ_LEN,
    load_in_4bit = True,
    load_in_8bit = False,
)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    max_seq_length = MAX_SEQ_LEN,
    use_rslora = False,
    loftq_config = None,
)

In [None]:
from datasets import load_dataset
dataset = load_dataset("proton98/sql-gpt-4.1-nano-distill-reasoning-data", split="train")

In [None]:
dataset.shuffle(seed=42).select(range(10))[0]["generation"]

In [None]:
def extract_answer(text: str) -> str | None:
    sql_match = re.search(r"<sql>(.*?)</sql>", text, re.DOTALL)
    if sql_match:
        return sql_match.group(1).strip()

In [None]:
def extract_schema(sql: str) -> str:
    table_pattern = re.compile(
        r"CREATE TABLE (\w+)\s*\((.*?)\);", re.DOTALL | re.IGNORECASE
    )

    column_pattern = re.compile(r"\s*(\w+)\s+([\w()]+)")

    schema_lines: list[str] = ["Tables:"]

    for match in table_pattern.finditer(sql):
        table_name, columns_block = match.groups()
        schema_lines.append(f"\n- {table_name}")
        for col_line in columns_block.strip().split(","):
            col_line = col_line.strip()
            col_match = column_pattern.match(col_line)
            if col_match:
                col_name, col_type = col_match.groups()
                schema_lines.append(f"  - {col_name} ({col_type.upper()})")

    return "\n".join(schema_lines)

In [None]:
extract_answer(dataset.shuffle(seed=42).select(range(10))[0]["generation"][0])

In [None]:
REASONING_START = "<think>"
REASONING_END = "</think>"
SOLUTION_START = "<sql>"
SOLUTION_END = "</sql>"

SYSTEM_PROMPT = \
f"""You are an expert in writing optimized SQL queries.
Think about the problem and provide your working out.
Place it between {REASONING_START} and {REASONING_END}.
Then, provide your solution between {SOLUTION_START} and {SOLUTION_END}.

Context:
{{context}}"""
SYSTEM_PROMPT

In [None]:
dataset = dataset.map(
    lambda x: {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT.format(context=extract_schema(x["sql_context"]))},
            {"role": "user", "content": x["sql_prompt"]},
        ],
        "answer": extract_answer(next(iter(x["generation"]))),
    },
)

In [None]:
match_format = re.compile(
    rf"\s*{REASONING_START}(?P<thinking>.+?){REASONING_END}\s*"
    rf"{SOLUTION_START}(?P<sql>.+?){SOLUTION_END}\s*",
    flags=re.DOTALL
)

In [None]:
match_format.search(
f"""
{REASONING_START}
Hello
{REASONING_END}
{SOLUTION_START}
SELECT * FROM table WHERE condition;
{SOLUTION_END}
"""
)

In [None]:
def match_format_exactly(completions: list[dict], **kwargs):
    scores: list[float] = []
    for completion in completions:
        score = 0
        response = next(iter(completion))["content"]
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

In [None]:
def match_format_approximately(completions, **kwargs):
    scores: list[float] = []
    for completion in completions:
        score = 0
        response: str = next(iter(completion))["content"]
        score += 0.5 if response.count(REASONING_START) == 1 else -0.5
        score += 0.5 if response.count(REASONING_END) == 1 else -0.5
        score += 0.5 if response.count(SOLUTION_START) == 1 else -0.5
        score += 0.5 if response.count(SOLUTION_END) == 1 else -0.5
        scores.append(score)
    return scores

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
    responses = [next(iter(completion))["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores: list[float] = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(.0)
            continue
        if guess == true_answer:
            score += 3.0
        elif guess.strip() == true_answer.strip():
            score += 1.5
        scores.append(score)
    return scores

In [None]:
from trl import GRPOConfig, GRPOTrainer

In [None]:
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = MAX_PROMPT_LENGTH,
    max_completion_length = MAX_SEQ_LEN - MAX_PROMPT_LENGTH,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 50,
    save_steps = 50,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_answer,
    ],
    args = training_args,
    train_dataset = dataset,
)

In [None]:
trainer.train()