torch.compile doesn't work well with custom triton kernel from Mamba #128061
Labels
module: pt2-dispatcher
PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,
module: user triton
related to ability to directly torch.compile triton kernels
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
Run Mamba benchmark from https://github.com/state-spaces/mamba/blob/main/benchmarks/benchmark_generation_mamba_simple.py,
Apply to
torch.compile
tomodel.generate
and then we will see several graph breaks and failures, the major failure is:After adding some debug info, the assertion triggered because
kernel
is antriton.runtime.autotuner.Heuristics
object.Versions
N/A
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @zou3519 @oulgen @aakhundov
The text was updated successfully, but these errors were encountered: