Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchvision manywheel py 3.11, cuda failure #8441

Open
atalman opened this issue May 27, 2024 · 6 comments
Open

torchvision manywheel py 3.11, cuda failure #8441

atalman opened this issue May 27, 2024 · 6 comments

Comments

@atalman
Copy link
Contributor

atalman commented May 27, 2024

馃悰 Describe the bug

Following nightly failure:
https://github.com/pytorch/vision/actions/runs/9111531485/job/25048879022

Started on 5/16 . Most likely introduced by one of the changes in:
pytorch/pytorch@a86434a

+ /__w/vision/vision/3/bin/conda run -p /__w/_temp/conda_environment_9111531485 python pytorch/vision/test/smoke_test.py
ERROR conda.cli.main_run:execute(41): `conda run python pytorch/vision/test/smoke_test.py` failed. (See above for error)
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /github/home/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
torchvision: 0.19.0.dev20240516+cu121
torch.cuda.is_available: True
/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/compile_fx.py:133: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Traceback (most recent call last):
  File "/__w/vision/vision/pytorch/vision/test/smoke_test.py", line 103, in <module>
    main()
  File "/__w/vision/vision/pytorch/vision/test/smoke_test.py", line 96, in main
    smoke_test_compile()
torch.ops.image._jpeg_version() = 62
Is torchvision usable? True
German shepherd (cpu): 37.6%
German shepherd (cuda): 37.6%

  File "/__w/vision/vision/pytorch/vision/test/smoke_test.py", line 42, in smoke_test_compile
    out = model(x)
          ^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 414, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1085, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 926, in _convert_frame
    result = inner_convert(
             ^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 464, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_utils_internal.py", line 74, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 807, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 626, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 572, in transform
    tracer.run()
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
          ^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2423, in RETURN_VALUE
    self._return(inst)
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2408, in _return
    self.output.compile_subgraph(
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1084, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1301, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1392, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1373, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/__init__.py", line 1747, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1492, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 65, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 965, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 686, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(
                  ^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line [445](https://github.com/pytorch/vision/actions/runs/9111531485/job/25048879022#step:16:446), in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1397, in fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/debug.py", line 304, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 522, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
                     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 818, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1736, in compile_to_fn
    return self.compile_to_module().call
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 210, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1680, in compile_to_module
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2576, in load_by_key_path
    mod = _reload_python_module(key, path)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 44, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_root/fk/cfka6mgq6jdwbss6b6p7iptbdbl4bt3qmhxjfnhaxkgaqknx6hu7.py", line 2330, in <module>
    async_compile.wait(globals())
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3137, in wait
    scope[key] = result.result()
                 ^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2939, in result
    self.future.result()
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/concurrent/futures/_base.py", line [456](https://github.com/pytorch/vision/actions/runs/9111531485/job/25048879022#step:16:457), in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/__w/_temp/conda_environment_9111531485/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
SyntaxError: unterminated string literal (detected at line 1) (<unknown>, line 1)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

2.4.0

@atalman
Copy link
Contributor Author

atalman commented May 28, 2024

Strangely enought, I can't seems to repro it:

pip install torchvision==0.19.0.dev20240515+cu121 --pre --index-url https://download.pytorch.org/whl/nightly/cu121
pip install torch==2.4.0.dev20240517+cu121 --pre --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall
....
python smoke_test.py 
torchvision: 0.19.0.dev20240515+cu121
torch.cuda.is_available: True
torch: 2.4.0.dev20240517+cu121
torch.ops.image._jpeg_version() = 62
Is torchvision usable? True
German shepherd (cpu): 37.6%
German shepherd (cuda): 37.6%
/home/atalman/miniconda3/envs/py311/lib/python3.11/site-packages/torch/_inductor/compile_fx.py:133: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
torch.compile model output: torch.Size([1, 1000])

@jamesjwu
Copy link

jamesjwu commented May 28, 2024

Tagging @eellison maybe? The logs show a "triton compilation failed" with a fn.src = defb844cc147' which has a quotation mark in it, which triton can't deal with:

Cleaned up paste: (https://gist.github.com/jamesjwu/a76490e55e27641697cc60f172cab908)

@jamesjwu
Copy link

jamesjwu commented May 28, 2024

I think it's trying to output the big decorator at the top, but the output is cut off somehow, because defb844cc147' is the last part of "backend_hash":

@triton_heuristics.pointwise(
    size_hints=[262144], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*i8', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=80), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_max_pool2d_with_indices_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 9, 'num_reduction': 0, 'backend_hash': '4a97da3b799ff6c044be9f5292baf267c385f6396e77240569b5
#------------ (where the failing output starts)-------------
defb844cc147', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit

@atalman
Copy link
Contributor Author

atalman commented May 28, 2024

@manman-ren
Copy link

Given the error message of
SyntaxError: unterminated string literal (detected at line 1)

As James mentioned, looks like Triton doesn't see the first few lines, and that is why it reports error at line 1?

If there is a reproducer for Triton, it will help us triage.

@NicolasHug
Copy link
Member

@atalman these jobs seem to be passing on the latest commits from main branch: https://github.com/pytorch/vision/actions/runs/9402293448/job/25897155214

Should we close this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants