Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile doesn't work well with custom triton kernel from Mamba #128061

Closed
yanboliang opened this issue Jun 5, 2024 · 2 comments
Closed
Assignees
Labels
module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, module: user triton related to ability to directly torch.compile triton kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yanboliang
Copy link
Contributor

yanboliang commented Jun 5, 2024

馃悰 Describe the bug

Run Mamba benchmark from https://github.com/state-spaces/mamba/blob/main/benchmarks/benchmark_generation_mamba_simple.py,

Apply to torch.compile to model.generate and then we will see several graph breaks and failures, the major failure is:

WARNING:torch._dynamo:Encountered an exception in identify_mutated_tensors, assuming every input is mutated
Traceback (most recent call last):
  File "/home/ybliang/local/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 474, in identify_mutated_tensors
    ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
  File "/home/ybliang/local/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 131, in generate_ttir
    assert isinstance(kernel, JITFunction), f"{kernel}, {type(kernel)}"
AssertionError
WARNING:torch._dynamo:Encountered an exception in identify_mutated_tensors, assuming every input is mutated
Traceback (most recent call last):
  File "/home/ybliang/local/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 474, in identify_mutated_tensors
    ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
  File "/home/ybliang/local/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 131, in generate_ttir
    assert isinstance(kernel, JITFunction), f"{kernel}, {type(kernel)}"
AssertionError
Traceback (most recent call last):
  File "/data/users/ybliang/debug/empathy/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 83, in <module>
    out = fn()
  File "/data/users/ybliang/debug/empathy/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 57, in <lambda>
    fn = lambda: model_generate(
  File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 421, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 260, in generate
    output = decode(
  File "/home/ybliang/local/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 160, in decode
    model._decoding_cache = update_graph_cache(
  File "/home/ybliang/local/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 293, in update_graph_cache
    param_example = next(iter(model.parameters()))
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 293, in torch_dynamo_resume_in_update_graph_cache_at_293
    param_example = next(iter(model.parameters()))
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 293, in torch_dynamo_resume_in_update_graph_cache_at_293
    param_example = next(iter(model.parameters()))
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 305, in torch_dynamo_resume_in_update_graph_cache_at_293
    gc.collect()
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 321, in torch_dynamo_resume_in_update_graph_cache_at_305
    cache.callables[batch_size, decoding_seqlen] = capture_graph(
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 343, in capture_graph
    device = next(iter(model.parameters())).device
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 343, in torch_dynamo_resume_in_capture_graph_at_343
    device = next(iter(model.parameters())).device
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 343, in torch_dynamo_resume_in_capture_graph_at_343
    device = next(iter(model.parameters())).device
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/utils/generation.py", line 355, in torch_dynamo_resume_in_capture_graph_at_343
    logits = model(
  File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward
    hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
  File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward
    hidden_states, residual = layer(
  File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/nn/modules/module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/modules/block.py", line 57, in forward
    hidden_states, residual = layer_norm_fn(
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py", line 875, in layer_norm_fn
    return LayerNormFn.apply(
  File "/home/ybliang/local/pytorch/torch/autograd/function.py", line 573, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py", line 748, in forward
    y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py", line 334, in _layer_norm_fwd
    with torch.cuda.device(x.device.index):
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py", line 335, in torch_dynamo_resume_in__layer_norm_fwd_at_334
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 1078, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 919, in _convert_frame
    result = inner_convert(
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 456, in _convert_frame_assert
    return _compile(
  File "/home/ybliang/local/pytorch/torch/_utils_internal.py", line 83, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/home/ybliang/local/pytorch/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 799, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/ybliang/local/pytorch/torch/_dynamo/utils.py", line 232, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 618, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/ybliang/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1184, in transform_code_object
    transformations(instructions, code_options)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 177, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/convert_frame.py", line 564, in transform
    tracer.run()
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2450, in run
    super().run()
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 892, in run
    while self.step():
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 804, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2641, in RETURN_VALUE
    self._return(inst)
  File "/home/ybliang/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2626, in _return
    self.output.compile_subgraph(
  File "/home/ybliang/local/pytorch/torch/_dynamo/output_graph.py", line 1122, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/ybliang/local/pytorch/torch/_dynamo/output_graph.py", line 1314, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/ybliang/local/pytorch/torch/_dynamo/utils.py", line 232, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/output_graph.py", line 1405, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/ybliang/local/pytorch/torch/_dynamo/output_graph.py", line 1386, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/ybliang/local/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/ybliang/local/pytorch/torch/__init__.py", line 1796, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/ybliang/local/miniconda3/envs/pt/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/ybliang/local/pytorch/torch/_inductor/compile_fx.py", line 1474, in compile_fx
    return aot_autograd(
  File "/home/ybliang/local/pytorch/torch/_dynamo/backends/common.py", line 65, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_functorch/aot_autograd.py", line 941, in aot_module_simplified
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/home/ybliang/local/pytorch/torch/_dynamo/utils.py", line 232, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_functorch/aot_autograd.py", line 674, in create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/home/ybliang/local/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 128, in aot_dispatch_base
    fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph(  # type: ignore[misc]
  File "/home/ybliang/local/pytorch/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 137, in aot_dispatch_base_graph
    fw_module = _create_graph(
  File "/home/ybliang/local/pytorch/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 45, in _create_graph
    fx_g = make_fx(
  File "/home/ybliang/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 1421, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/home/ybliang/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 1367, in trace
    return self._trace_inner(f, *args)
  File "/home/ybliang/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 1354, in _trace_inner
    t = dispatch_trace(
  File "/home/ybliang/local/pytorch/torch/_compile.py", line 30, in inner
    return disable_fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 642, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/ybliang/local/pytorch/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/home/ybliang/local/pytorch/torch/fx/_symbolic_trace.py", line 821, in trace
    (self.create_arg(fn(*args)),),
  File "/home/ybliang/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 660, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "/home/ybliang/local/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 549, in _functionalized_f_helper
    with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
  File "/home/ybliang/local/pytorch/torch/autograd/grad_mode.py", line 390, in __init__
    self.prev_version = tensor._version
  File "/home/ybliang/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 705, in __torch_function__
    return func(*args, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Inference tensors do not track version counter.

After adding some debug info, the assertion triggered because kernel is an triton.runtime.autotuner.Heuristics object.

Versions

N/A

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @zou3519 @oulgen @aakhundov

@oulgen
Copy link
Contributor

oulgen commented Jun 5, 2024

@bdhirsh fixed this one (or created a PR and never landed it)

@oulgen oulgen added the module: user triton related to ability to directly torch.compile triton kernels label Jun 6, 2024
@soulitzer soulitzer added module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 6, 2024
@oulgen
Copy link
Contributor

oulgen commented Jun 7, 2024

#124489 fixes this

@oulgen oulgen closed this as completed Jun 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, module: user triton related to ability to directly torch.compile triton kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants