Closed
Description
System Info
transformers
version: 4.50.0.dev0- Platform: Linux-6.8.0-39-generic-x86_64-with-glibc2.35
- Python version: 3.11.10
- Huggingface_hub version: 0.29.3
- Safetensors version: 0.5.3
- Accelerate version: 1.5.2
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA GeForce RTX 4090
Who can help?
Gemma 3 works fine with bfloat16 but the output is empty with float16.
@amyeroberts, @qubvel @ArthurZucker
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
import torch
device = 'cuda:0'
compute_dtype = torch.float16 #bfloat16 works fine
cache_dir = None
model_id = 'google/gemma-3-4b-it'
from transformers import Gemma3ForConditionalGeneration, AutoProcessor
processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_dir)
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype=compute_dtype, attn_implementation="sdpa", cache_dir=cache_dir, device_map='cuda')
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
}
]
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=compute_dtype)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=128, do_sample=False)[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
Expected behavior
Gemma 3 should work with float16 weights too.