Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Triton Error [CUDA]: device kernel image is invalid #1556

Closed
Lawrencium77 opened this issue Apr 20, 2023 · 7 comments · Fixed by #1593
Closed

RuntimeError: Triton Error [CUDA]: device kernel image is invalid #1556

Lawrencium77 opened this issue Apr 20, 2023 · 7 comments · Fixed by #1593

Comments

@Lawrencium77
Copy link

Lawrencium77 commented Apr 20, 2023

I'm encountering an error when running kernels on some machines.

It is very sensitive to the exact kernel code that's written. Even trivial changes such as trimming whitespace or adding/removing comments can change whether or not it occurs. In addition to the above message (given in title), there is a KeyError, which looks quite similar to #1512.

Min repro and detailed hardware & environment info are below.

Let me know if more details are required. Any help would be much appreciated!

Kernel Code

cast.py

import torch
import triton
import triton.language as tl

def cdiv(x: int, y: int):
	return (x + y - 1) // y

class TritonCast(torch.nn.Module):
	def forward(self, x):
		x, out, M, BLOCK_SIZE = _extract_meta_params(x)
		_triton_cast[(M,)](x, out, x.numel(), BLOCK_SIZE)
		return out

@triton.jit
def _triton_cast(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
	pid = tl.program_id(axis=0)
	block_start = pid * BLOCK_SIZE
	offsets = block_start + tl.arange(0, BLOCK_SIZE)
	mask = offsets < n_elements
	y = tl.load(x_ptr + offsets, mask=mask)
	tl.store(out_ptr + offsets, y, mask=mask)

def _extract_meta_params(x):
	x = x.flatten()
	out = torch.empty_like(x, dtype=torch.int8)
	BLOCK_SIZE = 1024
	M = cdiv(x.numel(), BLOCK_SIZE)
	return x, out, M, BLOCK_SIZE

test.py

import torch
from unittests.ops.cast import TritonCast

FEAT_DIM = 1920

def test_kernel(bsz=1, seq_len=30):
	kernel = TritonCast().cuda().half()
	x = torch.randn((seq_len, bsz, FEAT_DIM), requires_grad=True, device="cuda", dtype=torch.float16)
	kernel(x) 

if __name__ == "__main__":
	test_kernel()

Full Traceback

❯ python3 unittests/ops/test.py
Traceback (most recent call last):
  File "<string>", line 21, in _triton_cast
KeyError: ('2-.-0-.-0-1e8410f206c822547fb50e2ea86e45a6-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float16, torch.int8, 'i32'), (1024,), (True, True, (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "unittests/ops/test.py", line 12, in <module>
    test_kernel()
  File "unittests/ops/test.py", line 9, in test_kernel
    kernel(x)
  File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lawrencea/git/aladdin3/unittests/ops/cast.py", line 11, in forward
    _triton_cast[(M,)](x, out, x.numel(), BLOCK_SIZE)
  File "<string>", line 43, in _triton_cast
  File "/venv/lib/python3.8/site-packages/triton/compiler.py", line 1679, in __getattribute__
    self._init_handles()
  File "/venv/lib/python3.8/site-packages/triton/compiler.py", line 1672, in _init_handles
    mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
RuntimeError: Triton Error [CUDA]: device kernel image is invalid

Environment

  • PyTorch 2.0.0a0+gitc263bd4
  • NVCC V11.6.124
  • CUDA Driver version 11.7
  • Triton 2.0.0

Hardware-Specifics

Whether this problem occurs on a specific device is very unpredictable. Even trivial changes, such as adding/removing comments or trimming whitespace can cause it to appear/disappear. I've seen it on:

  • NVIDIA GeForce RTX 2080 Ti
  • NVIDIA GeForce RTX 3090
  • NVIDIA A100-SXM4-40GB

For the specific code above, this error:

  • Occurs on NVIDIA GeForce RTX 2080 Ti
  • Doesn't occur on NVIDIA GeForce RTX 3090
  • Occurs on NVIDIA A100-SXM4-40GB
ptillet pushed a commit that referenced this issue Apr 29, 2023
…oad (#1593)

Closes #1556
#1512

The current hash used for caching the cubin does not include the
architecture. This leads to the following error when compiling against
one arch and running against another (with no code changes to trigger a
recompilation).
```
RuntimeError: Triton Error [CUDA]: device kernel image is invalid
```
Was not sure what unit tests would be appropriate here (if any)

Co-authored-by: davidma <davidma@speechmatics.com>
@egaebel
Copy link

egaebel commented May 1, 2023

Hi, I'm seeing this issue even after reinstalling triton just this morning with:

pip install -U --pre --force-reinstall triton

I see the problem on 2/3 of my GPUs:

RTX 6000 Ada: Error occurs
RTX A6000: Error occurs
Titan RTX: Code works

My code wraps torch.compile around one piece of the network I'm training.

@ptillet
Copy link
Collaborator

ptillet commented May 1, 2023

Nightly wheels aren't back. You'll need to reinstall from source

jayfurmanek added a commit to ROCm/triton that referenced this issue May 22, 2023
* [OPTIMIZER] simplified pipeline pass (triton-lang#1582)

directly rematerialize for loop with the right values, instead of
replacing unpipelined load uses a posteriori

* [OPTIMIZER] Added kWidth attribute to DotOperandEncoding (triton-lang#1584)

This is a pre-requisist for efficient mixed-precision matmul

* [TEST] Fix test cache (triton-lang#1588)

To avoid puzzling segment fault problems caused by multiprocessing, this
PR:

- Uses "spawn" instead of "fork".
- Define the `instance_descriptor` namedtuple globally.
- Make the `kernel_sub` JITFunction defined by the child process only.

* [BACKEND] Updated slice layout semantics, updated vectorization logic used for load/store ops. (triton-lang#1587)

* [FRONTEND][BACKEND] Add the `noinline` annotation for `triton.jit` (triton-lang#1568)

# Introducing the `noinline` Parameter for Triton JIT Decorator

We're excited to introduce a new parameter, `noinline`, that can be
added to the `jit` decorator in Triton. This parameter allows developers
to specify that a particular Triton function should not be inlined into
its callers. In this post, we'll dive into the syntax, purpose, and
implementation details of this new feature.

## Syntax

To use the `noinline` parameter, simply add `noinline=True` to the `jit`
decorator for the function that you don't want to be inlined. Here's an
example:

```python
@triton.jit(noinline=True)
def device_fn(x, y, Z):
    z = x + y
    tl.store(Z, z)

def test_noinline():
    @triton.jit
    def kernel(X, Y, Z):
        x = tl.load(X)
        y = tl.load(Y)
        device_fn(x, y, Z)
```

In this example, the `device_fn` function is decorated with
`@triton.jit(noinline=True)`, indicating that it should not be inlined
into its caller, `kernel`.

## Purpose

The `noinline` parameter serves several key purposes:

- Reducing code size: By preventing inlining, we can reduce the size of
the compiled code.
- Facilitating debugging: Keeping functions separate can make it easier
to debug the code.
- Avoiding common subexpression elimination (CSE) in certain cases: CSE
can sometimes be avoided by using the `noinline` parameter to reduce
register pressure.
- Enabling dynamic linking: This parameter makes it possible to
dynamically link Triton functions.

## Implementation

The implementation of the `noinline` parameter involves significant
changes to three analysis modules in Triton: *Allocation*, *Membar*, and
*AxisInfo*. Prior to this update, these modules assumed that all Triton
functions had been inlined into the root kernel function. With the
introduction of non-inlined functions, we've had to rework these
assumptions and make corresponding changes to the analyses.

### Call Graph and Limitations

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234663904-12864247-3412-4405-987b-6991cdf053bb.png"
alt="figure 1" width="200" height="auto">
</div>

To address the changes, we build a call graph and perform all the
analyses on the call graph instead of a single function. The call graph
is constructed by traversing the call edges and storing them in an edge
map. Roots are extracted by checking nodes with no incoming edges.

The call graph has certain limitations:

- It does not support recursive function calls, although this could be
implemented in the future.
- It does not support dynamic function calls, where the function name is
unknown at compilation time.

### Allocation

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665110-bf6a2660-06fb-4648-85dc-16429439e72d.png"
alt="figure 2" width="400" height="auto">
</div>

In Triton, shared memory allocation is achieved through two operations:
`triton_gpu.convert_layout` and `triton_gpu.alloc_tensor`. The
`convert_layout` operation allocates an internal tensor, which we refer
to as a *scratch* buffer, while the `alloc_tensor` operation returns an
allocated tensor and is thus known as an *explicit* buffer.

To accommodate the introduction of function calls, we are introducing a
third type of buffer called a *virtual* buffer. Similar to scratch
buffers, virtual buffers are allocated internally within the scope of a
function call, and the buffers allocated by the called functions remain
invisible to subsequent operations in the calling function. However,
virtual buffers are distinct from scratch buffers in that the call
operation itself does not allocate memory—instead, it specifies the
total amount of memory required by all the child functions being called.
The actual allocation of buffers is performed by individual operations
within these child functions. For example, when invoking edge e1, no
memory is allocated, but the total amount of memory needed by function B
is reserved. Notably, the amount of shared memory used by function B
remains fixed across its call sites due to the consideration of dynamic
control flows within each function.

An additional challenge to address is the calculation of shared memory
offsets for functions within a call graph. While we can assume a shared
memory offset starting at 0 for a single root function, this is not the
case with a call graph, where we must determine each function's starting
offset based on the call path. Although each function has a fixed memory
consumption, the starting offset may vary. For instance, in Figure 2,
the starting offset of function C through edges e1->e2 differs from that
through edges e2->e4. To handle this, we accumulate the starting offset
at each call site and pass it as an argument to the called function.
Additionally, we amend both the function declaration and call sites by
appending an offset variable.

### Membar

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665157-844dd66f-5028-4ef3-bca2-4ca74b8f969d.png"
alt="figure 3" width="300" height="auto">
</div>

The membar pass is dependent on the allocation analysis. Once the offset
and size of each buffer are known, we conduct a post-order traversal of
the call graph and analyze each function on an individual basis. Unlike
previous analyses, we now return buffers that remain unsynchronized at
the end of functions, allowing the calling function to perform
synchronization in cases of overlap.

### AxisInfo

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665183-790a11ac-0ba1-47e1-98b1-e356220405a3.png"
alt="figure 4" width="400" height="auto">
</div>

The AxisInfo analysis operates differently from both membar and
allocation, as it traverses the call graph in topological order. This is
necessary because function arguments may contain axis information that
will be utilized by callee functions. As we do not implement
optimizations like function cloning, each function has a single code
base, and the axis information for an argument is determined as a
conservative result of all axis information passed by the calling
functions.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>

* [FRONTEND] add architecture to hash to avoid invalid image on cubin load (triton-lang#1593)

Closes triton-lang#1556
triton-lang#1512

The current hash used for caching the cubin does not include the
architecture. This leads to the following error when compiling against
one arch and running against another (with no code changes to trigger a
recompilation).
```
RuntimeError: Triton Error [CUDA]: device kernel image is invalid
```
Was not sure what unit tests would be appropriate here (if any)

Co-authored-by: davidma <davidma@speechmatics.com>

* [FRONTEND] Fix calling local variables’ attribute functions in the if statement (triton-lang#1597)

If `node.func` is an `ast.Attribute`, it won't cause an early return.
(Not sure if I interpret it correctly)

triton-lang#1591

* [OPTIMIZER][BACKEND] Enabled elementwise ops (including casts) between ldmatrix and mma.sync (triton-lang#1595)

* [RUNTIME] Ensure we hold the GIL before calling into CPython API in cubin binding (triton-lang#1583)

Formatting of the diff is not the best. I only indented the whole
function, moved the creation of the py::bytes and the return out of the
scope and declared and assigned the cubin variable appropriately.
Everything else is unchanged.

Today it triggers the following error on CPython debug build:
```
Fatal Python error: _PyMem_DebugMalloc: Python memory allocator called without holding the GIL
Python runtime state: initialized

```

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Philippe Tillet <phil@openai.com>

* Merge branch `llvm-head` (triton-lang#1600)

* Zahi/slice reduce rebased (triton-lang#1594)

[BACKEND] Enable slice layout support for reduce op

* [OPTIMIZER] Fix crash in loop pipelining. (triton-lang#1602)

Fixes issue triton-lang#1601.

* [FRONTEND] make torch optional (triton-lang#1604)

make torch optional to fix circular dependency issue

* [OPTIMIZER] Clean-up Utility.cpp and fixed bug in RematerializeForward (triton-lang#1608)

ConvertLayoutOp can be folded in other ConvertLayoutOp

* [BACKEND] Fixed up ConvertLayout for slices (triton-lang#1616)

* [FRONTEND] Add `tl.expand_dims` (triton-lang#1614)

This exposes `semantic.expand_dims` in the public API and builds upon it
with support for expanding multiple dimensions at once. e.g.
```python
tl.expand_dims(tl.arange(0, N), (0, -1))  # shape = [1, N, 1]
```

Compared to indexing with `None`, this API is useful because the
dimensions can be constexpr values rather than hard-coded into the
source. As a basic example
```python
@triton.jit
def max_keepdim(value, dim):
    res = tl.max(value, dim)
    return tl.expand_dims(res, dim)
```

* [BACKEND] Modified store op thread masking (triton-lang#1605)

* [CI] no longer runs CI job on macos-10.15 (triton-lang#1624)

* [BACKEND] Allow noinline functions to return multiple values of primitive types (triton-lang#1623)

Fix triton-lang#1621

* [BACKEND] Updated predicate for atomic ops (triton-lang#1619)

* [TEST] Added convert layout test from/to sliced blocked/mma (triton-lang#1620)

* [BACKEND] fix typo in Membar class about WAR description and refine some code (triton-lang#1629)

Co-authored-by: Philippe Tillet <phil@openai.com>

* [SETUP] Removing `torch` as a test dependency (triton-lang#1632)

circular dependency is causing troubles now that our interpreter depends
on torch 2.0 ...

* [DOCS] Fix docstrings for sphinx docs (triton-lang#1635)

* [FRONTEND] Added interpreter mode (triton-lang#1573)

Simple mechanism to run Triton kernels on PyTorch for debugging purpose
(upstream from Kernl).

Todo:
- random grid iteration
- support of atomic ops
- more unit tests
- cover new APIs?

* [CI] Build wheels for musllinux (triton-lang#1638)

Ideally you would also build source distributions so that it is in
principle possible to build `triton` on other platforms, but building
`musllinux` wheels would at least help with openai/whisper#1328.

I suspect you will also get people showing up at some point asking for
`aarch64` wheels as well. It might be worth taking a look at the
[`cibuildwheel` output
matrix](https://cibuildwheel.readthedocs.io/en/stable/#what-does-it-do)
to see what you are comfortable with shipping (particularly if you
aren't shipping source distributions).

* [FRONTEND] Fix return op related control flow issues (triton-lang#1637)

- Case 1: Return after static control flow is taken. Peel off
instructions after the first `return` for each basic block.

```python
if static_condition:
    tl.store(...)
    return
return
```

- Case 2: Return exists in both `if` and `else` branches of an inlined
`JITFunction` function

```python
def foo():
    if dynamic_condition:
        return a
    else:
        return b
```

- Case 3: Return exists in a `JITFunction` from another module

```python
import module
if cond:
    a = module.func()
```

- Case 4: A chain of calls through undefined local variables

```python
import module
if cond:
    a = x
    a = a.to(tl.int32).to(tl.int32)
```

- Case 5: Call a function `func` without returning variables. `func` is
recognized as an `Expr` first instead of a `Call`.

```python
if cond:
    foo()
else:
    bar()
```

- Case 6: Call a `noinline` function. We don't need to check if the
function contains any return op.

* [CI] Upload CUDA test artifacts (triton-lang#1645)

* [FRONTEND] Add support for scalar conditions in `device_assert` (triton-lang#1641)

This sometimes happens in TorchInductor. See
pytorch/pytorch#100880.
More generally, it's useful to be able to write `tl.device_assert(False,
msg)`.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>

* [FRONTEND] Hotfix for `contains_return_op` (triton-lang#1651)

`noinline` can be None, False, or True, so we have to check the callee
in the first two cases.

* [TEST] Fixed and re-enabled reduce test (triton-lang#1644)

Re-enabled reduce test after fixing the %cst stride in the ttgir, and
modifying the sweep parameters to make sure the shape per CTA to be less
than or equal to the tensor shape.

* [FRONTEND] Don't call set_device in tl.dot (triton-lang#1646)

This breaks multiprocess compilation

* [TESTS] Add regression test for issue triton-lang#1601. (triton-lang#1611)

Following up on triton-lang#1603, I am adding a new file meant to contain
functional regression tests to the repository.
Let me know if another folder would be a more appropriate place for
these tests.

Co-authored-by: Philippe Tillet <phil@openai.com>

* [BUILD] Move canonicalization patterns of Load/Store to Ops.cpp. (NFC) (triton-lang#1650)

This breaks a cyclic dependency between the TritonAnalysis and the
TritonIR libraries (see triton-lang#1649). It also follows the convention from
upstream (for example, see the AMDGPU, Affine, and Arith dialects).

* [FRONTEND] Better error messages for noinline functions (triton-lang#1657)

```
at 10:18:def val_multiplier_noinline(val, i):
    return val * i

           ^
Function val_multiplier_noinline is marked noinline, but was called with non-scalar argument val:fp32[constexpr[128]]
```

* [BUILD] Add missing CMake link-time dependencies. (triton-lang#1654)

* [BACKEND] Move isSharedEncoding to TritonGPUIR. (triton-lang#1655)

This breaks a cyclic dependency between TritonAnalysis and TritonGPUIR
(see triton-lang#1649).

* [FRONTEND] Do not use exceptions do guide control flow in compilation runtime (triton-lang#1663)

Triton runtime currently relies on KeyError to check whether a kernel
has been compiled. This results in somewhat confusing backtraces when
running the kernel crashes, as the stack traces includes not only the
actual crash, but also the stack trace for the original KeyError which
was caught.

* [FRONTEND] Assert that for loop bounds must be ints (triton-lang#1664)

* [OPTIMIZER] Fix-up reduction cloning

* [DEPENDENCIES] Update LLVM to 17.0.0 (c5dede880d17) and port changes. (triton-lang#1668)

This depends on a [pending LLVM
release](ptillet/triton-llvm-releases#10).

* Implement setCalleeFromCallable in CallOp.
* Cast type to ShapedType for various getters.
* Improve TritonDialect::materializeConstant due to breaking change in
constructor of arith::ConstantOp.
* Add OpaqueProperties argument in inferReturnTypes.

Co-authored-by: Philippe Tillet <phil@openai.com>

* [OPTIMIZER] adjusted selection heuristics for when `mmaLayout.warpsPerTile[1] = 1` (triton-lang#1675)

this fixes fused attention with D_HEAD=128

* [BUILD] stop depending on dlfcn-win32 by implementing `dladdr` natively with WIN32 API (triton-lang#1674)

Co-authored-by: Philippe Tillet <phil@openai.com>

* [BUILD] minor fixes (triton-lang#1676)

Remove unused variables, fix member initializer list order.

* [FRONTEND] Differentiate between bool and int in the frontend (triton-lang#1678)

`bool` is a subclass of `int`, so `isinstance(bool_var, int) == True`,
and a `bool` constant will be converted to an `int` constant.

In triton specifically, if a bool var is treated as an integer, it
prevents us using the `logical_and` operator which requires both
operands have the same bit length.

> Cannot bitcast data-type of size 32 to data-type of size 1

By differentiating int and bool, it allows us to make the syntax more
close to native python. We can now use `if bool_var and condition` to
check the truthiness, and `if bool_var is True` to check identity.

* [BUILD] Add deduction guide for `Interval` (triton-lang#1680)

This avoids `ctad-maybe-unsupported` warning.

* [OPS] Remove duplicated function already defined in `triton` module. (triton-lang#1679)

* IFU 230517 Resolve merge conflicts

* Fix is_hip() check

* [ROCM] Fix hardcoded warpsize in getMask

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com>
Co-authored-by: David MacLeod <macleod.david@live.co.uk>
Co-authored-by: davidma <davidma@speechmatics.com>
Co-authored-by: albanD <desmaison.alban@gmail.com>
Co-authored-by: Christian Sigg <chsigg@users.noreply.github.com>
Co-authored-by: Benjamin Chetioui <3920784+bchetioui@users.noreply.github.com>
Co-authored-by: Michaël Benesty <pommedeterresautee@users.noreply.github.com>
Co-authored-by: peterbell10 <peterbell10@live.co.uk>
Co-authored-by: long.chen <lipracer@gmail.com>
Co-authored-by: q.yao <streetyao@live.com>
Co-authored-by: Paul Ganssle <1377457+pganssle@users.noreply.github.com>
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
Co-authored-by: Natalia Gimelshein <ngimel@fb.com>
Co-authored-by: Ingo Müller <github.com@ingomueller.net>
Co-authored-by: Ingo Müller <ingomueller@google.com>
Co-authored-by: George Karpenkov <cheshire@google.com>
Co-authored-by: Sophia Wisdom <sophia.wisdom1999@gmail.com>
Co-authored-by: cloudhan <cloudhan@outlook.com>
Co-authored-by: Daniil Fukalov <1671137+dfukalov@users.noreply.github.com>
@zy-fang
Copy link

zy-fang commented Jan 19, 2024

I have encountered the same problem, how did you solve it?

1 similar comment
@laomao0
Copy link

laomao0 commented Jan 24, 2024

I have encountered the same problem, how did you solve it?

@xgbj
Copy link

xgbj commented Jan 31, 2024

I encountered the same error. In my case, I have a local Docker environment with CUDA 11.4 installed, and I installed the pytorch 2.2.0+cu118 library using pip. Here is part of the error log during runtime:
File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 887, in aot_module_simplified compiled_fn = create_aot_dispatcher_function( File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper r = func(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 600, in create_aot_dispatcher_function compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata) File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 425, in aot_wrapper_dedupe return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata) File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 630, in aot_wrapper_synthetic_base return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata) File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 295, in aot_dispatch_autograd compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args) File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper r = func(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py", line 1100, in fw_compiler_base return inner_compile( File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper inner_compiled_fn = compiler_fn(gm, example_inputs) File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/debug.py", line 305, in inner return fn(*args, **kwargs) File "/usr/lib/python3.8/contextlib.py", line 75, in inner return func(*args, **kwds) File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py", line 320, in compile_fx_inner compiled_graph = fx_codegen_and_compile( File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py", line 550, in fx_codegen_and_compile compiled_fn = graph.compile_to_fn() File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/graph.py", line 1116, in compile_to_fn return self.compile_to_module().call File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper r = func(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/graph.py", line 1070, in compile_to_module mod = PyCodeCache.load_by_key_path( File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/codecache.py", line 1892, in load_by_key_path exec(code, mod.__dict__, mod.__dict__) File "/tmp/torchinductor_hadoop-perception/qw/cqwqiq3cpyyjdci7ohjub32wfio7swg4qpodckpyayneldgkgevh.py", line 1678, in <module> async_compile.wait(globals()) File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/codecache.py", line 2471, in wait scope[key] = result.result() File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/codecache.py", line 2315, in result kernel = self.kernel = _load_kernel(self.kernel_name, self.source_code) File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/codecache.py", line 2291, in _load_kernel kernel.precompile() File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/triton_heuristics.py", line 188, in precompile compiled_binary, launcher = self._precompile_config( File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/triton_heuristics.py", line 308, in _precompile_config binary._init_handles() File "/usr/local/lib/python3.8/dist-packages/triton/compiler/compiler.py", line 683, in _init_handles mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device) torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised: RuntimeError: Triton Error [CUDA]: device kernel image is invalid

@xingjinglu
Copy link

xingjinglu commented Feb 7, 2024

Hope this help.
#1955 (comment)

@xgbj
Copy link

xgbj commented Feb 9, 2024

Thank you for your reply. It seems that the issue is caused by the incorrect version of CUDA used for compilation. I will check my own environment. @xingjinglu

pingzhuu pushed a commit to siliconflow/triton that referenced this issue Apr 2, 2024
…oad (triton-lang#1593)

Closes triton-lang#1556
triton-lang#1512

The current hash used for caching the cubin does not include the
architecture. This leads to the following error when compiling against
one arch and running against another (with no code changes to trigger a
recompilation).
```
RuntimeError: Triton Error [CUDA]: device kernel image is invalid
```
Was not sure what unit tests would be appropriate here (if any)

Co-authored-by: davidma <davidma@speechmatics.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
7 participants