Skip to content

UnboundLocalError: cannot access local variable 'images_list' when using Gemma 3 AutoProcessor with use_fast=True #36739

Closed
@Zebz13

Description

@Zebz13

System Info

  • transformers version: 4.50.0.dev0
  • Platform: Linux-6.13.6-100.fc40.x86_64-x86_64-with-glibc2.39
  • Python version: 3.12.9
  • Huggingface_hub version: 0.29.1
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.4.0+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 4060 Laptop GPU

Who can help?

@ArthurZucker
Terribly sorry if it's the wrong person! I hope this passes as Text Model.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Used base code given on HF for Gemma 3(modified to local path): https://huggingface.co/google/gemma-3-4b-it#running-the-model-on-a-singlemulti-gpu

Added use_fast=True to the AutoProcessor arguments.

Using a slow image processor as use_fast is unset and a slow processor was saved with this model. use_fast=True will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with use_fast=False.

from PIL import Image
import requests
import torch

model_id = "gemma-3-4b-it"

model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id, device_map="auto"
).eval()

processor = AutoProcessor.from_pretrained(model_id,use_fast=True)

messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": "You are a helpful assistant."}]
    },
    {
        "role": "user",
        "content": [
            {"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
            {"type": "text", "text": "Describe this image in detail."}
        ]
    }
]

inputs = processor.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,
    return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]

decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)

# **Overall Impression:** The image is a close-up shot of a vibrant garden scene, 
# focusing on a cluster of pink cosmos flowers and a busy bumblebee. 
# It has a slightly soft, natural feel, likely captured in daylight.

Reason (I think): images_list variable is declared under if do_pan_and_scan. If do_pan_and_scan is not enabled, images_list is not available and hence it will error out.
Same variable is used for group_images_by_shape in Line No: 294.
Lines:

for image_list in images:
if do_pan_and_scan:
images_list, num_crops = self._process_images_for_pan_and_scan(
images=image_list,
do_pan_and_scan=do_pan_and_scan,
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
)
else:
num_crops = [[0] for images in images_list]
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list)

Since images variable is being passed as List[List["torch.Tensor"]], passing images_list = image_list in else case can fix the issue. There might be a better way for fixing this. Local fix which I'm using:

            else:
                # assign variable to bypass unbounded error
                images_list = image_list
                num_crops = [[0] for _ in images_list]
            # Group images by size for batched processing
            processed_image_patches_grouped = {}
            grouped_image_patches, grouped_image_patches_index = group_images_by_shape(images_list)

Got the idea from got_ocr2:

Expected behavior

Returns output without failure.

Sample (clips out due to token limit):


**Overall Impression:**

The image is a close-up shot of a vibrant garden scene, focusing on a cluster of pink cosmos flowers and a busy bumblebee. It has a slightly soft, natural feel, likely due to the lighting and the focus on the foreground.

**Foreground:**

*   **Cosmos Flowers:** The main subject is a group of pink cosmos flowers. The flower in the center is the most prominent,```

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions