Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_dtype is actually used now? #36567

Open
4 tasks
dakinggg opened this issue Mar 5, 2025 · 6 comments
Open
4 tasks

torch_dtype is actually used now? #36567

dakinggg opened this issue Mar 5, 2025 · 6 comments
Labels

Comments

@dakinggg
Copy link
Contributor

dakinggg commented Mar 5, 2025

System Info

different transformers versions. see description

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Previously (v4.46.3, didn't check all versions), torch_dtype in the config was ignored, meaning that model weights would get loaded in fp32 by default (correct behavior for training). On latest transformers version (v4.49.0), it seems it is now used, and so the weights get loaded with whatever is in the checkpoint. Was this change intentional? I previously recall seeing somewhere in the code that you weren't going to make the change to actually use torch_dtype until v5, and I didn't see anything in release notes at a glance, although maybe I missed it.

In [1]: import transformers

In [2]: llama1bcfg = transformers.AutoConfig.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')

In [3]: llama1b = transformers.AutoModelForCausalLM.from_config(llama1bcfg)

In [4]: next(llama1b.parameters()).dtype
Out[4]: torch.bfloat16

Expected behavior

Not actually sure, would like to confirm what you expect now.

@dakinggg dakinggg added the bug label Mar 5, 2025
@zucchini-nlp
Copy link
Member

Yep, that was cause by one of PRs to enable different dtype setting for each backbone in multimodal models. Indeed a breaking change, but I opted to not revert it since I expect this is an intuitive behavior

@ArthurZucker WDYT? If we revert, we either lose dtypes for multimodals or have to manually pass dtype when init backbone in each model file

@dakinggg
Copy link
Contributor Author

dakinggg commented Mar 7, 2025

FWIW I think its a pretty big change to make silently, with no warnings, no notice in the release, etc. Especially because of this comment that was (still is actually) in the code:

# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
# we currently don't use this setting automatically, but may start to use with v5

@dakinggg
Copy link
Contributor Author

Also it seems to be different when using different construction APIs

In [1]: import transformers

In [2]: llama1bcfg = transformers.AutoConfig.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')

In [3]: llama1bcfg.torch_dtype
Out[3]: torch.bfloat16

In [4]: llama1b_frompt = transformers.AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')

In [5]: next(llama1b_frompt.parameters()).dtype
Out[5]: torch.float32

In [6]: llama1b_fromcfg = transformers.AutoModelForCausalLM.from_config(llama1bcfg)

In [7]: next(llama1b_fromcfg.parameters()).dtype
Out[7]: torch.bfloat16

dakinggg added a commit to databricks/compose-rl that referenced this issue Mar 14, 2025
Repeating mosaicml/llm-foundry#1734 here.

huggingface/transformers#36567 for more
context.

Also turns out that FSDP wrapping in Composer hangs if you mix dtypes on
different ranks.

Sorry for the type ignores, they came with the transformers upgrade. Not
sure why pyright thinks all the special tokens are strings...
```
In [1]: import transformers

In [2]: llamat = transformers.AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')

In [3]: llamat.eos_token_id
Out[3]: 128009

In [4]: type(llamat.eos_token_id)
Out[4]: int

In [5]: transformers.__version__
Out[5]: '4.49.0'
```
@ArthurZucker
Copy link
Collaborator

Sorry about that, completely missed the notification... it is a big breaking change that should not have made it to the release

@ArthurZucker
Copy link
Collaborator

As we did not have much report, we are just gonna update the release notes and the code there ! Thanks a lot @dakinggg

@dakinggg
Copy link
Contributor Author

FWIW, I'd be a bit worried that people just hadn't noticed, because sometimes this will error (e.g. you have different layers of a model initialized differently, and some part of the training stack doesn't like having different dtypes for different params), but if you don't have unit tests to check this, and don't encounter any error cases, you will just accidentally. use bf16 weights instead of fp32 (likely fine if you are just doing inference, and not fine if you are trianing)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants