Skip to content

Triton Error [CUDA]: invalid device context when autograd.backward a triton kernel #124565

@ppwwyyxx

Description

@ppwwyyxx

🐛 Describe the bug

import torch
import triton
import triton.language as tl

class Rua(torch.autograd.Function):

    @staticmethod
    def rua(hidden):
        output = torch.empty_like(hidden)
        _rua_kernel[(1,)](hidden)
        return output

    @staticmethod
    def forward(ctx, hidden):
        return Rua.rua(hidden)

    @staticmethod
    def backward(ctx, output_grad):
        return Rua.rua(output_grad)

@triton.jit
def _rua_kernel(hidden_sh_ptr):
    return

if __name__ == "__main__":

    input_data = torch.empty(1, device='cuda', requires_grad=True)
    output_grad = torch.empty(1, device='cuda')

    output = Rua.apply(input_data)
    # torch.empty(1, device='cuda', requires_grad=True).backward()
    torch.autograd.backward(output, output_grad)
    torch.cuda.synchronize()

Prints:

Traceback (most recent call last):
  File "/home/X/a.py", line 33, in <module>
    torch.autograd.backward(output, output_grad)
  File "/home/X/.local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/X/.local/lib/python3.11/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/X/a.py", line 20, in backward
    return Rua.rua(output_grad)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/X/a.py", line 11, in rua
    _rua_kernel[(1,)](hidden)
  File "/home/X/.local/lib/python3.11/site-packages/triton/runtime/jit.py", line 550, in run
    bin.c_wrapper(
RuntimeError: Triton Error [CUDA]: invalid device context

If I uncomment the commented line in the code, it runs correctly.

Versions

PyTorch version: 2.2.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 13.2.1 20230801
Clang version: 17.0.6
CMake version: version 3.29.2
Libc version: glibc-2.39

Python version: 3.11.8 (main, Feb 12 2024, 14:50:05) [GCC 13.2.1 20230801] (64-bit runtime)
Python platform: Linux-6.6.27-1-lts-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080
Nvidia driver version: 550.67

cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @Varal7 @ptrblck @msaroufim @bertmaher @int3 @davidberard98 @nmacchioni @lezcano

Metadata

Metadata

Assignees

No one assigned

    Labels

    has workaroundmodule: autogradRelated to torch.autograd, and the autograd engine in generalmodule: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleupstream tritonUpstream Triton Issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions