Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convnext breaks torch.compile #97018

Closed
FrancescoSaverioZuppichini opened this issue Mar 17, 2023 · 2 comments
Closed

Convnext breaks torch.compile #97018

FrancescoSaverioZuppichini opened this issue Mar 17, 2023 · 2 comments
Assignees
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@FrancescoSaverioZuppichini
Copy link

FrancescoSaverioZuppichini commented Mar 17, 2023

馃悰 Describe the bug

Hi 馃憢

Trying to torch.compile convnext results in a triton error.

It works for other models, e.g. vit and resnet.

I also would like to ask for help reading the error, I am not experienced but I'd like to understand what the error is telling me.

import torch
from torchvision.models import convnext_base


model = convnext_base()
model = model.cuda().half()

model = torch.compile(model, mode="max-autotune")

x = torch.randn((1, 3, 224, 224), device="cuda").half()

with torch.no_grad():
    model(x)

Error

[2023-03-17 14:18:03,439] torch._inductor.utils: [WARNING] using triton random, expect difference from eager
Traceback (most recent call last):
  File "/home/zuppif/Documents/medium/pytorch-2.0-compile/convnext.py", line 13, in <module>
    model(x)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 99, in __call__
    return self.dynamo_ctx(self._orig_mod.__call__)(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 235, in _fn
    return fn(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 372, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 405, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 105, in _fn
    return fn(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 263, in _convert_frame_assert
    return _compile(
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 325, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 530, in transform_code_object
    transformations(instructions, code_options)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 312, in transform
    tracer.run()
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1842, in run
    super().run()
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in run
    and self.step()
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 561, in step
    getattr(self, inst.opname)(inst)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1921, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 545, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 615, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 701, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 697, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 1064, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/__init__.py", line 1527, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 411, in compile_fx
    return compile_fx(
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 515, in compile_fx
    return aot_autograd(
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 59, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2987, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2668, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1770, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1943, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1252, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 483, in fw_compiler
    return inner_compile(
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 598, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 182, in compile_fx_inner
    compiled_fn = graph.compile_to_fn()
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/graph.py", line 648, in compile_to_fn
    return self.compile_to_module().call
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/graph.py", line 626, in compile_to_module
    mod = PyCodeCache.load(code)
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 654, in load
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_zuppif/kx/ckxfxxrav4eqlbxhp4vwjdwmi7wieclub23mynkqzi5d3lolcl6f.py", line 3106, in <module>
    async_compile.wait(globals())
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 841, in wait
    scope[key] = result.result()
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 699, in result
    self.future.result()
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/concurrent/futures/_base.py", line 439, in result
    return self.__get_result()
  File "/home/zuppif/miniconda3/envs/dl/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='debug_wrapper' raised:
CompilationError: at 65:39:
def triton_(arg_A, arg_B, in_ptr2, in_ptr3, seed4, in_ptr5, out_ptr1):
    GROUP_M : tl.constexpr = 8
    EVEN_K : tl.constexpr = True
    ALLOW_TF32 : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float32
    BLOCK_M : tl.constexpr = 64
    BLOCK_N : tl.constexpr = 32
    BLOCK_K : tl.constexpr = 32

    A = arg_A
    B = arg_B

    M = 3136
    N = 128
    K = 512
    stride_am = 512
    stride_ak = 1
    stride_bk = 1
    stride_bn = 512

    # based on triton.ops.matmul
    pid = tl.program_id(0)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N

    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    idx_m = rm[:, None]
    idx_n = rn[None, :]
    mask = (idx_m < M) & (idx_n < N)

    # inductor generates a suffix
    xindex = idx_n + (128*idx_m)
    tmp0 = tl.load(in_ptr2 + (idx_n + tl.zeros(mask.shape, tl.int32)), mask).to(tl.float32)
    tmp1 = tl.load(in_ptr3 + (idx_n + tl.zeros(mask.shape, tl.int32)), mask).to(tl.float32)
    tmp4_load = tl.load(seed4 + (0))
    tmp4 = tl.broadcast_to(tmp4_load, [XBLOCK])
                                       ^


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Cheers,

Fra

Versions

Collecting environment information...
PyTorch version: 2.1.0.dev20230316+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.10 (x86_64)
GCC version: (Ubuntu 12.2.0-3ubuntu1) 12.2.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.36

Python version: 3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-35-generic-x86_64-with-glibc2.36
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 525.60.11
cuDNN version: Probably one of the following:
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 7 2700X Eight-Core Processor
CPU family: 23
Model: 8
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 2
Frequency boost: enabled
CPU(s) scaling MHz: 78%
CPU max MHz: 3700,0000
CPU min MHz: 2200,0000
BogoMIPS: 7400.05
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb hw_pstate ssbd ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec xgetbv1 xsaves clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif overflow_recov succor smca sev sev_es
Virtualization: AMD-V
L1d cache: 256 KiB (8 instances)
L1i cache: 512 KiB (8 instances)
L2 cache: 4 MiB (8 instances)
L3 cache: 16 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT vulnerable
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.24.1
[pip3] pytorch-lightning==2.0.0
[pip3] pytorch-triton==2.1.0+2c32f43999
[pip3] torch==2.1.0.dev20230316+cu117
[pip3] torchaudio==0.12.1+cu116
[pip3] torchmetrics==0.11.4
[pip3] torchtriton==2.0.0+0d7e753227
[pip3] torchvision==0.16.0.dev20230316+cu117
[conda] numpy 1.24.1 pypi_0 pypi
[conda] pytorch-lightning 2.0.0 pypi_0 pypi
[conda] pytorch-triton 2.1.0+2c32f43999 pypi_0 pypi
[conda] torch 2.1.0.dev20230316+cu117 pypi_0 pypi
[conda] torchaudio 0.12.1+cu116 pypi_0 pypi
[conda] torchmetrics 0.11.4 pypi_0 pypi
[conda] torchtriton 2.0.0+0d7e753227 pypi_0 pypi
[conda] torchvision 0.16.0.dev20230316+cu117 pypi_0 pypi

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh

@Chillee
Copy link
Contributor

Chillee commented Mar 21, 2023

This looks like an issue with using the Triton-based convolution. I can't actually repro it on master, but I suspect this will fix it: #95556

@Chillee Chillee added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 21, 2023
@williamwen42
Copy link
Member

Can no longer repro - closing as fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants