Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

triton package that ships with torch is not compatible with torch.ops.matmul #125539

Open
jeromeku opened this issue May 4, 2024 · 6 comments
Open
Labels
matrix multiplication module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jeromeku
Copy link

jeromeku commented May 4, 2024

UPDATE from @oulgen:

import triton
import torch

from triton import ops

def test_triton_matmul_compile():
    dtype = torch.float16
    A = torch.randn(4096, 4096, dtype=dtype, device="cuda")
    B = torch.randn(4096, 4096, dtype=dtype, device="cuda")
    acc_dtype = None
    input_precision = "ieee"
    fp8_fast_accum = False
    output_dtype = None

    out = ops.matmul(A, B, acc_dtype, input_precision, fp8_fast_accum, output_dtype)

if __name__ == "__main__":
    test_triton_matmul_compile()

with triton distributed via torch does not work without compile.

🐛 Describe the bug

@oulgen

Per the title, torch.dynamo errs out with "No module named 'nvi'" when compiling a user-defined kernel.

A minimal repro is below, where the "user-defined kernel" is simply triton.ops.matmul, the official triton matmul implementation:

import triton
import torch
   
def test_triton_matmul_compile():
    dtype = torch.float16
    A = torch.randn(4096, 4096, dtype=dtype, device="cuda")
    B = torch.randn(4096, 4096, dtype=dtype, device="cuda")
    acc_dtype = None
    input_precision = "ieee"
    fp8_fast_accum = False
    output_dtype = None

    out = triton.ops.matmul(A, B, acc_dtype, input_precision, fp8_fast_accum, output_dtype)
    
    triton_compiled = torch.compile(triton.ops.matmul)
    compiled_out = triton_compiled(A, B, acc_dtype, input_precision, fp8_fast_accum, output_dtype)
    
if __name__ == "__main__":
    test_triton_matmul_compile()

torch version: 2.4.0.dev20240427+cu118
triton version: 3.0.0

Full versions pasted in box below.

Full error:

Traceback (most recent call last):
  File "/notebooks/experiments/test/test_dora_fusion.py", line 204, in <module>
    test_triton_matmul_compile()
  File "/notebooks/experiments/test/test_dora_fusion.py", line 201, in test_triton_matmul_compile
    compiled_out = triton_compiled(A, B, acc_dtype, input_precision, fp8_fast_accum, output_dtype)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/autograd/function.py", line 571, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/triton/ops/matmul.py", line 213, in forward
    @staticmethod
  File "envs/lib/python3.11/site-packages/triton/ops/matmul.py", line 202, in _call
    _kernel[grid](
  File "envs/lib/python3.11/site-packages/triton/runtime/jit.py", line 207, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 140, in run
    ret = self.fn.run(
          ^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 282, in run
    return self.fn.run(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/triton/runtime/jit.py", line 381, in run
    def run(self, *args, grid, warmup, **kwargs):
  File "envs/lib/python3.11/site-packages/triton/runtime/jit.py", line 388, in torch_dynamo_resume_in_run_at_388
    device = driver.active.get_current_device()
  File "envs/lib/python3.11/site-packages/triton/runtime/jit.py", line 389, in torch_dynamo_resume_in_run_at_389
    stream = driver.active.get_current_stream(device)
  File "envs/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 977, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 818, in _convert_frame
    result = inner_convert(
             ^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 729, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "envs/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 700, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 268, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 568, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1116, in transform_code_object
    transformations(instructions, code_options)
  File "envs/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 173, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 515, in transform
    tracer.run()
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2241, in run
    super().run()
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 875, in run
    while self.step():
          ^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 790, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 492, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1852, in CALL
    self.call_function(fn, args, kwargs)
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 730, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 339, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 736, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2422, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2538, in inline_call_
    tracer.run()
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 875, in run
    while self.step():
          ^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 790, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1010, in LOAD_GLOBAL
    source = self.get_global_source(name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 974, in get_global_source
    self.import_source(self.f_globals["__name__"]), name
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1036, in import_source
    value = importlib.import_module(module_name)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "envs/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1140, in _find_and_load_unlocked
torch._dynamo.exc.InternalTorchDynamoError: No module named 'nvi'

Versions

Collecting environment information...
PyTorch version: 2.4.0.dev20240427+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0
Clang version: Could not collect
CMake version: version 3.27.20230625-g7e38674
Libc version: glibc-2.35

Python version: 3.11.8 (main, Feb 26 2024, 21:39:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-45-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A6000
Nvidia driver version: 525.116.04
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 5315Y CPU @ 3.20GHz
CPU family: 6
Model: 106
Thread(s) per core: 1
Core(s) per socket: 8
Socket(s): 1
Stepping: 6
BogoMIPS: 6405.68
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti ibpb fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves umip pku ospke vaes vpclmulqdq rdpid
Hypervisor vendor: Xen
Virtualization type: full
L1d cache: 384 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 10 MiB (8 instances)
L3 cache: 96 MiB (8 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-7
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4 (envs/lib/python3.11/site-packages)
[pip3] pytorch-triton==3.0.0+45fff310c8 (envs/lib/python3.11/site-packages)
[pip3] torch==2.4.0.dev20240427+cu118 (envs/lib/python3.11/site-packages)
[pip3] torchao==0.1 (/notebooks/Triton/ao)
[pip3] triton==3.0.0 (envs/lib/python3.11/site-packages)
[conda] numpy 1.26.4 pypi_0 pypi
[conda] pytorch-triton 3.0.0+45fff310c8 pypi_0 pypi
[conda] torch 2.4.0.dev20240427+cu118 pypi_0 pypi
[conda] torchao 0.1 dev_0
[conda] triton 3.0.0 pypi_0 pypi

cc @malfet @seemethere @ptrblck @msaroufim

@oulgen
Copy link
Contributor

oulgen commented May 4, 2024

"No module named 'nvi'" means that there's a version mismatch. Are you able to run your triton code in eager mode (without torch.compile)?

@jeromeku
Copy link
Author

jeromeku commented May 4, 2024

yes, in the minimal repro I provided, the code first runs the kernel without compile first. Only errors out when compiling.

@oulgen
Copy link
Contributor

oulgen commented May 4, 2024

@jeromeku looking at your relevant libraries, i see you have both triton and pytorch-triton installed, could you pip uninstall triton so that there's only one triton?

@jeromeku
Copy link
Author

jeromeku commented May 5, 2024

I tried to do that, but uninstalling triton breaks my installation. When I try to import triton.ops, I get:

  "No module named 'triton.ops'"

@oulgen
Copy link
Contributor

oulgen commented May 5, 2024

That's because ops is likely not defined in the __init__.py. If you add from triton import ops and then change triton.ops.matmul to ops.matmul, it makes progress, but then it runs into another issue in triton's autotuner.

import triton
import torch

from triton import ops

def test_triton_matmul_compile():
    dtype = torch.float16
    A = torch.randn(4096, 4096, dtype=dtype, device="cuda")
    B = torch.randn(4096, 4096, dtype=dtype, device="cuda")
    acc_dtype = None
    input_precision = "ieee"
    fp8_fast_accum = False
    output_dtype = None

    out = ops.matmul(A, B, acc_dtype, input_precision, fp8_fast_accum, output_dtype)

    #triton_compiled = torch.compile(ops.matmul)
    #compiled_out = triton_compiled(A, B, acc_dtype, input_precision, fp8_fast_accum, output_dtype)

if __name__ == "__main__":
    test_triton_matmul_compile()

notice that there's no torch compile anymore, still runs into the same autotuner issue. I suspect the triton package that ships with torch is not compatible with torch.ops.matmul.

I'll update the issue, for proper triage.

@oulgen oulgen changed the title torch.compile error with user-defined kernel: "No module named 'nvi'" triton package that ships with torch is not compatible with torch.ops.matmul May 5, 2024
@jbschlosser jbschlosser added module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module matrix multiplication labels May 6, 2024
@arthursunbao
Copy link

I have the same issue and will update more details later

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
matrix multiplication module: build Build system issues module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants