-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Closed
Copy link
Labels
has workaroundmodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleupstream tritonUpstream Triton IssueUpstream Triton Issue
Description
🐛 Describe the bug
import torch
import triton
import triton.language as tl
class Rua(torch.autograd.Function):
@staticmethod
def rua(hidden):
output = torch.empty_like(hidden)
_rua_kernel[(1,)](hidden)
return output
@staticmethod
def forward(ctx, hidden):
return Rua.rua(hidden)
@staticmethod
def backward(ctx, output_grad):
return Rua.rua(output_grad)
@triton.jit
def _rua_kernel(hidden_sh_ptr):
return
if __name__ == "__main__":
input_data = torch.empty(1, device='cuda', requires_grad=True)
output_grad = torch.empty(1, device='cuda')
output = Rua.apply(input_data)
# torch.empty(1, device='cuda', requires_grad=True).backward()
torch.autograd.backward(output, output_grad)
torch.cuda.synchronize()
Prints:
Traceback (most recent call last):
File "/home/X/a.py", line 33, in <module>
torch.autograd.backward(output, output_grad)
File "/home/X/.local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/X/.local/lib/python3.11/site-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/home/X/a.py", line 20, in backward
return Rua.rua(output_grad)
^^^^^^^^^^^^^^^^^^^^
File "/home/X/a.py", line 11, in rua
_rua_kernel[(1,)](hidden)
File "/home/X/.local/lib/python3.11/site-packages/triton/runtime/jit.py", line 550, in run
bin.c_wrapper(
RuntimeError: Triton Error [CUDA]: invalid device context
If I uncomment the commented line in the code, it runs correctly.
Versions
PyTorch version: 2.2.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Arch Linux (x86_64)
GCC version: (GCC) 13.2.1 20230801
Clang version: 17.0.6
CMake version: version 3.29.2
Libc version: glibc-2.39
Python version: 3.11.8 (main, Feb 12 2024, 14:50:05) [GCC 13.2.1 20230801] (64-bit runtime)
Python platform: Linux-6.6.27-1-lts-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080
Nvidia driver version: 550.67
cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @Varal7 @ptrblck @msaroufim @bertmaher @int3 @davidberard98 @nmacchioni @lezcano
akihironitta, BrokenArrow1404 and shoaibahmed
Metadata
Metadata
Assignees
Labels
has workaroundmodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleupstream tritonUpstream Triton IssueUpstream Triton Issue