Closed
Description
System Info
There are more details in here: google-deepmind/gemma#193
But shortly; it seems like multi-image training is not implemented yet
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
The following code works:
import io
from typing import Any, cast
import requests
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from transformers import TrainingArguments
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from datasets import IterableDataset, Features
import datasets
from PIL import Image
import numpy as np
HF_TOKEN = "..."
def load_image(url):
response = requests.get(url)
image = Image.open(io.BytesIO(response.content))
return image
def image_from_bytes(image_bytes):
return Image.open(io.BytesIO(image_bytes))
def main():
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", token=HF_TOKEN
)
model.config.use_cache = False # Disable caching for training
processor = AutoProcessor.from_pretrained(model_id, padding_side="right", token=HF_TOKEN)
processor.tokenizer.pad_token = processor.tokenizer.eos_token # Use eos token as pad token
processor.tokenizer.padding_side = "right"
def train_iterable_gen():
N_IMAGES = 1
image = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg").resize((896, 896))
images = np.array([image] * N_IMAGES)
print("IMAGES SHAPE", images.shape)
yield {
"images": images,
"messages": [
{
"role": "user",
"content": [{"type": "image" } for _ in range(images.shape[0])]
},
{
"role": "assistant",
"content": [{"type": "text", "text": "duck" }]
}
]
}
train_ds = IterableDataset.from_generator(
train_iterable_gen,
features=Features({
'images': [datasets.Image(mode=None, decode=True, id=None)],
'messages': [{'content': [{'text': datasets.Value(dtype='string', id=None), 'type': datasets.Value(dtype='string', id=None) }], 'role': datasets.Value(dtype='string', id=None)}]
} )
)
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"] for example in examples]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
print("collate_fn pixel_values", batch["pixel_values"].shape)
print("collate_fn input_ids", batch["input_ids"].shape)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.image_token_id] = -100
batch["labels"] = labels
return batch
# Set up LoRA configuration for causal language modeling
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM"
)
# Define training arguments
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=1,
learning_rate=2e-4,
logging_steps=1,
save_steps=25,
report_to="tensorboard",
group_by_length=False,
remove_unused_columns=False,
dataset_kwargs = {"skip_prepare_dataset": True},
gradient_checkpointing_kwargs = dict(use_reentrant=False),
max_steps=1
)
# Create the SFTTrainer with LoRA parameters
trainer = SFTTrainer(
model=model,
train_dataset=cast(Any, train_ds),
peft_config=lora_config,
args=training_args,
data_collator=collate_fn,
processing_class=processor.tokenizer,
)
print("Training model...")
trainer.train()
print("Training complete.")
if __name__ == "__main__":
main()
but if I increase N_IMAGES
to 2 it crashes with the following error:
File "/tmp/ray/session_2025-03-18_01-47-01_879621_311/runtime_resources/working_dir_files/_ray_pkg_95509c95a64411ba/.venv/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1333, in forward
raise ValueError(
ValueError: Number of images does not match number of special image tokens in the input text. Got 512 image tokens in the text but 256 tokens from image embeddings.
Expected behavior
I'd expect either:
- The preprocessor to return pixel_values in the shape
[batch, n_images, c, w, h]
, and that the model can handle that - Or, something else needs fixing because if I put debug statements in the transformers library, it looks like it's only taking the first image in the batch of images
In the first case:
- get_image_features doesn't seem to handle multi-image batches
- the preprocessor seems to just output a single list of images