# Let's make Gemma 3 1b think! 🍎

This is another notebook to make Gemma 3 think. This time focusing on the smallest 1b variant. You should be able to download this notebook for Mac silicone.

![logo](https://storage.googleapis.com/gweb-uniblog-publish-prod/images/Gemma3_KeywordBlog_RD3_V01b.width-1200.format-webp.webp)

👩‍🎓 If you want to learn more about making models think and reason, check out [The Reasoning Course](https://huggingface.co/reasoning-course)

### Installation

In [42]:
# # install this release tag of transformers
# !pip install -qqq git+https://github.com/huggingface/trl.git@main \
#                   bitsandbytes

# !pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

# !pip install git+https://github.com/huggingface/peft.git

In [1]:
# Set your Hugging Face token
from creds import all_creds
os.environ["HUGGING_FACE_HUB_TOKEN"] = all_creds['HUGGINGFACE_ACCESS_TOKEN_Gemma']

In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
import torch
from transformers import Gemma3ForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

torch_dtype = torch.bfloat16

model = Gemma3ForCausalLM.from_pretrained(
    pretrained_model_name_or_path="google/gemma-3-1b-it",
    device_map="auto" if not torch.mps.is_available() else torch.device("mps"),  # switch to mac silicon
    #attn_implementation="sdpa",
    attn_implementation="eager",
    torch_dtype=torch_dtype
)

# Load LoRA
peft_config = LoraConfig(
    lora_alpha=4,
    lora_dropout=0.05,
    r=4,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],  # make sure to save the lm_head and embed_tokens as you train the special tokens
)

model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())

processor = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'
trainable params: 607,241,216 || all params: 1,607,127,168 || trainable%: 37.7843
None


### Process data to create reasoning chains

Borrowing from [Will Brown's gist](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) we'll make reasoning chains from GSM8k.

In [1]:
import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>"""

XML_COT_FORMAT = """<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

In [None]:
# Optionally load data from CSV
from datasets import Dataset
import pandas as pd

def csv_to_gsm8k_format(csv_path):
    # Read the CSV file
    df = pd.read_csv(csv_path)
    
    # Convert to the required format
    formatted_data = {
        'question': [],
        'answer': [],
        'prompt': []
    }
    
    SYSTEM_PROMPT = """Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>"""
    
    for _, row in df.iterrows():
        formatted_data['question'].append(row['question'])
        formatted_data['answer'].append(row['answer'])
        formatted_data['prompt'].append([
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': row['question']}
        ])
    
    # Create HuggingFace dataset
    dataset = Dataset.from_dict(formatted_data)
    return dataset

# Example usage:
# dataset = csv_to_gsm8k_format('your_csv_file.csv')

# Reward Functions

Now, let's define reward functions. These are the functions we'll need to setup reward chains.

| Reward Function | Purpose |
|---|---|
| `correctness_reward_func` | Rewards the model when its answer matches the correct answer |
| `int_reward_func` | Rewards the model for providing a numeric answer |
| `strict_format_reward_func` and `soft_format_reward_func` | Reward the model for following the specified format |
| `xmlcount_reward_func` | Rewards proper XML tag usage and penalizes extra content after the closing tags |

In [5]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [2]:
dataset.features

{'question': Value(dtype='string', id=None),
 'answer': Value(dtype='string', id=None),
 'prompt': [{'content': Value(dtype='string', id=None),
   'role': Value(dtype='string', id=None)}]}

In [3]:
len(dataset)

7473

In [7]:
print(dataset[0])

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'answer': '72', 'prompt': [{'content': 'Respond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>', 'role': 'system'}, {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'role': 'user'}]}


In [8]:
dataset.data[0][0]

<pyarrow.StringScalar: 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?'>

# Train with GRPOTrainer

Now we'll confgure training with the `GRPOConfig`

In [11]:
# May need to do this later
#    if torch.backends.mps.is_available():
#        # MPS setup
#        training_args = GRPOConfig(
#            # other arguments
#            optim = "adamw_torch",  # Standard PyTorch optimizer
#            # rest of config
#        )
#    else:
#        # CUDA setup
#        training_args = GRPOConfig(
#            # other arguments
#            optim = "adamw_8bit",  # 8-bit optimizer for CUDA
#            # rest of config
#        )

In [12]:
from trl import GRPOConfig, GRPOTrainer
from transformers import GenerationConfig

max_prompt_length = 1024
max_seq_length = 2048


training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "constant",
    #optim = "adamw_8bit",
    optim = "adamw_torch",
    logging_steps = 1,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 4,
    num_generations = 2,
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1,
    max_steps = 5,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none",
    cache_implementation="hybrid"
)

# Start trainer

In [13]:
from trl.trainer.utils import pad
import torch

trainer = GRPOTrainer(
    model = model,
    processing_class = processor,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)

trainer.train()

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


-------------------- Question:
Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade Ahmed needs to get to beat Emily if all grades are whole numbers? 
Answer:
100 
Response:
<reasoning>
This is a classic problem involving optimization. Ahmed's success is the key. Emily’s higher grade on the final assignment gives her an advantage, but Ahmed’s already high score is a significant hurdle. We need to determine the smallest grade Ahmed needs to surpass Emily's final score to achieve a tie.  We can think of this as a game of trying to close the gap between their scores.

</reasoning>
<answer>
The minimum grade Ahmed needs to get is 91. If Ahmed gets 91, he will overcome Emily’s 92.做到یی திறன் невозможно ..."..."..."..."..."..."..."..."..."...".

Step,Training Loss
1,0.0
2,0.0473
3,0.0001
4,0.0
5,0.0001


-------------------- Question:
Marie has 98 unread messages on her phone. She decides to clear them by reading 20 messages a day. However, she also gets 6 new messages a day. How many days will it take her to read all her unread messages? 
Answer:
7 
Response:
<reasoning>
Marie’s daily progress is determined by the rate at which she reads new messages and the number of unread messages she’s already reading. She reads 20 messages and gets 6 new messages, so her net progress each day is 20 - 6 = 14 messages.  We need to find the number of days it takes to clear 98 unread messages.

<answer>
It will take her 98 / 14 = 6.86 days. Since she can’t have a fraction of a day, we need to round up to the nearest whole number, which is 7 days.jneeedmoretime做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做到做

TrainOutput(global_step=5, training_loss=0.009502559900283813, metrics={'train_runtime': 3059.3499, 'train_samples_per_second': 0.026, 'train_steps_per_second': 0.002, 'total_flos': 0.0, 'train_loss': 0.009502559900283813})

In [14]:
trainer.push_to_hub("voxmenthe/gemma3-1b-thinking")

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/6.01k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/voxmenthe/trainer_output/commit/ac1c0048c5e59ad96397f504db2899fe6201fb13', commit_message='voxmenthe/gemma3-1b-thinking', commit_description='', oid='ac1c0048c5e59ad96397f504db2899fe6201fb13', pr_url=None, repo_url=RepoUrl('https://huggingface.co/voxmenthe/trainer_output', endpoint='https://huggingface.co', repo_type='model', repo_id='voxmenthe/trainer_output'), pr_revision=None, pr_num=None)

In [19]:
from transformers import pipeline

question = "The school principal decided that she wanted every class to have an equal number of boys and girls in each first-grade classroom. There are 4 classrooms. There are 56 boys and 44 girls. How many total students are in each classroom?"
generator = pipeline("text-generation", model=trainer.model, tokenizer=processor)

# Get the raw string by setting tokenize=False
input_text = processor.apply_chat_template([{"role": "user", "content": question}], tokenize=False)

# Now you can append "<reasoning>" to the string
input_text = input_text + "<reasoning>"

output = generator(input_text, max_new_tokens=1024)

Device set to use mps
The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForCausalLM', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalL

In [22]:
print(output[0]['generated_text'])

<bos><start_of_turn>user
The school principal decided that she wanted every class to have an equal number of boys and girls in each first-grade classroom. There are 4 classrooms. There are 56 boys and 44 girls. How many total students are in each classroom?<end_of_turn>
<reasoning>
Let $B$ be the number of boys in each classroom and $G$ be the number of girls in each classroom.
We are given that there are 4 classrooms.
The total number of boys is 56 and the total number of girls is 44.
The total number of students in each classroom is equal.
The number of boys in each classroom is $B$, and the number of girls in each classroom is $G$.
Since there are 4 classrooms, the total number of boys is $4B$ and the total number of girls is $4G$.
We are given that the number of boys is 56 and the number of girls is 44.
So, $4B = 56$ and $4G = 44$.
From $4B = 56$, we have $B = \frac{56}{4} = 14$.
From $4G = 44$, we have $G = \frac{44}{4} = 11$.
The number of students in each classroom is equal, so 

In [18]:
from transformers import pipeline

question = "The school principal decided that she wanted every class to have an equal number of boys and girls in each first-grade classroom. There are 4 classrooms. There are 56 boys and 44 girls. How many total students are in each classroom?"
#generator = pipeline("text-generation", model=trainer.model, tokenizer=processor.tokenizer)
generator = pipeline("text-generation", model=trainer.model, tokenizer=processor)
input = processor.apply_chat_template([{"role": "user", "content": question}])
input + "<reasoning>"
output = generator(input, max_new_tokens=1024)


Device set to use mps
The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForCausalLM', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalL

TypeError: can only concatenate list (not "str") to list

In [None]:
output

[{'generated_text': '<bos><start_of_turn>user\nThe school principal decided that she wanted every class to have an equal number of boys and girls in each first-grade classroom. There are 4 classrooms. There are 56 boys and 44 girls. How many total students are in each classroom?<end_of_turn>\n* * *\n**Solution**\n\n1.  **Find the total number of students:** 56 boys + 44 girls = 100 students\n2.  **Divide the total students by the number of classrooms:** 100 students / 4 classrooms = 25 students per classroom\n\n**Answer:** There are 25 students in each classroom.'}]

In [None]:
# Instead of using the pipeline, use:
inputs = processor(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=max_length)
response = processor.decode(outputs[0], skip_special_tokens=True)

# Next Steps!

Checkout the [The Reasoing Course](https://huggingface.co/reasoning-course) for more info on GRPO.

In the coming days we'll release a version of this notebook with Unsloth!

<a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>