Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

BFloat16 on cuda: add triu/tril support #101932

Closed
Maykeye opened this issue May 20, 2023 · 4 comments
Closed

BFloat16 on cuda: add triu/tril support #101932

Maykeye opened this issue May 20, 2023 · 4 comments

Comments

@Maykeye
Copy link

Maykeye commented May 20, 2023

馃殌 The feature, motivation and pitch

Right now if you try to use torch.triu on bfloat16 tensor (I hit it when was training simple network with AMP) you'll get a error that triu is not support

In [8]: torch.__version__
Out[8]: '2.0.1+cu117'

In [9]: torch.arange(4).reshape(2,2).bfloat16().cuda().triu()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 1
----> 1 torch.arange(4).reshape(2,2).bfloat16().cuda().triu()

RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'

It would be nice to have it.

Alternatives

Can be replaced with multiplication against torch.ones.triu()

   self.register_buffer("triu0", torch.ones(n, n).triu()) # __init__
   ...
   y = x * self.triu0 # forward()
   

Additional context

No response

@Maykeye
Copy link
Author

Maykeye commented May 21, 2023

Oh.. Fixed by merged #101414 : if I try my example in freshly build docker container, it will work

>>> torch.__version__
'2.1.0a0+git22ca1a1'
>>> torch.arange(4).reshape(2,2).bfloat16().cuda().triu()
tensor([[0., 1.],
        [0., 3.]], device='cuda:0', dtype=torch.bfloat16)

@Maykeye Maykeye closed this as completed May 21, 2023
@18335100284
Copy link

馃殌 The feature, motivation and pitch

Right now if you try to use torch.triu on bfloat16 tensor (I hit it when was training simple network with AMP) you'll get a error that triu is not support

In [8]: torch.__version__
Out[8]: '2.0.1+cu117'

In [9]: torch.arange(4).reshape(2,2).bfloat16().cuda().triu()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 1
----> 1 torch.arange(4).reshape(2,2).bfloat16().cuda().triu()

RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'

It would be nice to have it.

Alternatives

Can be replaced with multiplication against torch.ones.triu()

   self.register_buffer("triu0", torch.ones(n, n).triu()) # __init__
   ...
   y = x * self.triu0 # forward()
   

Additional context

No response

Very nice idea, but are there any potential problems with this implementation? @Maykeye

@guotong1988
Copy link

guotong1988 commented Jun 4, 2024

So torch==2.1.0 have fixed this?

@guotong1988
Copy link

How to install 2.1.0a0+git22ca1a1? Thank you @Maykeye

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants