In [None]:
# %%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm
# Install latest Hugging Face for Gemma-3!
!pip install --no-deps git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

In [2]:

#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt


In [3]:
from unsloth import FastModel
import torch
max_seq_length = 1024

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",

    # Other popular models!
    "unsloth/Llama-3.1-8B",
    "unsloth/Llama-3.2-3B",
    "unsloth/Llama-3.3-70B",
    "unsloth/mistral-7b-instruct-v0.3",
    "unsloth/Phi-4",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-1b-it",
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # Should leave on always!

    r = 8,           # Larger = higher accuracy, but might overfit
    lora_alpha = 8,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 03-21 08:14:02 [__init__.py:256] Automatically detected platform cuda.
==((====))==  Unsloth 2025.3.17: Fast Gemma3 patching. Transformers: 4.50.0.dev0. vLLM: 0.8.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

Unsloth: Making `model.base_model.model.model` require gradients


# Creating our Connect Four Games Dataset

Below is our method to generate lots of games of connect four. I collect games up to a point where one player wins. I can then ask the model what move it needs to take to win the game if its either X or O! I also then know what the answer should be which I can rewards a model on.

In [None]:
import numpy as np
import random
import pandas as pd

ROWS = 6
COLS = 7

def create_board():
    """Creates an empty Connect Four board."""
    return np.full((ROWS, COLS), '.', dtype=str)

def drop_piece(board, col, piece):
    """Drops a piece in the given column and returns the row where it landed."""
    for row in range(ROWS - 1, -1, -1):  # start from the bottom row
        if board[row, col] == '.':
            board[row, col] = piece
            return row
    return None  # Column is full

def is_winning_move(board, row, col, piece):
    """Checks if placing a piece at (row, col) wins the game.
       It checks vertical, horizontal, and both diagonal directions."""
    directions = [
        (1, 0),   # vertical (down)
        (0, 1),   # horizontal (right)
        (1, 1),   # diagonal ↘
        (1, -1)   # diagonal ↙
    ]
    for dr, dc in directions:
        count = 1  # count the piece just placed
        for d in [-1, 1]:  # check both directions along (dr, dc)
            r, c = row + d * dr, col + d * dc
            while 0 <= r < ROWS and 0 <= c < COLS and board[r, c] == piece:
                count += 1
                if count == 4:
                    return True
                r += d * dr
                c += d * dc
    return False

def get_almost_winning_boards(num_games=1000):
    """
    Simulates games and collects board states that are just one move away from winning.
    For each board state, the function returns a tuple with:
      (board state copy, winning player, winning column)
    """
    winning_positions = []

    # Run through multiple game simulations.
    for _ in range(num_games):
        board = create_board()
        current_piece = "X"

        # Simulate moves until board is full or a winning move is found.
        for _ in range(ROWS * COLS):
            available_columns = [c for c in range(COLS) if board[0, c] == '.']
            if not available_columns:
                break  # board is full

            col = random.choice(available_columns)
            row = drop_piece(board, col, current_piece)

            # Check if the last move is a winning move.
            if is_winning_move(board, row, col, current_piece):
                # Remove the winning move to get the board state just before the win.
                board[row, col] = '.'
                # Save the board state, the winning player, and the column to win.
                winning_positions.append((board.copy(), current_piece, col))
                break

            # Switch the player.
            current_piece = "O" if current_piece == "X" else "X"

    return winning_positions

# Generate board states that are one move away from winning.
almost_winning_boards = get_almost_winning_boards(num_games=1000)

# Convert board states into a DataFrame with one board per row.
# Each row will include a label for the game, the winning player, the winning column,
# and the board rows (from top row 0 to bottom row ROWS-1).
board_data = []
for idx, (board, winner, win_col) in enumerate(almost_winning_boards):
    # Convert each row (a numpy array) to a string for display.
    rows_as_str = ["".join(row) for row in board]
    board_data.append([f"Game {idx+1}", winner, win_col] + rows_as_str)

# Create DataFrame columns: Game, Winning Player, Winning Column, and one column for each board row.
columns = ["Game", "Winning Player", "Winning Column"] + [f"Row {i}" for i in range(ROWS)]
df = pd.DataFrame(board_data, columns=columns)
df.to_csv("games.csv")


# Training & Rewarding Gemma 1B

This section of the code focuses on preparing the Connect Four game data and structuring the interaction with the AI model. It begins by importing necessary libraries like Pandas for data handling and re for regular expressions. The Connect Four game data is loaded from a CSV file named "games.csv" into a Pandas DataFrame. To guide the AI, a system prompt is defined, instructing it to provide reasoning and solutions within specific tags.

The format_puzzle function then transforms game states from the DataFrame into prompts for the AI, including a visual representation of the board and instructions. To ensure the AI's responses adhere to the desired format and to extract the predicted move, regular expressions are utilized. Finally, a check_answer function, is sued as part of the reward system, it assesses the accuracy of the AI's predictions during training by comparing them to the correct moves stored in our CSV. This setup lays the groundwork for effectively training and evaluating the AI's performance in playing Connect Four over time.

The core of the training process relies on the GRPO algorithm. GRPO scores the AI's responses using a combination of reward functions which we just mentioned, assessing aspects like output format adherence and prediction accuracy. These scores are then used to calculate gradients that guide the model's learning process. Importantly, training is conducted using LORA (Low-Rank Adaptation), a parameter-efficient fine-tuning technique. LORA enables fine-tuning specific model layers while keeping the majority of the pre-trained weights frozen, leading to faster training and reduced memory requirements. Ultimately, this process aims to enhance the AI's ability to predict winning moves in Connect Four, progressively refining its performance through iterative training and feedback

In [7]:
import pandas as pd
import re

# Load the games CSV - this is a automated csv
df = pd.read_csv("games.csv")

# Format constants for responses
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<solution>"
solution_end = "</solution>"

system_prompt = \
f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""
system_prompt

def format_puzzle(row):
    """Formats a Connect 4 board from a DataFrame row into a prompt."""

    board_rows = [df.loc[row, f'Row {i}'] for i in range(6)]  # Display from bottom to top
    board_str = "1234567\n    " + "\n    ".join(board_rows)  # Add column numbers at the top
    winning_player = df.loc[row, f'Winning Player']
    winning_column = df.loc[row, f'Winning Column']
    # Create the prompt
    prompt = f"""
    You are a connect four master.

    board position:
     {board_str}

    It is {winning_player}'s turn to move. Find where they should move their piece.
    Columns are labelled from 1-7. Choose the best column to win the game.

    {reasoning_start}
    As {winning_player} which number column should you place your column to win the game and connect four?
    {reasoning_end}

    {solution_start}"""

    return prompt


# Compile regex to check that the AI output follows the expected format
match_format = re.compile(
    rf"^[\s]*"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]*$",
    flags = re.MULTILINE | re.DOTALL
)

def extract_move(response):
    """Extracts the move number from the AI response."""
    move_pattern = re.search(r"\b([1-7])\b", response)  # Look for a single-digit column number
    return move_pattern.group(1) if move_pattern else None

def check_answer(prompts, completions, answer, **kwargs):
    """Reward function comparing extracted move with true answer."""
    # Handle different completion formats
    responses = []
    for completion in completions:
        if isinstance(completion, list) and len(completion) > 0 and isinstance(completion[0], dict) and "content" in completion[0]:
            responses.append(completion[0]["content"])
        elif isinstance(completion, dict) and "content" in completion:
            responses.append(completion["content"])
        else:
            # Handle case where completion is already a string
            responses.append(completion)

    extracted_responses = []
    for r in responses:
        if not isinstance(r, str):
            extracted_responses.append(None)
            continue

        match = match_format.search(r)
        if match is not None:
            guess = match.group(1)
            # Further extract the number from the solution section
            num_match = re.search(r"\b([1-7])\b", guess)
            extracted_responses.append(num_match.group(1) if num_match else None)
        else:
            extracted_responses.append(None)

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        # Convert true_answer to int, add 1, then convert back to string
        adjusted_true_answer = str(int(true_answer) + 1)

        score = 0
        if guess is None:
            scores.append(0)
            continue
        # Correct answer gets 3 points!
        if guess == adjusted_true_answer:
            score += 3.0
        # If stripping whitespace makes them match, reward partially
        elif guess.strip() == adjusted_true_answer.strip():
            score += 1.5
        else:
            # Also check if the answer is close via ratio comparison
            try:
                ratio = float(guess) / float(adjusted_true_answer)
                if 0.9 <= ratio <= 1.1:
                    score += 0.5
                elif 0.8 <= ratio <= 1.2:
                    score += 0.25
                else:
                    score -= 1.0  # Penalize wrong answers
            except:
                score -= 0.5  # Penalize non-numeric or badly formatted answers
        scores.append(score)
    return scores

def match_format_exactly(prompts, completions, answer, **kwargs):
    """Reward full credit if the output strictly follows the expected format."""
    responses = [comp[0]["content"] if isinstance(comp, list) and isinstance(comp[0], dict) else comp for comp in completions]
    scores = []
    for r in responses:
        if match_format.fullmatch(r.strip()):
            scores.append(1.0)
        else:
            scores.append(0.0)
    return scores


def match_format_exactly(prompts, completions, answer, **kwargs):
    """Reward full credit if the output strictly follows the expected format."""
    responses = []
    for completion in completions:
        if isinstance(completion, list) and len(completion) > 0 and isinstance(completion[0], dict) and "content" in completion[0]:
            responses.append(completion[0]["content"])
        elif isinstance(completion, dict) and "content" in completion:
            responses.append(completion["content"])
        else:
            # Handle case where completion is already a string
            responses.append(completion)

    scores = []
    for r in responses:
        if isinstance(r, str) and match_format.search(r):
            scores.append(1.0)
        else:
            scores.append(0.0)
    return scores

def match_format_approximately(prompts, completions, answer, **kwargs):
    """Reward if key formatting tags are present in the output."""
    responses = []
    for completion in completions:
        if isinstance(completion, list) and len(completion) > 0 and isinstance(completion[0], dict) and "content" in completion[0]:
            responses.append(completion[0]["content"])
        elif isinstance(completion, dict) and "content" in completion:
            responses.append(completion["content"])
        else:
            # Handle case where completion is already a string
            responses.append(completion)

    scores = []
    for r in responses:
        if isinstance(r, str) and reasoning_start in r and reasoning_end in r and solution_start in r and solution_end in r:
            scores.append(0.5)
        else:
            scores.append(0.0)
    return scores

def check_numbers(prompts, completions, answer, **kwargs):
    """Reward if the extracted number from the output is correct."""
    responses = []
    for completion in completions:
        if isinstance(completion, list) and len(completion) > 0 and isinstance(completion[0], dict) and "content" in completion[0]:
            responses.append(completion[0]["content"])
        elif isinstance(completion, dict) and "content" in completion:
            responses.append(completion["content"])
        else:
            # Handle case where completion is already a string
            responses.append(completion)

    scores = []
    for r, true in zip(responses, answer):
        if not isinstance(r, str):
            scores.append(0.0)
            continue

        num = extract_move(r)
        # Convert true to int, add 1, then convert back to string
        adjusted_true = str(int(true) + 1)

        if num == adjusted_true:
            scores.append(1.0)
        else:
            scores.append(0.0)
    return scores
# Create a simple dataset that cycles through the games CSV
class ConnectFourDataset:
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        prompt = format_puzzle(idx)
        # The answer is the winning column (as a string)
        # Keep as 0-based index from CSV, conversion happens in check functions
        answer = str(self.dataframe.loc[idx, 'Winning Column'])
        return {'prompt': prompt, 'answer': answer}

dataset = ConnectFourDataset(df)

In [None]:
# Set max prompt length and import GRPO trainer components
max_prompt_length = 256

from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1,  # Increase to 4 for smoother training if needed
    num_generations = 4,             # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1,          # Uncomment for a full training run
    max_steps = 50,
    save_steps = 50,
    max_grad_norm = 0.1,
    report_to = "none",              # Can use Weights & Biases if desired
    output_dir = "outputs",
)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args = training_args,
    train_dataset = dataset,
)

# Start training with GRPO!
trainer.train()

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4
Unsloth: Switching to float32 training since model cannot work with float16


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 996 | Num Epochs = 1 | Total steps = 50
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 1 x 1) = 4
 "-____-"     Trainable parameters = 6,522,880/1,006,408,832 (0.65% trained)


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / match_format_exactly,rewards / match_format_approximately,rewards / check_answer,rewards / check_numbers
1,0.0,0.75,0.866025,768.0,0.0,0.5,0.25,0.0,0.0
2,-0.0,0.125,0.25,768.0,0.0,0.0,0.125,0.0,0.0
3,0.0,0.0,0.0,768.0,2e-06,0.0,0.0,0.0,0.0
4,0.0,0.375,0.75,768.0,2e-06,0.25,0.125,0.0,0.0
5,0.0,0.25,0.5,768.0,3e-06,0.0,0.0,0.0,0.25
6,0.0,0.0,0.0,768.0,7e-06,0.0,0.0,0.0,0.0
7,0.0,0.5,0.707107,768.0,4e-06,0.5,0.25,-0.25,0.0
8,0.0,0.5,0.707107,768.0,2e-06,0.5,0.25,-0.25,0.0
9,0.0,0.0,0.0,768.0,3e-06,0.0,0.0,0.0,0.0
10,0.0,0.0,0.0,584.5,1.1e-05,0.0,0.0,0.0,0.0


Unsloth: Will smartly offload gradients to save VRAM!
