Skip to content

Component loading incorrect dtype #36686

Closed
@dsocek

Description

@dsocek

System Info

  • transformers version: 4.50.0.dev0
  • Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • Huggingface_hub version: 0.29.3
  • Safetensors version: 0.4.5
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.6.0a0+df5bbc09d1.nv24.12 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@zucchini-nlp @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Some components are loaded with incorrect dtype. @zucchini-nlp git bisect analysis shows that this bug was introduced in 84a6789

84a6789145c3d728f2e405d31e9a35df5d74f05c is the first bad commit

Please install diffusers for a simple reproducer:

pip install diffusers

Reproducer script (we instantiate SDXL pipeline with BF16 dtype, and expect components to be of this type):

import torch
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.bfloat16,
)
print(pipe.text_encoder.dtype)
print(pipe.text_encoder_2.dtype")

Output (correct) before 84a6789 :

torch.bfloat16
torch.bfloat16

Output (incorrect) after 84a6789 was merged:

torch.bfloat16
torch.float16

Now 2nd text encoder is incorrectly loaded as torch.float16 which can cause a series of issues (e.g. see huggingface/optimum-habana#1815)

@zucchini-nlp can you please help address this issue?

Expected behavior

2nd text encoder should be loaded as BF16 torch.bfloat16 and not as FP16 torch.float16. This was correctly the case before the bug was introduced.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions