Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bfloat16 tensor .numpy() support #90574

Open
JulesGM opened this issue Dec 9, 2022 · 11 comments
Open

Bfloat16 tensor .numpy() support #90574

JulesGM opened this issue Dec 9, 2022 · 11 comments
Labels
module: bfloat16 module: numpy Related to numpy support, and also numpy compatibility of our operators onnx-needs-info needs information from the author / reporter before ONNX team can take action triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@JulesGM
Copy link

JulesGM commented Dec 9, 2022

🚀 The feature, motivation and pitch

Numpy doesn't support bfloat16, and doesn't plan to do so. The effect of this is that code that makes any tensor.numpy() call breaks when you make it use bfloat16. I was thinking that bfloat16 getting outputted to np.float32 would make sense, as it just keeps the exponent and ads a few mantissa bits. This must be very quick. This would make all code that is supported with float32 or float16 be compatible with bfloat16 out of the box, and feels like reasonable behavior to me.

Additional context

The to_numpy function seems to be here https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_numpy.cpp#L159

and the function that decides the output np.dtype seems to be here:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_numpy.cpp#L267

cc @mruberry @rgommers

@jingxu10
Copy link
Collaborator

Would you share more detailed information like which kind case motivated this feature?

@samdow samdow added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: numpy Related to numpy support, and also numpy compatibility of our operators labels Dec 12, 2022
@rgommers
Copy link
Collaborator

I'm not sure that that upcast is always the right thing to do. Arguably some users would want np.float32 for precision, and some others would want np.float16 for memory use. So probably raising an exception with an informative error message for bfloat16 is the appropriate thing to do.

@JulesGM
Copy link
Author

JulesGM commented Dec 12, 2022

I thought there could be a way to set the default conversion type. This could be the best of both worlds. Raise an exception if there isn't a default type, and otherwise, use it.

This would allow using frameworks that were built without bfloat16 in mind to work at all, like stable_baselines3.

@JulesGM
Copy link
Author

JulesGM commented Dec 12, 2022

Again, the situation is that a number of frameworks are written assuming that you can do tensor.numpy() at any time, which is true for everything but for bfloat16.

Implementing something like torch.default_bfloat16_numpy_type(torch.float32) would solve this problem in a very reasonably clean way.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Dec 12, 2022

one way is to have some default preserving precision (e.g. .numpy(upcast=True, downcast=False) or .numpy(dtype=None) / .numpy(dtype=torch.float32) / .numpy(dtype=torch.float16)) and then having the user to maintain/pass global state for passing to numpy(). This might work because numpy() isn't being called in a recursive way currently, so might uphold propagation of this global state onto user. Enabling existing libraries for bfloat16 might be an argument for adding torch global state, but in my experience it's better in the mid-term to inflict some changes than adding more global state

@JulesGM
Copy link
Author

JulesGM commented Dec 12, 2022

That sounds good to me. Having a way to set a global state for the conversion default (per dtype or for bfloat16 specifically) and then allowing the specification of output dtypes in .numpy() sounds reasonable. The part with the global state is what sounds more important to me now. @rgommers would be curious to know your thoughts on this.

@rgommers
Copy link
Collaborator

Again, the situation is that a number of frameworks are written assuming that you can do tensor.numpy() at any time, which is true for everything but for bfloat16.

For completeness: there are more important exceptions (non-cpu and requires-grad tensors). So it's code that is doing a_tensor.clone().cpu().numpy() (e.g here for stable_baselines).

Related: gh-36560 wanted to make that easier by adding a force keyword, and .numpy() already has that: Tensor.numpy docs.

An alternative here is to make force=True cast bfloat16 to float32. That would answer the request here without more special-case keywords.

This might work because numpy() isn't being called in a recursive way currently, so might uphold propagation of this global state onto user.

Global state is pretty painful to manage though, you start having to worry about things like multi-threading/processing.

If a user has to set anything at all, I don't think it matters if it's global state, or using .astype, or force=True. WDYT about using force=True?

@JulesGM
Copy link
Author

JulesGM commented Dec 13, 2022 via email

@rgommers
Copy link
Collaborator

(Also, as of 1.13.0, .numpy(force=True) doesn't work with bfloat16, by the way)

That was my point - let's make it work by upcasting to float32.

The idea with a global call is to not have to modify the code of the other frameworks

To me that's not enough of a reason to add global state personally (the code that doesn't work can be improved instead), so I'm -0.5 on this one. It's not my decision though, so perhaps @mruberry or someone else can weigh in here.

@vadimkantorov
Copy link
Contributor

That was my point - let's make it work by upcasting to float32.

(for extra more controls, given force=True, one might add another argument backup_dtype - but maybe should be only introduced if there're user requests)

@crypdick
Copy link

crypdick commented Mar 7, 2024

For any Googlers finding this discussion, the best workaround I've found so far is converting to fp32 manually:

if embeddings.dtype == torch.bfloat16:
    embeddings = embeddings.float()
embeddings = embeddings.numpy()

@thiagocrepaldi thiagocrepaldi added the onnx-needs-info needs information from the author / reporter before ONNX team can take action label Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: bfloat16 module: numpy Related to numpy support, and also numpy compatibility of our operators onnx-needs-info needs information from the author / reporter before ONNX team can take action triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants