In [1]:
from datasets import load_dataset, DatasetDict, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import GRPOConfig, GRPOTrainer
import duckdb
import re

  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 05-19 23:00:06 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-19 23:00:07 [__init__.py:239] Automatically detected platform cuda.


2025-05-19 23:00:08,527	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


### Load the dataset

In [2]:
DATASET_REPO_ID = "proton98/sql-distill-llama-3-1-70b-instruct-reasoning"
train_dataset = load_dataset(DATASET_REPO_ID, split="train")

### Prompt template

In [3]:
REASONING_START = "<think>"
REASONING_END = "</think>"
SOLUTION_START = "<sql>"
SOLUTION_END = "</sql>"

SYSTEM_PROMPT_TEMPLATE = \
f"""You are an expert in writing optimized SQL queries.
Think about the problem and provide your working out.
Place it between {REASONING_START} and {REASONING_END}.
Then, provide your solution between {SOLUTION_START} and {SOLUTION_END}.

Respond in the following format:
{REASONING_START}
{REASONING_END}
{SOLUTION_START}
{SOLUTION_END}

Context:
{{context}}""".rstrip()

### Process the dataset

In [4]:
sql_match_format = re.compile(rf"{SOLUTION_START}(.*?){SOLUTION_END}", re.DOTALL)
reasoning_match_format = re.compile(rf"{REASONING_START}(.*?){REASONING_END}", re.DOTALL)

def extract_sql(text: str) -> str | None:
    """
    Extracts the SQL using regex from the generated text.
    :param text:
    :return:
    """
    sql_match = sql_match_format.search(text)
    if sql_match:
        return sql_match.group(1).strip()


def extract_think(text: str) -> str | None:
    """
    Extracts the think using regex from the generated text.
    :param text:
    :return:
    """
    think_match = reasoning_match_format.search(text)
    if think_match:
        return think_match.group(1).strip()


def _reasoning_format(content: str) -> str:
    """
    Formatting prompt for exact match.
    :param content:
    :return:
    """
    __think = extract_think(content)
    __sql = extract_sql(content)
    return f"""{REASONING_START}\n{__think}\n{REASONING_END}\n{SOLUTION_START}\n{__sql}\n{SOLUTION_END}"""


def conversations_formatting(dataset: DatasetDict) -> Dataset | DatasetDict:
    """
    Format the conversations for the model.
    :param dataset:
    :return:
    """
    dataset = dataset.map(lambda x: {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT_TEMPLATE.format(context=x["sql_context"], instruction=x["sql_prompt"]).strip()},
            {"role": "assistant", "content": _reasoning_format(x["generation"])},
        ],
        "questions": x["sql_prompt"],
        "contexts": x["sql_context"],
        "answers": extract_sql(x["generation"]),
    })
    return dataset


train_dataset = conversations_formatting(train_dataset)

Map: 100%|██████████| 5000/5000 [00:00<00:00, 9004.38 examples/s]


### Load the model and tokenizer

In [6]:
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
MAX_SEQ_LEN = 4096
MAX_LORA_RANK = 16
LORA_RANK = 16

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_NAME,
    max_seq_length = MAX_SEQ_LEN,
    load_in_4bit = True,
    fast_inference = True,
    max_lora_rank = MAX_LORA_RANK,
    gpu_memory_utilization = 0.7,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = LORA_RANK,
    lora_alpha = LORA_RANK,
    target_modules = ["gate_proj", "up_proj", "down_proj",],
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,
)

==((====))==  Unsloth 2025.5.6: Fast Llama patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA GeForce RTX 3060. Num GPUs = 1. Max memory: 12.0 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit with actual GPU utilization = 63.86%
Unsloth: Your GPU has CUDA compute capability 8.6 with VRAM = 12.0 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 4096. Num Sequences = 192.
Unsloth: vLLM's KV Cache can use up to 5.24 GB. Also swap space = 0 GB.
INFO 05-19 23:01:34 [config.py:717] This model supports multiple tasks: {'classify', 'embed', 'score', 'generate', 'reward'}. Defaulting to 'generate'.
INFO 05-19 23:01:34 [

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.79s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.79s/it]

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.23it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.23it/s]


INFO 05-19 23:01:41 [punica_selector.py:18] Using PunicaWrapperGPU.





INFO 05-19 23:01:42 [gpu_model_runner.py:1347] Model loading took 2.2969 GiB and 5.153025 seconds
INFO 05-19 23:01:56 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/a11df5d034/rank_0_0 for vLLM's torch.compile
INFO 05-19 23:01:56 [backends.py:430] Dynamo bytecode transform time: 13.97 s
INFO 05-19 23:02:04 [backends.py:118] Directly load the compiled graph(s) for shape None from the cache, took 7.330 s
INFO 05-19 23:02:08 [monitor.py:33] torch.compile takes 13.97 s in total
INFO 05-19 23:02:09 [kv_cache_utils.py:634] GPU KV cache size: 31,472 tokens
INFO 05-19 23:02:09 [kv_cache_utils.py:637] Maximum concurrency for 4,096 tokens per request: 7.68x
INFO 05-19 23:02:56 [gpu_model_runner.py:1686] Graph capturing finished in 47 secs, took 1.19 GiB
INFO 05-19 23:02:56 [core.py:159] init engine (profile, create kv cache, warmup model) took 74.50 seconds
Unsloth: Just some info: will skip parsing ['k_norm', 'pre_feedforward_layernorm', 'q_norm', 'post_feedforwa

Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Unsloth 2025.5.6 patched 28 layers with 0 QKV layers, 0 O layers and 28 MLP layers.


### Function rewards

In [7]:
COLORED_GREEN = "\033[92m"
COLORED_BLUE = "\033[94m"
COLORED_RESET = "\033[0m"
BOLD = "\033[1m"

response_match_format = re.compile(
    rf"^{REASONING_START}\n.*?\n{REASONING_END}\n{SOLUTION_START}\n.*?\n{SOLUTION_END}$",
    re.DOTALL | re.MULTILINE
)


def match_format_exactly(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: # noqa
    """
    Check if the completions match the expected format exactly.
    :param completions:
    :param kwargs:
    :return:
    """
    scores: list[float] = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        if response_match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores


def match_format_approximately(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: # noqa
    """
    Check if the completions match the expected format approximately.
    :param completions:
    :param kwargs:
    :return:
    """
    scores: list[float] = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        score += 0.125 if response.count(f"{REASONING_START}\n")    == 1 else -1.0
        score += 0.125 if response.count(f"\n{REASONING_END}\n")    == 1 else -1.0
        score += 0.125 if response.count(f"\n{SOLUTION_START}\n")   == 1 else -1.0
        score += 0.125 if response.count(f"\n{SOLUTION_END}")       == 1 else -1.0
        scores.append(score)
    return scores


def validate_sql_query(sql_query: str, context: str) -> tuple[bool, str]:
    """
    Validate the SQL query against the provided context.

    Args:
        sql_query (str): The SQL query to validate.
        context (str): The context in which the SQL query is executed
            (e.g., table schema, data).
    Returns:
        tuple[bool, str]: Whether the SQL query is valid.
    """
    try:
        con = duckdb.connect(database=':memory:')

        context_statements = [stmt.strip() for stmt in context.strip().split(';') if stmt.strip()]
        for statement in context_statements:
            con.execute(statement)

        con.execute(sql_query)
        _ = con.fetchall()

        return True, "SQL query is valid and executed successfully."
    except Exception as e:
        return False, str(e)


def check_sql_reward(
    completions: list[list[dict[str, str]]],
    questions: list[str],
    answers: list[str],
    contexts: list[str],
    **kwargs # noqa
) -> list[float]:
    """
    Check the SQL reward for the given prompts and completions.
    :param completions:
    :param questions:
    :param answers:
    :param contexts:
    :param kwargs:
    :return:
    """
    responses = [completion[0]["content"] for completion in completions]
    pred_responses: list[str | None] = [extract_sql(r) for r in responses]

    print(
        '-'*20,
        f"{BOLD}Question:{COLORED_RESET} {questions[0]}{COLORED_RESET}",
        f"\n{BOLD}Response:{COLORED_RESET} \n{COLORED_GREEN}{responses[0]}{COLORED_RESET}",
        f"\n{BOLD}Real Answer:{COLORED_RESET} \n{COLORED_GREEN}{answers[0]}{COLORED_RESET}",
        f"\n{BOLD}Extracted:{COLORED_RESET} \n"
        f"{COLORED_BLUE}{pred_responses[0] if pred_responses[0] is not None else '-Invalid Format-'}{COLORED_RESET}",
    )

    scores: list[float] = []
    for pred, context, answer in zip(pred_responses, contexts, answers):
        score = 0
        if pred is None:
            scores.append(0)
            continue

        # Check if the SQL query is valid
        is_valid, _ = validate_sql_query(pred, context)

        # Check if the SQL query is similar to the true SQL
        # Correct answer gets 3 points
        if is_valid:
            # Exact match
            if pred == answer: score += 3.0
            # Match if spaces are seen
            elif pred.strip() == answer.strip(): score += 1.5
        else: score -= 1.5 # Penalty for incorrect SQL
        scores.append(score)
    return scores


### Training

In [8]:
MAX_PROMPT_LEN = 2048
MAX_COMPLETION_LEN = 1024

training_args = GRPOConfig(
    use_vllm = True,
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 1,
    num_generations = 4,
    max_prompt_length = MAX_PROMPT_LEN,
    max_completion_length = MAX_COMPLETION_LEN,
    max_steps = 500,
    save_steps = 150,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs",
)

In [9]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [ # type: ignore
        match_format_exactly,
        match_format_approximately,
        check_sql_reward,
    ],
    args = training_args,
    train_dataset = train_dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 5,000 | Num Epochs = 1 | Total steps = 500
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8
 "-____-"     Trainable parameters = 15,138,816/3,000,000,000 (0.50% trained)


-------------------- [1mQuestion:[0m Delete health equity metrics for Arizona in 2019[0m 
[1mResponse:[0m 
[92m<think>
To delete health equity metrics for Arizona in 2019, I need to make sure that it doesn't delete any data that doesn't exist for the specified combination of state and year. This is because there might be a few states and years with missing data, and we should avoid deleting those.

However, since there is no guarantee that rows with NULL values will be deleted automatically by default, we should explicitly use the instruction to do so or use the ON CONFLICT (or WHERE) clause to specify exactly which rows to delete.

Additionally, to ensure that we're deleting all rows where state = 'AZ' and year = 2019, I'll be using the WHERE clause along with a WHERE NOT EXISTS or WHERE IN combination.

</think>
<sql>
DELETE 
FROM health_equity_metrics 
WHERE state = 'AZ' AND year = 2019;
</sql>[0m 
[1mReal Answer:[0m 
[92mDELETE FROM health_equity_metrics WHERE state = 'AZ

Step,Training Loss,reward,reward_std,completion_length,kl,rewards / match_format_exactly,rewards / match_format_approximately,rewards / check_sql_reward
1,0.0,0.03125,1.726624,257.75,0.0,0.375,-0.34375,0.0
2,-0.0,-0.15625,0.9375,217.125,0.0,0.0,0.21875,-0.375
3,0.0,-0.859375,0.761541,228.625,0.000552,0.375,-0.484375,-0.75
4,0.0,0.265625,1.710652,173.75,0.000643,0.75,-0.484375,0.0
5,0.0,-0.34375,1.264521,253.875,0.000498,0.75,-0.34375,-0.75
6,0.0,-0.390625,1.744051,235.375,0.000537,0.75,-0.765625,-0.375
7,0.0,-0.4375,0.917136,231.0,0.000535,0.375,-0.0625,-0.75
8,0.0,-0.484375,1.358356,203.875,0.000641,0.0,-0.484375,0.0
9,0.0,-0.484375,1.739139,254.0,0.000496,0.375,-0.484375,-0.375
10,0.0,-1.046875,2.226107,224.125,0.0006,0.375,-1.046875,-0.375


-------------------- [1mQuestion:[0m List all gas wells in the South Pars/North Dome field and their production figures[0m 
[1mResponse:[0m 
[92m<think>
To solve this problem, we need to filter the gas wells to those located in the South Pars/North Dome field and retrieve their gas production figures.
We can achieve this using a simple SELECT statement that joins the gas_wells table with a WHERE clause to filter the wells.

We should consider using INNER JOIN to only retrieve the wells in the desired field, or LEFT JOIN if we also want to include wells with no production figure. However, in this case, since we're interested in wells with production figures, INNER JOIN seems more suitable.

Additionally, we might want to specify the field explicitly to make the query more explicit, but since the wells in the same field have the same location, we could use the WHERE clause to filter for that.

We could also consider using a more precise query to filter by multiple conditions, but i

KeyboardInterrupt: 