### GRPO Fine-Tuning with MLX LM

In this notebook, we'll walk through how to fine-tune an LLM with **Group Relative Policy Optimization (GRPO)** using MLX LM. GRPO is a Reinforcement Learning algorithm similar to PPO. We'll use the [HellaSwag](https://rowanzellers.com/hellaswag/) dataset for common sense reasoning as an example. An outline:

1. Download the dataset and prepare it for the GRPO loop.
2. Setup and run GRPO training. We will implement the full RL loop, including rollout, reward calculation, and optimization with the PPO-clip objective.
3. Evaluate the final accuracy on the test set.
4. Fuse the resulting adapters into the base model.
5. Discuss tips for debugging accuracy and efficiency.

Note: This notebook currently does not have an implementation for the reward function. Instead, it has a dummy reward logic:

```python
reward = 1.0 if batch_answers[i] in response else 0.0
```

I will add a reward function to this notebook later on and inform again when done, but please feel free to file a pull request if you would like to contribute in anyway.

### Install dependencies

In [None]:
!pip install mlx-lm
!pip install matplotlib

### Preprocess Data
We'll start by downloading an already pre-processed version of the HellaSwag dataset from [LLM-Adapters](https://github.com/AGI-Edgerunners/LLM-Adapters).

In [None]:
import json
import numpy as np
from pathlib import Path
from urllib import request

save_dir = "/tmp/hellaswag"

def download_and_save(save_dir):
    base_url = "https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/main/dataset/hellaswag/"
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    for name in ["train.json", "test.json"]:
        out_file = save_dir / name
        if not out_file.exists():
            request.urlretrieve(base_url + name, out_file)

def load_json(dataset):
    download_and_save(save_dir)
    with open(f"{save_dir}/{dataset}.json", "r") as fid:
        return json.load(fid)

train_set, test_set = load_json("train"), load_json("test")
print(f"HellaSwag stats: {len(train_set)} training examples and {len(test_set)} test examples.")
print("An example:\n")
print(json.dumps(train_set[0], indent=4))

Next, let's split the training set into a training and a validation set. We'll pull out a randomly chosen 10% for validation.

In [None]:
# Seed for reproducibility
np.random.seed(43)
perm = np.random.permutation(len(train_set))
valid_size = int(0.1 * len(train_set))
valid_set = [train_set[i] for i in perm[:valid_size]]
train_set = [train_set[i] for i in perm[valid_size:]]

### Fine-Tune

For fine-tuning, we'll use Microsoft's [Phi-3 mini](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). At 3.8 billion parameters, Phi-3 mini is a high-quality model that is also fast to fine-tune on most Apple silicon machines. Also, it has a [permissive MIT License](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE).

First, import all the packages and functions we need.

In [None]:
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_unflatten, tree_map
from mlx_lm import load, generate
from mlx_lm.tuner.lora import LoRALinear
from mlx_lm.tuner import linear_to_lora_layers
import tqdm
import os
import time
import copy
os.environ["TOKENIZERS_PARALLELISM"] = "true"

Next, setup the LoRA parameters.

In [None]:
# Make a directory to save the adapter config and weights
adapter_path = Path("adapters")
adapter_path.mkdir(parents=True, exist_ok=True)

lora_config = {
 "num_layers": 8,
 "lora_parameters": {
    "rank": 8,
    "scale": 10.0, # This can be tuned
    "dropout": 0.0,
}}

# Save the LoRA config to the adapter path
with open(adapter_path / "adapter_config.json", "w") as fid:
    json.dump(lora_config, fid, indent=4)

Next, load the models. For GRPO, we need three models:
- `model` (π_θ): The model we are training with LoRA adapters.
- `model_old` (π_θold): A copy of `model` used for generating rollouts. Its weights are periodically synchronized with `model`.
- `model_ref` (π_ref): The original pretrained model, used as a reference for the KL-divergence penalty.

In [None]:
model_path = "microsoft/Phi-3-mini-4k-instruct"

# Load the main model and tokenizer
model, tokenizer = load(model_path)

# Load the reference model
model_ref, _ = load(model_path)
model_ref.freeze()

After loading the main model, freeze its base parameters and convert the specified linear layers to LoRA layers. The LoRA adapters will be the only trainable parameters.

In [None]:
# Freeze the base model
model.freeze()

# Convert linear layers to lora layers
linear_to_lora_layers(model, lora_config["num_layers"], lora_config["lora_parameters"])

# Create the old model for rollouts
model_old, _ = load(model_path)
linear_to_lora_layers(model_old, lora_config["num_layers"], lora_config["lora_parameters"])
model_old.update(model.parameters()) # Sync weights
model_old.freeze()

num_train_params = (
    sum(v.size for _, v in tree_flatten(model.trainable_parameters()))
)
print(f"Number of trainable parameters: {num_train_params:,}")

## GRPO MODIFICATION ##
### Define the GRPO loss and training loop

Here we define the core components for GRPO. This includes:
1. A helper function to calculate the log probabilities of a sequence.
2. The GRPO loss function, which computes the PPO-clip objective and KL penalty.
3. The main training loop that orchestrates the RL process.

In [None]:
from mlx import nn

def calculate_log_probs(model, sequences, a_toks):
    """Calculates the log probabilities of the generated answer tokens."""
    # Pass the full sequence (prompt + answer) to the model
    logits = model(sequences)

    # Convert to log probabilities
    log_probs_full = nn.log_softmax(logits, axis=-1)

    ## Find the actual positions where answer tokens should be extracted
    # This assumes a_toks contains the actual token IDs that were generated
    batch_size, seq_len = sequences.shape
    _, ans_len = a_toks.shape

    # Calculate the starting position for answer tokens (assuming they're at the end)
    start_pos = seq_len - ans_len

    # Extract log probabilities for the answer portion of the sequence
    answer_log_probs = log_probs_full[:, start_pos:start_pos+ans_len, :]

    # Create indices for gathering - ensure proper shape alignment
    indices = a_toks[:, :, None]

    # Extract log probabilities for the actual answer tokens
    selected_log_probs = mx.take_along_axis(answer_log_probs, indices, axis=-1).squeeze(-1)

    # Sum log probabilities across the answer sequence
    return mx.sum(selected_log_probs, axis=-1)

def grpo_loss_fn(model, model_ref, sequences, a_toks, advantages, old_log_probs, beta, epsilon):
    """The GRPO loss function."""
    # Get log probs from the trainable model (π_θ)
    log_probs = calculate_log_probs(model, sequences, a_toks)

    # Get log probs from the reference model (π_ref) for KL penalty
    log_probs_ref = calculate_log_probs(model_ref, sequences, a_toks)

    # PPO-clip objective
    ratio = mx.exp(log_probs - old_log_probs)
    clipped_ratio = mx.clip(ratio, 1.0 - epsilon, 1.0 + epsilon)
    policy_reward = mx.minimum(ratio * advantages, clipped_ratio * advantages)

    # KL penalty
    # Step 1: Calculate log(r) where r = π_ref / π_θ
    # log(r) = log(π_ref) - log(π_θ)
    log_ratio_for_kl = log_probs_ref - log_probs

    # Step 2: Calculate r itself by exponentiating log(r)
    # r = exp(log(r))
    ratio_for_kl = mx.exp(log_ratio_for_kl)

    # Step 3: Apply the paper's full formula: r - log(r) - 1
    kl_div = ratio_for_kl - log_ratio_for_kl - 1

    # The objective is to maximize this, so we return the negative for minimization
    loss = -mx.mean(policy_reward - beta * kl_div)
    return loss, mx.mean(policy_reward), mx.mean(kl_div)

In [None]:
# Pad sequences to the same length
def pad_sequences(sequences, pad_token_id):
    if not sequences:
        return mx.array([])

    # Find hte maximum length
    max_len = max(len(seq) for seq in sequences)
    padded_sequences = []

    for seq in sequences:
        if len(seq) < max_len:
            padding = mx.array([pad_token_id] * (max_len - len(seq)))
            padded_seq = mx.concatenate([seq, padding])

        else:
            padded_seq = seq
        padded_sequences.append(padded_seq)

    return mx.stack(padded_sequences)

In [None]:
def grpo_train_loop(
    model, model_old, model_ref, tokenizer, optimizer, train_set,
    iters=200, group_size=4, batch_size=2, epsilon=0.2, beta=0.01,
    update_every=10, max_ans_len=4
):
    # Create a grad function for the trainable model
    loss_and_grad_fn = nn.value_and_grad(model, grpo_loss_fn)
    
    losses = []
    all_rewards = []
    
    # Start training
    pbar = tqdm.tqdm(range(iters))
    for it in pbar:
        batch_prompts = []
        batch_answers = []
        
        # 1. Sample a batch of prompts
        indices = np.random.randint(0, len(train_set), batch_size)
        for i in indices:
            # The last word of the output is the ground truth answer (e.g., "ending4")
            prompt_text, answer_text = train_set[i]["output"].rsplit(" ", maxsplit=1)
            full_prompt = [
                {"role": "user", "content": train_set[i]["instruction"]},
                {"role": "assistant", "content": prompt_text}
            ]
            batch_prompts.append(full_prompt)
            batch_answers.append(answer_text)
        
        # 2. Rollout: Generate G responses for each prompt using the old model
        rollout_sequences = []
        rollout_rewards = []
        rollout_log_probs = []
        rollout_a_toks = []

        for i in range(batch_size):
            prompt_tokens = tokenizer.apply_chat_template(batch_prompts[i], continue_final_message=True)
            group_rewards = []

            for _ in range(group_size):
                # Generate a response
                response = generate(model_old, tokenizer, prompt_tokens, max_tokens=max_ans_len)
                answer_tokens = tokenizer.encode(response, add_special_tokens=False)

                # 3. Get Reward
                reward = 1.0 if batch_answers[i] in response else 0.0
                group_rewards.append(reward)

                # Store data for the optimization step
                full_sequence = mx.array(prompt_tokens + answer_tokens)
                rollout_sequences.append(full_sequence)
                rollout_a_toks.append(mx.array(answer_tokens))

            all_rewards.extend(group_rewards)
            rollout_rewards.append(mx.array(group_rewards))
        
        # 4. Compute Advantages
        advantages = []
        for rewards in rollout_rewards:
            mean_reward = mx.mean(rewards)
            std_reward = mx.sqrt(mx.var(rewards)) + 1e-8 # Add epsilon for stability
            adv = (rewards - mean_reward) / std_reward
            advantages.append(adv)
        
        advantages = mx.concatenate(advantages)
        sequences = pad_sequences(rollout_sequences, tokenizer.pad_token_id)
        a_toks = pad_sequences(rollout_a_toks, tokenizer.pad_token_id)

        # Calculate log_probs with the old model for the ratio calculation
        old_log_probs = calculate_log_probs(model_old, sequences, a_toks)

        # 5. Optimization Step
        (loss, policy_reward, kl_div), grads = loss_and_grad_fn(
            model, model_ref, sequences, a_toks, advantages, old_log_probs, beta, epsilon
        )
        
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)

        losses.append(loss.item())
        pbar.set_description(f"Loss: {np.mean(losses[-10:]):.3f}, Mean Reward: {np.mean(all_rewards[-20:]):.3f}")
        
        # Sync old model weights
        if (it + 1) % update_every == 0:
            model_old.update(model.parameters())
            print(f"\nIter {it+1}: Synced old model weights.")
            
    # Final save of adapter weights
    model.save_weights(str(adapter_path / "adapters.safetensors"))
    print("Saved final weights to adapters/adapters.safetensors.")
    return losses, all_rewards

Now we're ready to put it all together and actually train the model. We'll use `Adam` for the optimizer and run our custom GRPO loop.

In [None]:
# GRPO Hyperparameters
learning_rate = 1e-5
iters = 200
group_size = 4      # G in the paper, number of responses per prompt
batch_size = 4      # Number of prompts per iteration
epsilon = 0.2       # PPO clip parameter
beta = 0.02         # KL penalty coefficient
update_every = 10   # Sync model_old every N iterations
max_ans_len = 4     # Max tokens to generate for an answer

# Put the model in training mode:
model.train()

# Make the optimizer:
opt = optim.Adam(learning_rate=learning_rate)

print("Starting GRPO training...")
start_time = time.time()

# Run the custom GRPO training loop
losses, rewards = grpo_train_loop(
    model=model,
    model_old=model_old,
    model_ref=model_ref,
    tokenizer=tokenizer,
    optimizer=opt,
    train_set=train_set,
    iters=iters,
    group_size=group_size,
    batch_size=batch_size,
    epsilon=epsilon,
    beta=beta,
    update_every=update_every,
    max_ans_len=max_ans_len
)

end_time = time.time()
print(f"Training finished in {end_time - start_time:.2f}s")

The adapters are saved at the end of training in `adapters.safetensors`.

In [None]:
!ls adapters/

Next, let's plot the training loss and the moving average of the rewards to see how well the model learned.

In [None]:
def moving_average(a, n=10) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('Iteration')
ax1.set_ylabel('GRPO Loss', color=color)
ax1.plot(losses, color=color)
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
color = 'tab:blue'
ax2.set_ylabel('Reward (Moving Avg)', color=color)
ax2.plot(moving_average(rewards, n=50), color=color)
ax2.tick_params(axis='y', labelcolor=color)

fig.tight_layout()
plt.title("GRPO Training Loss and Reward")
plt.show();

### Evaluate

The training and validation loss are only part of the story. For HellaSwag, we ultimately care about how good the model is at answering questions. To asses this, let's generate the actual `ending1`, `ending2`, `ending3`, or `ending4` responses with the fine-tuned model and measure the accuracy.

First, let's split the last word off of each output in the test set to create a prompt without the answer.

In [None]:
test_set_eval = [(t["instruction"], *t["output"].rsplit(" ", maxsplit=1)) for t in test_set]

Next, we'll generate the response for each example in the test set and compare it to the ground-truth answer to measure the accuracy.

In [None]:
def evaluate(model, tokenizer, num_test):
    num_correct = 0
    for prompt, completion, answer in tqdm.tqdm(test_set_eval[:num_test]):
        messages = [
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": completion}
        ]
        # Transform the prompt into the chat template
        prompt = tokenizer.apply_chat_template(
            conversation=messages, add_generation_prompt=True
        )
        # Use greedy decoding for evaluation
        response = generate(model, tokenizer, prompt=prompt, max_tokens=4, temp=0.0)
        num_correct += (answer in response)
    return num_correct / num_test

In [None]:
# Put model in eval mode for evaluation
model.eval()

# Increase this number to use more test examples
num_test = 100
test_acc = evaluate(model, tokenizer, num_test)
print(f"Approximate test accuracy {test_acc:.3f}")

### Fuse Adapters

Sometimes its convenient to fuse the adapters into the base model to create a single adapted model. MLX LM has a fuse script just for that.

To see more options for fusing the model, including how to upload to HuggingFace [check the documentation](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#fuse).

In [None]:
!mlx_lm.fuse --model {model_path} --adapter-path {adapter_path}

Once the adapters are fused, we can rerun the evaluation using the fused model to make sure it worked. By default the fused model will be saved to `fused_model`.

In [None]:
model_fused, tokenizer_fused = load("fused_model")
test_acc_fused = evaluate(model_fused, tokenizer_fused, num_test)
print(f"Approximate fused model test accuracy {test_acc_fused:.3f}")

### Troubleshooting

#### Results

To figure out why your GRPO fine-tuning is not working well, it's critical to plot both the loss and the average reward. 

**Underfitting**: The average reward is not increasing significantly and remains low. The loss may be stagnant or decreasing very slowly. This means the model isn't learning the desired behavior. You have a few options to improve the results:

- **Increase the learning rate**: A higher learning rate might be needed to escape local minima.
- **Increase `group_size` (G)**: A larger group provides a more stable estimate of the advantage, which can improve the quality of the gradients.
- **Tune the KL penalty `beta`**: If `beta` is too high, it will prevent the model from learning, acting as an overly strong regularizer. Try decreasing it.
- **Increase adapter capacity**: Use more `lora_layers` or a higher `rank`.
- **Check your reward function**: Ensure the reward accurately reflects the desired outcome. For simple tasks like this, it's straightforward, but for complex tasks, this is often a source of error.

**Overfitting/Instability**: The reward increases initially but then crashes, or the loss fluctuates wildly. This means the policy updates are too large and are destabilizing the model.

- **Decrease the learning rate**: This is the most common fix for instability.
- **Tune the PPO clip `epsilon`**: A smaller `epsilon` (e.g., 0.1) will make the updates more conservative.
- **Increase the KL penalty `beta`**: A larger `beta` will pull the policy back towards the original reference model, preventing it from straying too far into unstable regions.
- **Update `model_old` less frequently**: Increasing `update_every` can sometimes add stability.

#### Memory Use

RL fine-tuning can be more memory-intensive than SFT due to storing rollouts and multiple models. Here are some tips to reduce memory use:

- **Reduce `batch_size` or `group_size`**: These directly control how many sequences are held in memory for each iteration.
- **Quantization (QLoRA)**: This is highly effective. You can use a quantized base model from Hugging Face or create one with `mlx_lm.convert`.
- **Gradient Checkpointing**: This trades computation for memory. Add `grad_checkpoint=True` when calling the training loop and pass it down to the loss function if needed (though our custom loop doesn't have this argument, it could be added).
- **Reduce `lora_layers` or `rank`**: Fewer trainable parameters means a smaller memory footprint for gradients and optimizer states.

### Next Steps

- To learn more about MLX check-out the [GitHub repo](http://github.com/ml-explore/mlx) and [documentation](https://ml-explore.github.io/mlx/)
- For more on MLX LM check-out the [MLX LM documentation](https://github.com/ml-explore/mlx-examples/tree/main/llms#readme).
- Check out the other [MLX Examples](https://github.com/ml-explore/mlx-examples/tree/main). These are great as a learning resource or to use as a starting point for a new project.