Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma3 minimal fine tuning example? #36714

Closed
FredrikNoren opened this issue Mar 14, 2025 · 4 comments
Closed

Gemma3 minimal fine tuning example? #36714

FredrikNoren opened this issue Mar 14, 2025 · 4 comments

Comments

@FredrikNoren
Copy link
Contributor

I've been trying to fine tune gemma3 but I can't seem to figure out exactly how to do it. Here's what I have now:

import tempfile
from typing import Any, cast
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from transformers import TrainingArguments, Trainer
from common import HF_TOKEN
from datasets import Dataset


def test_train_gemma3():

    model_id = "google/gemma-3-4b-it"
    model = Gemma3ForConditionalGeneration.from_pretrained(
        model_id, device_map="auto", token=HF_TOKEN
    )

    processor = AutoProcessor.from_pretrained(model_id, padding_side="left", token=HF_TOKEN)

    def train_iterable_gen():
        messages = [
            {
                "role": "user",
                "content": [{ "type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" }],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": "duck" }]
            }
        ]
        inputs = processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        )
        inputs["labels"] = inputs["input_ids"].clone()
        yield inputs
    train_ds = Dataset.from_generator(train_iterable_gen, cache_dir=tempfile.gettempdir())

    training_args = TrainingArguments(
        output_dir=tempfile.gettempdir(),
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        report_to="none",
        max_steps=1000
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=cast(Any, train_ds),
    )

    print("Training model...")
    trainer.train()
    print("Training complete.")

But it throws an exception:

  File "/home/ray/anaconda3/lib/python3.12/site-packages/transformers/models/siglip/modeling_siglip.py", line 310, in forward
    _, _, height, width = pixel_values.shape
    ^^^^^^^^^^^^^^^^^^^
ValueError: too many values to unpack (expected 4)

How is my code wrong and what can I do to fix it?

@Rocketknight1
Copy link
Member

Hi @FredrikNoren, Gemma3ForConditionalGeneration is the vision-language model, and it will get confused because you're not passing any images. Try Gemma3ForCausalLM for the text-only model.

@FredrikNoren
Copy link
Contributor Author

@Rocketknight1 Hm I was under the impression that apply_chat_template with images in the messages would automatically create the images in the input? It does return pixel_values.

I also tried this variant:

import io
import tempfile
from typing import Any, cast
import requests
import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
from transformers import TrainingArguments, Trainer
from common import HF_TOKEN
from datasets import Dataset
from PIL import Image

def load_image(url):
        response = requests.get(url)
        image = Image.open(io.BytesIO(response.content))
        return image


def test_train_gemma3_collate():

    model_id = "google/gemma-3-4b-it"

    model = Gemma3ForConditionalGeneration.from_pretrained(
        model_id, device_map="auto", token=HF_TOKEN
    )

    processor = AutoProcessor.from_pretrained(model_id, padding_side="left", token=HF_TOKEN)

    peak_mem = torch.cuda.max_memory_allocated()
    print(f"The model as is is holding: {peak_mem / 1024**3:.2f}GB of GPU RAM")


    def train_iterable_gen():
        yield {
            "image": load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg").resize((128, 128)),
            "messages": [
                {
                    "role": "user",
                    "content": [{"type": "image" }]
                },
                {
                    "role": "assistant",
                    "content": [{"type": "text", "text": "duck" }]
                }
            ]
        }
    train_ds = Dataset.from_generator(train_iterable_gen, cache_dir=tempfile.gettempdir())
    def collate_fn(examples):
        texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
        image_inputs = [example["image"] for example in examples]

        batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100  #

        batch["labels"] = labels

        return batch
    training_args = TrainingArguments(
        output_dir=tempfile.gettempdir(),
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        gradient_checkpointing=True,
        report_to="none",
        remove_unused_columns=False,
        max_steps=1
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        train_dataset=cast(Any, train_ds),
    )

    print("Training model...")
    trainer.train()
    print("Training complete.")

Which instead crashes with OOM in Cuda. This is on a A100 (i.e. 40GB GPU memory), but maybe that's not enough memory?

I tried switching to "google/gemma-3-1b-it" but it gives me this exception:

  File "/home/ray/anaconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for Gemma3TextScaledWordEmbedding:
        size mismatch for weight: copying a param with shape torch.Size([262144, 1152]) from checkpoint, the shape in current model is torch.Size([262208, 2304]).

I tried adding quantization, but I suspect maybe the 1b doesn't have vision?

@ysdk2
Copy link

ysdk2 commented Mar 16, 2025

When fine tuning Gemma3ForCausalLM, does it only load the text parameters or will it load the vision branch as well? I'm trying to reduce memory.

@FredrikNoren
Copy link
Contributor Author

The gemma3 team provided a fine tuning example here: google-deepmind/gemma#175

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants