In [1]:
from datasets import load_dataset, DatasetDict, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import GRPOConfig, GRPOTrainer
import duckdb
import re

  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 05-19 21:37:57 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-19 21:37:57 [__init__.py:239] Automatically detected platform cuda.


2025-05-19 21:37:58,976	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


### Load the dataset

In [2]:
DATASET_REPO_ID = "proton98/sql-distill-llama-3-1-70b-instruct-reasoning"
dataset_train = load_dataset(DATASET_REPO_ID, split="train")

### Prompt template

In [3]:
REASONING_START = "<think>"
REASONING_END = "</think>"
SOLUTION_START = "<sql>"
SOLUTION_END = "</sql>"

REASONING_SYSTEM_PROMPT_TEMPLATE = \
f"""You are an expert in writing optimized SQL queries.
Think about the problem and provide your working out.
Place it between {REASONING_START} and {REASONING_END}.
Then, provide your solution between {SOLUTION_START} and {SOLUTION_END}.

Context:
{{context}}""".rstrip()

### Process the dataset

In [4]:
sql_match_format = re.compile(rf"{SOLUTION_START}(.*?){SOLUTION_END}", re.DOTALL)
reasoning_match_format = re.compile(rf"{REASONING_START}(.*?){REASONING_END}", re.DOTALL)

def extract_sql(text: str) -> str | None:
    """
    Extracts the SQL using regex from the generated text.
    :param text:
    :return:
    """
    sql_match = sql_match_format.search(text)
    if sql_match:
        return sql_match.group(1).strip()


def extract_think(text: str) -> str | None:
    """
    Extracts the think using regex from the generated text.
    :param text:
    :return:
    """
    think_match = reasoning_match_format.search(text)
    if think_match:
        return think_match.group(1).strip()


def _reasoning_format(content: str) -> str:
    """
    Formatting prompt for exact match.
    :param content:
    :return:
    """
    __think = extract_think(content)
    __sql = extract_sql(content)
    return f"""{REASONING_START}\n{__think}\n{REASONING_END}\n{SOLUTION_START}\n{__sql}\n{SOLUTION_END}"""


def conversations_formatting(dataset: DatasetDict) -> Dataset | DatasetDict:
    """
    Format the conversations for the model.
    :param dataset:
    :return:
    """
    dataset = dataset.map(lambda x: {
        "prompt": [
            {"role": "system", "content": REASONING_SYSTEM_PROMPT_TEMPLATE.format(context=x["sql_context"])},
            {"role": "user", "content": x["sql_prompt"]},
            {"role": "assistant", "content": _reasoning_format(x["generation"])},
        ],
        "questions": x["sql_prompt"],
        "contexts": x["sql_context"],
        "answers": extract_sql(x["generation"]),
    })
    return dataset


dataset_train = conversations_formatting(dataset_train)

### Load the model and tokenizer

In [5]:
MODEL_PATH_NAME = "unsloth/gemma-3-1b-it"
MAX_SEQ_LEN = 4096
MAX_LORA_RANK = 16

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_PATH_NAME,
    max_seq_length = MAX_SEQ_LEN,
    load_in_4bit = True,
    fast_inference = True,
    max_lora_rank = MAX_LORA_RANK,
    gpu_memory_utilization = 0.7,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,
)

==((====))==  Unsloth 2025.5.6: Fast Gemma3 patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA GeForce RTX 3060. Num GPUs = 1. Max memory: 12.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Making `model.base_model.model.model` require gradients


### Function rewards

In [None]:
COLORED_GREEN = "\033[92m"
COLORED_BLUE = "\033[94m"
COLORED_RESET = "\033[0m"
BOLD = "\033[1m"

response_match_format = re.compile(
    rf"^{REASONING_START}\n.*?\n{REASONING_END}\n{SOLUTION_START}\n.*?\n{SOLUTION_END}$",
    re.DOTALL | re.MULTILINE
)


def match_format_exactly(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: # noqa
    """
    Check if the completions match the expected format exactly.
    :param completions:
    :param kwargs:
    :return:
    """
    scores: list[float] = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        if response_match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores


def match_format_approximately(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: # noqa
    """
    Check if the completions match the expected format approximately.
    :param completions:
    :param kwargs:
    :return:
    """
    scores: list[float] = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        score += 0.125 if response.count(f"{REASONING_START}\n")    == 1 else -1.0
        score += 0.125 if response.count(f"\n{REASONING_END}\n")    == 1 else -1.0
        score += 0.125 if response.count(f"\n{SOLUTION_START}\n")   == 1 else -1.0
        score += 0.125 if response.count(f"\n{SOLUTION_END}")       == 1 else -1.0
        scores.append(score)
    return scores


def validate_sql_query(sql_query: str, context: str) -> tuple[bool, str]:
    """
    Validate the SQL query against the provided context.

    Args:
        sql_query (str): The SQL query to validate.
        context (str): The context in which the SQL query is executed
            (e.g., table schema, data).
    Returns:
        tuple[bool, str]: Whether the SQL query is valid.
    """
    try:
        con = duckdb.connect(database=':memory:')

        context_statements = [stmt.strip() for stmt in context.strip().split(';') if stmt.strip()]
        for statement in context_statements:
            con.execute(statement)

        con.execute(sql_query)
        _ = con.fetchall()

        return True, "SQL query is valid and executed successfully."
    except Exception as e:
        return False, str(e)


def check_sql_reward(
    prompts: list[list[dict[str, str]]],
    completions: list[list[dict[str, str]]],
    questions: list[str],
    answers: list[str],
    contexts: list[str],
    **kwargs # noqa
) -> list[float]:
    """
    Check the SQL reward for the given prompts and completions.
    :param prompts:
    :param completions:
    :param questions:
    :param answers:
    :param contexts:
    :param kwargs:
    :return:
    """
    responses = [completion[0]["content"] for completion in completions]
    pred_responses: list[str | None] = [extract_sql(r) for r in responses]

    print(
        '-'*20,
        f"{BOLD}Question:{COLORED_RESET} {questions[0]}{COLORED_RESET}",
        f"\n{BOLD}Response:{COLORED_RESET} \n{COLORED_GREEN}{responses[0]}{COLORED_RESET}",
        f"\n{BOLD}Real Answer:{COLORED_RESET} \n{COLORED_GREEN}{answers[0]}{COLORED_RESET}",
        f"\n{BOLD}Extracted:{COLORED_RESET} \n"
        f"{COLORED_BLUE}{pred_responses[0] if pred_responses[0] is not None else '-Invalid Format-'}{COLORED_RESET}",
    )

    scores: list[float] = []
    for pred, context, answer in zip(pred_responses, contexts, answers):
        score = 0
        if pred is None:
            scores.append(0)
            continue

        # Check if the SQL query is valid
        is_valid, _ = validate_sql_query(pred, context)

        # Check if the SQL query is similar to the true SQL
        # Correct answer gets 3 points
        if is_valid:
            # Exact match
            if pred == answer: score += 3.0
            # Match if spaces are seen
            elif pred.strip() == answer.strip(): score += 1.5
        else: score -= 1.5 # Penalty for incorrect SQL
        scores.append(score)
    return scores


### Training

In [None]:
MAX_PROMPT_LEN = 2048
MAX_COMPLETION_LEN = 1024

training_args = GRPOConfig(
    use_vllm = True,
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 1,
    num_generations = 4,
    max_prompt_length = MAX_PROMPT_LEN,
    max_completion_length = MAX_COMPLETION_LEN,
    max_steps = 100,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs",
)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [ # type: ignore
        match_format_exactly,
        match_format_approximately,
        check_sql_reward,
    ],
    args = training_args,
    train_dataset = dataset_train,
)
trainer.train()