Description
System Info
transformers
version: 4.50.0.dev0- Platform: Linux-6.13.6-100.fc40.x86_64-x86_64-with-glibc2.39
- Python version: 3.12.9
- Huggingface_hub version: 0.29.1
- Safetensors version: 0.4.5
- Accelerate version: 0.34.2
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (GPU?): 2.4.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA GeForce RTX 4060 Laptop GPU
Who can help?
@ArthurZucker
Terribly sorry if it's the wrong person! I hope this passes as Text Model.
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Used base code given on HF for Gemma 3(modified to local path): https://huggingface.co/google/gemma-3-4b-it#running-the-model-on-a-singlemulti-gpu
Added use_fast=True to the AutoProcessor arguments.
Using a slow image processor as
use_fast
is unset and a slow processor was saved with this model.use_fast=True
will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor withuse_fast=False
.
from PIL import Image
import requests
import torch
model_id = "gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(model_id,use_fast=True)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
}
]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
# **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
# focusing on a cluster of pink cosmos flowers and a busy bumblebee.
# It has a slightly soft, natural feel, likely captured in daylight.
Reason (I think): images_list
variable is declared under if do_pan_and_scan
. If do_pan_and_scan
is not enabled, images_list
is not available and hence it will error out.
Same variable is used for group_images_by_shape
in Line No: 294.
Lines:
transformers/src/transformers/models/gemma3/image_processing_gemma3_fast.py
Lines 280 to 294 in 6f3e0b6
Since images variable is being passed as List[List["torch.Tensor"]]
, passing images_list = image_list
in else case can fix the issue. There might be a better way for fixing this. Local fix which I'm using:
else:
# assign variable to bypass unbounded error
images_list = image_list
num_crops = [[0] for _ in images_list]
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list)
Got the idea from got_ocr2:
Expected behavior
Returns output without failure.
Sample (clips out due to token limit):
**Overall Impression:**
The image is a close-up shot of a vibrant garden scene, focusing on a cluster of pink cosmos flowers and a busy bumblebee. It has a slightly soft, natural feel, likely due to the lighting and the focus on the foreground.
**Foreground:**
* **Cosmos Flowers:** The main subject is a group of pink cosmos flowers. The flower in the center is the most prominent,```