Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma 3 is broken with fp16 #36822

Open
1 of 4 tasks
mobicham opened this issue Mar 19, 2025 · 2 comments · May be fixed by #36832
Open
1 of 4 tasks

Gemma 3 is broken with fp16 #36822

mobicham opened this issue Mar 19, 2025 · 2 comments · May be fixed by #36832
Labels

Comments

@mobicham
Copy link
Contributor

System Info

  • transformers version: 4.50.0.dev0
  • Platform: Linux-6.8.0-39-generic-x86_64-with-glibc2.35
  • Python version: 3.11.10
  • Huggingface_hub version: 0.29.3
  • Safetensors version: 0.5.3
  • Accelerate version: 1.5.2
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 4090

Who can help?

Gemma 3 works fine with bfloat16 but the output is empty with float16.
@amyeroberts, @qubvel @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
device        = 'cuda:0'
compute_dtype = torch.float16 #bfloat16 works fine
cache_dir     = None
model_id      = 'google/gemma-3-4b-it'


from transformers import Gemma3ForConditionalGeneration, AutoProcessor
processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_dir)
model     = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype=compute_dtype, attn_implementation="sdpa", cache_dir=cache_dir, device_map='cuda')


messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": "You are a helpful assistant."}]
    },
    {
        "role": "user",
        "content": [
            {"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
            {"type": "text", "text": "Describe this image in detail."}
        ]
    }
]

inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=compute_dtype)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**inputs, max_new_tokens=128, do_sample=False)[0][input_len:]
    decoded    = processor.decode(generation, skip_special_tokens=True)

print(decoded)

Expected behavior

Gemma 3 should work with float16 weights too.

@mobicham mobicham added the bug label Mar 19, 2025
@Rocketknight1
Copy link
Member

Hi @mobicham, I don't know if we put a lot of effort into supporting float16 inference for models anymore! In general, bfloat16 is the new standard for model training, and float16 is increasingly a legacy format.

If you (or anyone reading this) can find a simple fix that will enable float16 inference for Gemma3, we can probably accept a PR for that, but we won't be able to prioritize debugging it and adding the feature ourselves.

@mobicham
Copy link
Contributor Author

mobicham commented Mar 19, 2025

@Rocketknight1 that is not exactly true. The efficient low-bit kernels to run quantized models faster mainly support fp16 not bfp16 (because of some limitations related to atomic addition with bfp16) - this includes gemlite and Marlin in vllm.
I actually found the issue and will do a PR shortly.

This was referenced Mar 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants