# 09. LoRA for Linear Regression and GRPO on Verifiable Sorting

This tutorial extends the previous notebooks by combining **Low-Rank Adapters (LoRA)** with a lightweight reinforcement learning pipeline on verifiable tasks. We begin with a concrete linear regression example that highlights the memory advantages of LoRA on a single layer. We then move to a verifiable reinforcement learning from rewards (RLVR) setting where we adapt a Qwen 3 0.6B model with **Group Relative Policy Optimization (GRPO)** to sort lists of integers using the [PEFT](https://huggingface.co/docs/peft/index) library while explicitly modelling `<think>` reasoning tokens and structured `<output>` answers.


## Roadmap

1. Refresh the intuition for LoRA and quantify how low-rank adapters reduce the memory footprint of a single linear layer.
2. Implement the adapter for a synthetic multi-target linear regression problem and compare full fine-tuning vs. LoRA.
3. Build a verifiable sorting reward, warm-start the policy with a small cold-start dataset of `<think>/<output>` exemplars, run a short supervised fine-tuning (SFT) stage, and drive a GRPO loop with PEFT to adapt Qwen 3 0.6B.


>
💡 **Dependencies**

If you are running in a clean environment you may need to install a few extra packages such as `transformers`, `datasets`, `peft`, `trl`, and `accelerate`. The cell below can be uncommented when necessary.


In [None]:
# %pip install -q torch transformers datasets accelerate peft trl evaluate


## 1. Revisiting LoRA on a Single Linear Layer

LoRA decomposes the weight update of a frozen matrix `W` into a product `BA` where `A \in \mathbb{R}^{r \times d}` and `B \in \mathbb{R}^{m \times r}`. Instead of storing gradients and optimizer states for the full `m \times d` matrix, we only update the low-rank factors. The effective weight during adaptation is

$$W_{\text{eff}} = W + \frac{\alpha}{r} BA,$$

where `α` rescales the update. For wide layers (large `m` and `d`) and small rank `r`, this reduces the number of trainable parameters and the accompanying optimizer state by orders of magnitude.


### 1.1 Synthetic regression setup

We construct a multi-target linear regression task with a 512 → 256 linear layer. The ground-truth weight matrix is the sum of a frozen base matrix and a low-rank update, mirroring the scenario where LoRA is expected to shine.


In [None]:
import math
import random
from dataclasses import dataclass
from contextlib import nullcontext
from typing import Dict, Iterable, List, Sequence, Tuple

import pandas as pd
import plotly.graph_objects as go
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset

torch.manual_seed(0)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
in_features = 512
out_features = 256
lora_rank = 8
num_samples = 4096
batch_size = 256
noise_std = 0.05

# Construct a frozen base weight and a low-rank update that represents the target task.
base_weight = torch.randn(out_features, in_features)
adapter_A_true = torch.randn(lora_rank, in_features)
adapter_B_true = torch.randn(out_features, lora_rank)
delta_weight = adapter_B_true @ adapter_A_true
target_weight = base_weight + delta_weight

features = torch.randn(num_samples, in_features)
targets = features @ target_weight.T + noise_std * torch.randn(num_samples, out_features)

dataset = TensorDataset(features, targets)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


### 1.2 Implementing a LoRA-augmented linear layer

The class below mirrors the adapter structure used in larger language models. Only the low-rank matrices `A` and `B` are trainable; the base weight stays frozen.


In [None]:
class LoRALinear(nn.Module):
    def __init__(self, base_weight: torch.Tensor, rank: int, alpha: float = 1.0, bias: bool = False):
        super().__init__()
        out_features, in_features = base_weight.shape
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.alpha = alpha
        # Frozen base weight
        self.weight = nn.Parameter(base_weight.clone())
        self.weight.requires_grad = False
        # Trainable low-rank factors
        self.A = nn.Parameter(torch.zeros(rank, in_features))
        self.B = nn.Parameter(torch.zeros(out_features, rank))
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.zeros_(self.B)
        self.scaling = alpha / max(rank, 1)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)

    def effective_weight(self) -> torch.Tensor:
        return self.weight + (self.B @ self.A) * self.scaling

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.linear(x, self.effective_weight(), self.bias)


In [None]:
def count_trainable_parameters(module: nn.Module) -> int:
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

def train_linear_module(module: nn.Module, loader: DataLoader, steps: int, lr: float) -> List[float]:
    module.to(DEVICE)
    module.train()
    optimizer = torch.optim.Adam([p for p in module.parameters() if p.requires_grad], lr=lr)
    history: List[float] = []
    iterator = iter(loader)
    for step in range(steps):
        try:
            batch = next(iterator)
        except StopIteration:
            iterator = iter(loader)
            batch = next(iterator)
        x, y = (tensor.to(DEVICE) for tensor in batch)
        optimizer.zero_grad()
        preds = module(x)
        loss = F.mse_loss(preds, y)
        loss.backward()
        optimizer.step()
        history.append(loss.item())
    return history

def evaluate_mse(module: nn.Module, features: torch.Tensor, targets: torch.Tensor) -> float:
    module.eval()
    with torch.no_grad():
        preds = module(features.to(DEVICE))
        loss = F.mse_loss(preds.cpu(), targets)
    return float(loss)

def relative_weight_error(module: nn.Module, target: torch.Tensor) -> float:
    if isinstance(module, LoRALinear):
        weight = module.effective_weight().detach()
    else:
        weight = module.weight.detach()
    return float(torch.norm(weight - target) / torch.norm(target))


In [None]:
full_linear = nn.Linear(in_features, out_features, bias=False)
full_linear.weight.data.copy_(base_weight.clone())

lora_linear = LoRALinear(base_weight=base_weight, rank=lora_rank, alpha=lora_rank)

full_history = train_linear_module(full_linear, loader, steps=200, lr=1e-3)
lora_history = train_linear_module(lora_linear, loader, steps=200, lr=5e-3)

full_mse = evaluate_mse(full_linear, features, targets)
lora_mse = evaluate_mse(lora_linear, features, targets)

full_error = relative_weight_error(full_linear, target_weight)
lora_error = relative_weight_error(lora_linear, target_weight)


In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=full_history, name="Full fine-tuning"))
fig.add_trace(go.Scatter(y=lora_history, name="LoRA (rank=8)"))
fig.update_layout(title="Training loss comparison", xaxis_title="Step", yaxis_title="MSE loss")
fig.show()

def format_mb(params: int, dtype=torch.float32) -> float:
    bytes_per_param = torch.finfo(dtype).bits // 8
    return params * bytes_per_param / (1024 ** 2)

results_table = pd.DataFrame([
    {
        "Model": "Full fine-tuning",
        "Trainable params": count_trainable_parameters(full_linear),
        "Approx optimizer state (MB)": format_mb(count_trainable_parameters(full_linear) * 2),
        "Final MSE": full_mse,
        "Relative weight error": full_error,
    },
    {
        "Model": "LoRA (rank=8)",
        "Trainable params": count_trainable_parameters(lora_linear),
        "Approx optimizer state (MB)": format_mb(count_trainable_parameters(lora_linear) * 2),
        "Final MSE": lora_mse,
        "Relative weight error": lora_error,
    },
])
results_table


LoRA matches the full fine-tuning loss while updating only a few thousand parameters. The optimizer state memory shrinks proportionally, which is critical when the base layer contains millions of parameters.


## 2. RLVR with GRPO on a Sorting Task

We now move from a single layer to a causal language model. The goal is to sort a list of integers — a domain where the reward can be **verified automatically**. We will:

* Load a small cold-start dataset of prompts and structured `<think>/<output>` answers.
* Run a lightweight supervised warm-start so the LoRA adapter learns to emit the reasoning and output tags.
* Apply a LoRA adapter to Qwen 3 0.6B with the [PEFT](https://huggingface.co/docs/peft/index) library.
* Implement a GRPO-style policy gradient loop that samples multiple completions per prompt and shapes the reward with verifiable checks.


### 2.1 Cold-start data and prompt construction

The helper below loads a JSONL file (also easy to host on the Hugging Face Hub) and augments it with synthetic permutations so that the policy has a warm start before RL. Every prompt enforces the structure “think inside `<think>...</think>` and answer inside `<output>...</output>` with a strict “No tools.” reminder.


In [None]:
import statistics
import re
from datasets import Dataset, concatenate_datasets, load_dataset

STRUCTURED_INSTRUCTIONS = (
    "First think between <think> and </think> tags and then provide a response as a sorted list and nothing else in <output> and </output> tags. No tools."
)

def render_numbers(numbers):
    return ', '.join(str(n) for n in numbers)

def build_prompt(numbers):
    return f"Sort the numbers [{render_numbers(numbers)}].
{STRUCTURED_INSTRUCTIONS}"

def build_response(numbers):
    sorted_numbers = sorted(numbers)
    reasoning = (
        f"I sort the {len(numbers)} numbers [{render_numbers(numbers)}] into ascending order [{render_numbers(sorted_numbers)}]."
    )
    return f"<think>{reasoning}</think><output>[{render_numbers(sorted_numbers)}]</output>"

def generate_synthetic_sorting(num_examples: int = 128, seed: int = 0) -> Dataset:
    rng = random.Random(seed)
    samples = []
    for _ in range(num_examples):
        length = rng.randint(3, 7)
        numbers = [rng.randint(-20, 30) for _ in range(length)]
        prompt = build_prompt(numbers)
        response = build_response(numbers)
        samples.append(
            {
                "prompt": prompt,
                "response": response,
                "numbers": numbers,
                "rationale": response.split('<output>')[0].replace('<think>', '').replace('</think>', '').strip(),
            }
        )
    return Dataset.from_list(samples)

def load_sorting_cold_start(local_path: str = 'data/sorting_cold_start.jsonl') -> Dataset:
    dataset = load_dataset('json', data_files=local_path, split='train')

    def add_metadata(example):
        numbers = [int(x) for x in re.findall(r'-?\d+', example['prompt'])]
        return {
            "numbers": numbers,
            "rationale": f"Sorting yields [{render_numbers(sorted(numbers))}].",
        }

    dataset = dataset.map(add_metadata)
    return dataset

cold_start_ds = load_sorting_cold_start()
synthetic_ds = generate_synthetic_sorting(192, seed=42)
training_ds = concatenate_datasets([cold_start_ds, synthetic_ds]).shuffle(seed=0)
len(training_ds), training_ds[0]


### 2.2 Supervised warm-start with structured reasoning tokens

Before optimising with RL we align the policy to the desired format by running a brief supervised fine-tuning pass on the combined cold-start and synthetic data. We mask the prompt tokens so that only the completion (the `<think>` rationale plus the `<output>` answer) contributes to the loss, ensuring the LoRA adapter reliably emits the control tokens.


In [None]:
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments

SFT_MAX_LENGTH = 256

def prepare_sft_example(example):
    prompt = example["prompt"].strip()
    response = example["response"].strip()
    text = f"{prompt}
{response}"
    tokenized = tokenizer(text, truncation=True, max_length=SFT_MAX_LENGTH)
    prompt_ids = tokenizer(prompt, add_special_tokens=False, truncation=True, max_length=SFT_MAX_LENGTH)["input_ids"]
    labels = tokenized["input_ids"][:]
    labels = labels.copy()
    prompt_len = min(len(prompt_ids), len(labels))
    for idx in range(prompt_len):
        labels[idx] = -100
    tokenized["labels"] = labels
    return tokenized

sft_dataset = training_ds.map(
    prepare_sft_example,
    remove_columns=training_ds.column_names,
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

sft_args = TrainingArguments(
    output_dir="logs/qwen3_sorting_sft",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    max_steps=60,
    logging_steps=10,
    fp16=torch.cuda.is_available(),
    report_to="none",
)

sft_trainer = Trainer(
    model=policy_model,
    args=sft_args,
    train_dataset=sft_dataset,
    data_collator=data_collator,
)

sft_trainer.train()
policy_model.to(DEVICE)
policy_model.eval()


### 2.3 Reward shaping with verifiable checks

Sorting is verifiable because we can deterministically extract integers from the prompt and the model output. The reward below blends several signals:

* **Exact match** — full credit when the completion equals the sorted list.
* **Monotonicity** — partial credit if the answer is sorted but numbers differ.
* **Prefix accuracy** — rewards early correct numbers to stabilise learning.
* **Coverage** — encourages the model to reuse the original numbers.
* **Format compliance** — bonus for emitting both `<think>` and `<output>` blocks.
* **Length penalty** — discourages hallucinating or dropping numbers.

The function also returns diagnostic components so we can reason about learning progress.


In [None]:
from collections import Counter

THINK_PATTERN = re.compile(r"<think>(.*?)</think>", re.IGNORECASE | re.DOTALL)
OUTPUT_PATTERN = re.compile(r"<output>(.*?)</output>", re.IGNORECASE | re.DOTALL)

def extract_numbers(text: str) -> List[int]:
    return [int(x) for x in re.findall(r'-?\d+', text)]

def get_section(text: str, pattern: re.Pattern) -> str:
    match = pattern.search(text)
    return match.group(1).strip() if match else ""

def sorting_reward(prompt: str, completion: str) -> Tuple[float, Dict[str, float]]:
    target_numbers = extract_numbers(prompt)
    target_sorted = sorted(target_numbers)
    output_section = get_section(completion, OUTPUT_PATTERN)
    if not output_section:
        return -1.0, {"exact": 0.0, "monotonic": 0.0, "prefix": 0.0, "coverage": 0.0, "format": 0.0}

    predicted_numbers = extract_numbers(output_section)
    if not predicted_numbers:
        return -1.0, {"exact": 0.0, "monotonic": 0.0, "prefix": 0.0, "coverage": 0.0, "format": 0.0}

    think_section = get_section(completion, THINK_PATTERN)
    format_score = 1.0 if think_section else 0.0

    length_penalty = -0.05 * abs(len(predicted_numbers) - len(target_sorted))

    target_counter = Counter(target_sorted)
    predicted_counter = Counter(predicted_numbers)
    coverage = sum((target_counter & predicted_counter).values()) / max(len(target_sorted), 1)

    monotonic = 1.0 if predicted_numbers == sorted(predicted_numbers) else 0.0

    prefix = 0.0
    for t, p in zip(target_sorted, predicted_numbers):
        if t == p:
            prefix += 1
        else:
            break
    prefix = prefix / max(len(target_sorted), 1)

    exact = 1.0 if predicted_numbers == target_sorted else 0.0

    reward = (
        0.55 * exact
        + 0.2 * monotonic
        + 0.1 * prefix
        + 0.1 * coverage
        + 0.05 * format_score
        + length_penalty
    )
    reward = float(max(-1.0, min(reward, 1.0)))
    return reward, {
        "exact": exact,
        "monotonic": monotonic,
        "prefix": prefix,
        "coverage": coverage,
        "format": format_score,
    }

example_prompt = training_ds[0]['prompt']
example_response = training_ds[0]['response']
sorting_reward(example_prompt, example_response)


### 2.4 Qwen 3 0.6B with PEFT LoRA

We attach a LoRA adapter to Qwen 3 0.6B so that only a few attention projections are updated while the base weights stay frozen. The tokenizer is augmented with `<think>`/`<output>` specials so the adapter can model the reasoning format, and the frozen reference model provides the KL anchor in the GRPO loss.


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

base_model_name = 'Qwen/Qwen3-0.6B'
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
special_tokens = {'additional_special_tokens': ['<think>', '</think>', '<output>', '</output>']}
tokenizer.add_special_tokens(special_tokens)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'

policy_model = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code=True)
policy_model.resize_token_embeddings(len(tokenizer))

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
    lora_dropout=0.05,
    bias='none',
    task_type='CAUSAL_LM',
)
policy_model = get_peft_model(policy_model, lora_config)
policy_model.print_trainable_parameters()

reference_model = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code=True)
reference_model.resize_token_embeddings(len(tokenizer))
reference_model.to(DEVICE)
reference_model.eval()
for param in reference_model.parameters():
    param.requires_grad = False


### 2.5 Tokenisation utilities and log-probabilities

GRPO needs log-probabilities for each sampled completion under both the policy and the frozen reference model. The helpers below combine prompts with completions, build attention masks that isolate the generated tokens, and return per-sequence log-probs.


In [None]:
def build_batch(prompts: Sequence[str], completions: Sequence[str], max_length: int = 192) -> Dict[str, torch.Tensor]:
    input_ids: List[List[int]] = []
    attention_masks: List[List[int]] = []
    completion_masks: List[List[int]] = []
    for prompt, completion in zip(prompts, completions):
        prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
        completion_ids = tokenizer.encode(completion, add_special_tokens=False) + [tokenizer.eos_token_id]
        combined = prompt_ids + completion_ids
        combined = combined[:max_length]
        completion_mask = [0] * len(prompt_ids) + [1] * len(completion_ids)
        completion_mask = completion_mask[:max_length]
        attention = [1] * len(combined)
        pad = max_length - len(combined)
        if pad > 0:
            combined += [tokenizer.pad_token_id] * pad
            attention += [0] * pad
            completion_mask += [0] * pad
        input_ids.append(combined)
        attention_masks.append(attention)
        completion_masks.append(completion_mask)
    return {
        'input_ids': torch.tensor(input_ids, dtype=torch.long, device=DEVICE),
        'attention_mask': torch.tensor(attention_masks, dtype=torch.long, device=DEVICE),
        'completion_mask': torch.tensor(completion_masks, dtype=torch.float, device=DEVICE),
    }

def sequence_logprobs(model: AutoModelForCausalLM, prompts: Sequence[str], completions: Sequence[str], max_length: int = 192, detach: bool = True):
    batch = build_batch(prompts, completions, max_length=max_length)
    context = torch.no_grad() if detach else nullcontext()
    with context:
        outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
        logits = outputs.logits[:, :-1, :]
        labels = batch['input_ids'][:, 1:]
        log_probs = F.log_softmax(logits, dim=-1)
        token_logprobs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
    mask = batch['completion_mask'][:, 1:]
    seq_logprobs = (token_logprobs * mask).sum(dim=-1)
    token_counts = mask.sum(dim=-1).clamp(min=1.0)
    if detach:
        return seq_logprobs.detach(), token_counts.detach()
    return seq_logprobs, token_counts


### 2.6 Sampling groups of completions

GRPO gathers multiple completions per prompt and uses their relative rewards as baselines. The function below returns a list of completions for each prompt and keeps track of all random seeds for reproducibility.


In [None]:
def sample_completions(model: AutoModelForCausalLM, prompts: Sequence[str], *, num_generations: int, max_new_tokens: int, temperature: float = 0.7, top_p: float = 0.9) -> List[List[str]]:
    completions: List[List[str]] = []
    for prompt in prompts:
        encoded = tokenizer(prompt, return_tensors='pt').to(DEVICE)
        outputs = model.generate(
            **encoded,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            num_return_sequences=num_generations,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        prompt_len = encoded['input_ids'].size(1)
        generated = outputs[:, prompt_len:]
        texts = tokenizer.batch_decode(generated, skip_special_tokens=True)
        completions.append([text.strip() for text in texts])
    return completions


### 2.7 GRPO update step

The update uses per-group baselines — the mean reward of all completions for a prompt — to reduce variance. We normalise advantages inside each group, apply a KL penalty with respect to the frozen reference model, and clip gradients.


In [None]:
@dataclass
class GRPOConfig:
    batch_size: int = 4
    num_generations: int = 4
    max_new_tokens: int = 48
    max_length: int = 192
    temperature: float = 0.7
    top_p: float = 0.9
    learning_rate: float = 5e-5
    kl_coef: float = 0.05
    gradient_clip: float = 1.0
    num_steps: int = 30

config = GRPOConfig()
optimizer = torch.optim.AdamW([p for p in policy_model.parameters() if p.requires_grad], lr=config.learning_rate)

def grpo_update(prompts: Sequence[str]) -> Dict[str, float]:
    policy_model.train()
    completions = sample_completions(
        policy_model,
        prompts,
        num_generations=config.num_generations,
        max_new_tokens=config.max_new_tokens,
        temperature=config.temperature,
        top_p=config.top_p,
    )
    flat_prompts: List[str] = []
    flat_completions: List[str] = []
    advantages: List[float] = []
    raw_rewards: List[float] = []
    for prompt, candidate_list in zip(prompts, completions):
        group_rewards: List[float] = []
        for completion in candidate_list:
            reward, _ = sorting_reward(prompt, completion)
            group_rewards.append(reward)
            flat_prompts.append(prompt)
            flat_completions.append(completion)
        mean_reward = sum(group_rewards) / len(group_rewards)
        std_reward = statistics.pstdev(group_rewards)
        denom = std_reward if std_reward > 1e-6 else 1.0
        for reward in group_rewards:
            advantages.append((reward - mean_reward) / denom)
            raw_rewards.append(reward)
    if not flat_prompts:
        return {'loss': 0.0, 'avg_reward': 0.0}

    policy_logprobs, token_counts = sequence_logprobs(
        policy_model,
        flat_prompts,
        flat_completions,
        max_length=config.max_length,
        detach=False,
    )
    ref_logprobs, _ = sequence_logprobs(
        reference_model,
        flat_prompts,
        flat_completions,
        max_length=config.max_length,
        detach=True,
    )
    advantages_tensor = torch.tensor(advantages, dtype=torch.float, device=policy_logprobs.device)
    lengths = token_counts.to(policy_logprobs.device)
    norm_policy = policy_logprobs / lengths
    norm_ref = ref_logprobs.to(policy_logprobs.device) / lengths
    loss_policy = -(advantages_tensor * norm_policy).mean()
    kl_term = (norm_policy - norm_ref).mean()
    loss = loss_policy + config.kl_coef * kl_term
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_model.parameters(), config.gradient_clip)
    optimizer.step()
    return {
        'loss': float(loss.detach().cpu()),
        'policy_term': float(loss_policy.detach().cpu()),
        'kl_term': float(kl_term.detach().cpu()),
        'avg_reward': float(sum(raw_rewards) / max(len(raw_rewards), 1)),
    }


### 2.8 Training loop with evaluation hooks

We iterate over random prompts, run a GRPO update, and periodically measure success on held-out prompts using greedy decoding.


In [None]:
def generate_completion(model: AutoModelForCausalLM, prompt: str, max_new_tokens: int = 48) -> str:
    encoded = tokenizer(prompt, return_tensors='pt').to(DEVICE)
    output = model.generate(
        **encoded,
        do_sample=False,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    prompt_len = encoded['input_ids'].size(1)
    completion = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
    return completion.strip()

def evaluate_success(model: AutoModelForCausalLM, dataset: Dataset, num_examples: int = 32) -> Dict[str, float]:
    subset = dataset.shuffle(seed=1234).select(range(min(num_examples, len(dataset))))
    successes = []
    rewards = []
    for example in subset:
        completion = generate_completion(model, example['prompt'])
        reward, components = sorting_reward(example['prompt'], completion)
        rewards.append(reward)
        successes.append(1.0 if components['exact'] == 1.0 else 0.0)
    return {
        'avg_reward': float(sum(rewards) / max(len(rewards), 1)),
        'exact_match_rate': float(sum(successes) / max(len(successes), 1)),
    }

rng = random.Random(0)
metrics_history: List[Dict[str, float]] = []
dataset_list = list(training_ds)
for step in range(config.num_steps):
    batch = rng.sample(dataset_list, k=min(config.batch_size, len(dataset_list)))
    prompts = [item['prompt'] for item in batch]
    metrics = grpo_update(prompts)
    if step % 5 == 0:
        eval_metrics = evaluate_success(policy_model, training_ds, num_examples=16)
        metrics.update({f'eval_{k}': v for k, v in eval_metrics.items()})
        print(f"Step {step}: {metrics}")
    metrics['step'] = step
    metrics_history.append(metrics)


### 2.9 Inspecting the tuned model

After training we can sample a few prompts and compare the base (reference) model with the LoRA-adapted policy.


In [None]:
sample_prompts = [example['prompt'] for example in training_ds.select(range(3))]
for prompt in sample_prompts:
    base_completion = generate_completion(reference_model, prompt)
    tuned_completion = generate_completion(policy_model, prompt)
    reward_base, _ = sorting_reward(prompt, base_completion)
    reward_tuned, _ = sorting_reward(prompt, tuned_completion)
    print('Prompt:', prompt)
    print('Reference model:', base_completion, f'(reward={reward_base:.2f})')
    print('LoRA + GRPO:', tuned_completion, f'(reward={reward_tuned:.2f})')
    print('-' * 80)


## Takeaways

* LoRA adapters collapse the fine-tuning footprint of a dense linear layer while maintaining accuracy when the required update is approximately low rank.
* Supervised warm-starting with structured `<think>/<output>` exemplars teaches the LoRA adapter to emit both reasoning tokens and the final sorted answer before RL.
* GRPO-style updates combined with PEFT adapters on Qwen 3 0.6B provide a practical recipe for reinforcement learning on consumer hardware.
