Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma3 can't be fine-tuned on multi-image examples #36816

Open
1 of 4 tasks
FredrikNoren opened this issue Mar 19, 2025 · 8 comments
Open
1 of 4 tasks

Gemma3 can't be fine-tuned on multi-image examples #36816

FredrikNoren opened this issue Mar 19, 2025 · 8 comments
Labels

Comments

@FredrikNoren
Copy link
Contributor

FredrikNoren commented Mar 19, 2025

System Info

There are more details in here: google-deepmind/gemma#193

But shortly; it seems like multi-image training is not implemented yet

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The following code works:

import io
from typing import Any, cast
import requests
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from transformers import TrainingArguments
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from datasets import IterableDataset, Features
import datasets
from PIL import Image
import numpy as np

HF_TOKEN = "..."

def load_image(url):
        response = requests.get(url)
        image = Image.open(io.BytesIO(response.content))
        return image

def image_from_bytes(image_bytes):
    return Image.open(io.BytesIO(image_bytes))

def main():

    model_id = "google/gemma-3-4b-it"

    model = Gemma3ForConditionalGeneration.from_pretrained(
        model_id, device_map="auto", token=HF_TOKEN
    )
    model.config.use_cache = False  # Disable caching for training

    processor = AutoProcessor.from_pretrained(model_id, padding_side="right", token=HF_TOKEN)
    processor.tokenizer.pad_token = processor.tokenizer.eos_token  # Use eos token as pad token
    processor.tokenizer.padding_side = "right"

    def train_iterable_gen():
        N_IMAGES = 1
        image = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg").resize((896, 896))
        images = np.array([image] * N_IMAGES)
        print("IMAGES SHAPE", images.shape)
        yield {
                "images": images,
                "messages": [
                    {
                        "role": "user",
                        "content": [{"type": "image" } for _ in range(images.shape[0])]
                    },
                    {
                        "role": "assistant",
                        "content": [{"type": "text", "text": "duck" }]
                    }
                ]
            }
    train_ds = IterableDataset.from_generator(
         train_iterable_gen,
        features=Features({
            'images': [datasets.Image(mode=None, decode=True, id=None)],
            'messages': [{'content': [{'text': datasets.Value(dtype='string', id=None), 'type': datasets.Value(dtype='string', id=None) }], 'role': datasets.Value(dtype='string', id=None)}]
            } )
    )

    def collate_fn(examples):
        # Get the texts and images, and apply the chat template
        texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
        images = [example["images"] for example in examples]

        # Tokenize the texts and process the images
        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        print("collate_fn pixel_values", batch["pixel_values"].shape)
        print("collate_fn input_ids", batch["input_ids"].shape)

        # The labels are the input_ids, and we mask the padding tokens in the loss computation
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100
        labels[labels == processor.image_token_id] = -100
        batch["labels"] = labels

        return batch

    # Set up LoRA configuration for causal language modeling
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj"],
        bias="none",
        task_type="CAUSAL_LM"
    )

    # Define training arguments
    training_args = SFTConfig(
        output_dir="./results",
        num_train_epochs=1,
        per_device_train_batch_size=1,
        learning_rate=2e-4,
        logging_steps=1,
        save_steps=25,
        report_to="tensorboard",
        group_by_length=False,
        remove_unused_columns=False,
        dataset_kwargs = {"skip_prepare_dataset": True},
        gradient_checkpointing_kwargs = dict(use_reentrant=False),
        max_steps=1
    )

    # Create the SFTTrainer with LoRA parameters
    trainer = SFTTrainer(
        model=model,
        train_dataset=cast(Any, train_ds),

        peft_config=lora_config,
        args=training_args,
        data_collator=collate_fn,
        processing_class=processor.tokenizer,
    )

    print("Training model...")
    trainer.train()
    print("Training complete.")

if __name__ == "__main__":
    main()

but if I increase N_IMAGES to 2 it crashes with the following error:

  File "/tmp/ray/session_2025-03-18_01-47-01_879621_311/runtime_resources/working_dir_files/_ray_pkg_95509c95a64411ba/.venv/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1333, in forward
    raise ValueError(
ValueError: Number of images does not match number of special image tokens in the input text. Got 512 image tokens in the text but 256 tokens from image embeddings.

Expected behavior

I'd expect either:

  1. The preprocessor to return pixel_values in the shape [batch, n_images, c, w, h], and that the model can handle that
  2. Or, something else needs fixing because if I put debug statements in the transformers library, it looks like it's only taking the first image in the batch of images

In the first case:

@zucchini-nlp
Copy link
Member

@FredrikNoren Gemma3 works with muti-images in inference, and thus I'd assume training should be no different

The inputs format expected with several images is as follows, can you verify the train script follows it:

images = [[im1, im2], [im3]]
texts  = ["Are these identical images? <image> <image>", "Describe this: <image>"]
inputs = processor(images=images, text=text, return_tensors='pt')
print(inputs['pixel_values'].shape) # 3, 3, 896, 896

@FredrikNoren
Copy link
Contributor Author

@zucchini-nlp Yup that's the format I'm giving the processor, and the output of the processor I'm getting is:

collate_fn pixel_values torch.Size([2, 3, 896, 896])
collate_fn input_ids torch.Size([1, 649])

But the problem is that the trainer seems to treat the first dimension of pixel_values as the batch_count, so it just takes the first image with the first input_ids and tries to train on that. Which is why it fails with the error Number of images does not match..; the gemma forward pass only receives a single image.

@zucchini-nlp
Copy link
Member

Oh, I see now. Actually gemma3 is not the first model where the first dim doesn't match with batch dim, earlier we had qwen2-vl where the first dim was image-seq-length. Lemme check out how it worked or if it worked

@zucchini-nlp
Copy link
Member

I would guess the training here is also multi-GPU. so I'll link it to #33666

@FredrikNoren
Copy link
Contributor Author

@zucchini-nlp I'm planning to do multi-GPU as well, but I haven't yet. So far this is just single-GPU.

@zucchini-nlp
Copy link
Member

@FredrikNoren I asked internally the team and found that training was tested and works with gemma3 multi-image cases. So I believe there is something wrong in the way your train is set up. Could you please open the issue in TRL repo instead?

@FredrikNoren
Copy link
Contributor Author

@zucchini-nlp Hm I suspect it's not a trl issue but I've created an issue for them here now: huggingface/trl#3121

Would you mind asking the team that was able to train internally if they can share their code?

Also; what is the expected output format of collate_fn in the case of multi-image training?

@zucchini-nlp
Copy link
Member

Yes sure, the team is aware if the issue and someone from TRL will take a look soon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants