-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma3 minimal fine tuning example? #36714
Comments
Hi @FredrikNoren, |
@Rocketknight1 Hm I was under the impression that apply_chat_template with images in the messages would automatically create the images in the input? It does return pixel_values. I also tried this variant: import io
import tempfile
from typing import Any, cast
import requests
import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, BitsAndBytesConfig
from transformers import TrainingArguments, Trainer
from common import HF_TOKEN
from datasets import Dataset
from PIL import Image
def load_image(url):
response = requests.get(url)
image = Image.open(io.BytesIO(response.content))
return image
def test_train_gemma3_collate():
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", token=HF_TOKEN
)
processor = AutoProcessor.from_pretrained(model_id, padding_side="left", token=HF_TOKEN)
peak_mem = torch.cuda.max_memory_allocated()
print(f"The model as is is holding: {peak_mem / 1024**3:.2f}GB of GPU RAM")
def train_iterable_gen():
yield {
"image": load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg").resize((128, 128)),
"messages": [
{
"role": "user",
"content": [{"type": "image" }]
},
{
"role": "assistant",
"content": [{"type": "text", "text": "duck" }]
}
]
}
train_ds = Dataset.from_generator(train_iterable_gen, cache_dir=tempfile.gettempdir())
def collate_fn(examples):
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
image_inputs = [example["image"] for example in examples]
batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
batch["labels"] = labels
return batch
training_args = TrainingArguments(
output_dir=tempfile.gettempdir(),
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
report_to="none",
remove_unused_columns=False,
max_steps=1
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=cast(Any, train_ds),
)
print("Training model...")
trainer.train()
print("Training complete.") Which instead crashes with OOM in Cuda. This is on a A100 (i.e. 40GB GPU memory), but maybe that's not enough memory? I tried switching to "google/gemma-3-1b-it" but it gives me this exception:
I tried adding quantization, but I suspect maybe the 1b doesn't have vision? |
When fine tuning Gemma3ForCausalLM, does it only load the text parameters or will it load the vision branch as well? I'm trying to reduce memory. |
The gemma3 team provided a fine tuning example here: google-deepmind/gemma#175 |
I've been trying to fine tune gemma3 but I can't seem to figure out exactly how to do it. Here's what I have now:
But it throws an exception:
How is my code wrong and what can I do to fix it?
The text was updated successfully, but these errors were encountered: