-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bfloat16 tensor .numpy() support #90574
Comments
Would you share more detailed information like which kind case motivated this feature? |
I'm not sure that that upcast is always the right thing to do. Arguably some users would want |
I thought there could be a way to set the default conversion type. This could be the best of both worlds. Raise an exception if there isn't a default type, and otherwise, use it. This would allow using frameworks that were built without bfloat16 in mind to work at all, like |
Again, the situation is that a number of frameworks are written assuming that you can do Implementing something like |
one way is to have some default preserving precision (e.g. |
That sounds good to me. Having a way to set a global state for the conversion default (per dtype or for bfloat16 specifically) and then allowing the specification of output dtypes in .numpy() sounds reasonable. The part with the global state is what sounds more important to me now. @rgommers would be curious to know your thoughts on this. |
For completeness: there are more important exceptions (non-cpu and requires-grad tensors). So it's code that is doing Related: gh-36560 wanted to make that easier by adding a An alternative here is to make
Global state is pretty painful to manage though, you start having to worry about things like multi-threading/processing. If a user has to set anything at all, I don't think it matters if it's global state, or using |
The idea with a global call is to not have to modify the code of the other
frameworks, that are calling .numpy without thinking it may not work. So
adding a parameter to the `.numpy` call does not fix this.
There seems to be machinery to do something related to this, as there is a
global for the default tensor dtype.
(Also, as of `1.13.0`, `.numpy(force=True)` doesn't work with bfloat16, by
the way)
…On Tue., Dec. 13, 2022, 2:48 p.m. Ralf Gommers, ***@***.***> wrote:
Again, the situation is that a number of frameworks are written assuming
that you can do tensor.numpy() at any time, which is true for everything
but for bfloat16.
For completeness: there are more important exceptions (non-cpu and
requires-grad tensors). So it's code that is doing
a_tensor.clone().cpu().numpy() (e.g here
<https://github.com/DLR-RM/stable-baselines3/blob/6d55a09f810bc0d7d38ad04ade92f2b720308b58/stable_baselines3/common/buffers.py#L443>
for stable_baselines).
Related: gh-36560 <#36560>
wanted to make that easier by adding a force keyword, and .numpy()
already has that: Tensor.numpy docs
<https://pytorch.org/docs/stable/generated/torch.Tensor.numpy.html#torch.Tensor.numpy>
.
An alternative here is to make force=True cast bfloat16 to float32. That
would answer the request here without more special-case keywords.
This might work because numpy() isn't being called in a recursive way
currently, so might uphold propagation of this global state onto user.
Global state is pretty painful to manage though, you start having to worry
about things like multi-threading/processing.
If a user has to set anything at all, I don't think it matters if it's
global state, or using .astype, or force=True. WDYT about using force=True
?
—
Reply to this email directly, view it on GitHub
<#90574 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAYU34NNH2FDSZQITCGAIATWNDHJVANCNFSM6AAAAAASZZT6OU>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
That was my point - let's make it work by upcasting to
To me that's not enough of a reason to add global state personally (the code that doesn't work can be improved instead), so I'm -0.5 on this one. It's not my decision though, so perhaps @mruberry or someone else can weigh in here. |
(for extra more controls, given |
For any Googlers finding this discussion, the best workaround I've found so far is converting to fp32 manually: if embeddings.dtype == torch.bfloat16:
embeddings = embeddings.float()
embeddings = embeddings.numpy() |
🚀 The feature, motivation and pitch
Numpy doesn't support bfloat16, and doesn't plan to do so. The effect of this is that code that makes any
tensor.numpy()
call breaks when you make it use bfloat16. I was thinking that bfloat16 getting outputted tonp.float32
would make sense, as it just keeps the exponent and ads a few mantissa bits. This must be very quick. This would make all code that is supported with float32 or float16 be compatible with bfloat16 out of the box, and feels like reasonable behavior to me.Additional context
The
to_numpy
function seems to be here https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_numpy.cpp#L159and the function that decides the output
np.dtype
seems to be here:https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_numpy.cpp#L267
cc @mruberry @rgommers
The text was updated successfully, but these errors were encountered: