Skip to content

Commit

Permalink
[Inductor] Properly package target info for triton.compile (#125553)
Browse files Browse the repository at this point in the history
Triton updated the interface for `triton.compile` triton-lang/triton@5162346

The `target` argument to compile needs to be wrapped in a `GPUTarget` object. Without proper wrapping, we hit an assert in `compile`. If that assert is removed, Triton attempts to read device info from Torch while inside a torch thread, which hits an in bad fork assert. This change is required for compatibility with latest commits in Triton. The implementation is backwards compatible, so existing versions of Triton that work now continue to work.

Re-submitting this after #125241 was reverted due to an unrelated CI issue.

Pull Request resolved: #125553
Approved by: https://github.com/huydhn
  • Loading branch information
alexbaden authored and pytorchmergebot committed May 6, 2024
1 parent 1dd42e4 commit fc183f0
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,17 @@
from triton.compiler.compiler import ASTSource
except ImportError:
ASTSource = None

try:
from triton.backends.compiler import GPUTarget
except ImportError:
GPUTarget = None
else:
Config = object
KernelInterface = object
OutOfResources = object
ASTSource = None
GPUTarget = None

try:
autograd_profiler = torch.autograd.profiler
Expand Down Expand Up @@ -334,11 +340,22 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool):
else:
rocm_warp_size = 64

target = (
(compile_meta["device_type"], compile_meta["cc"])
if not torch.version.hip
else [compile_meta["device_type"], compile_meta["cc"], rocm_warp_size]
)
if GPUTarget:
target = GPUTarget(
compile_meta["device_type"],
compile_meta["cc"],
rocm_warp_size if torch.version.hip else 32,
)
else:
target = (
(compile_meta["device_type"], compile_meta["cc"])
if not torch.version.hip
else [
compile_meta["device_type"],
compile_meta["cc"],
rocm_warp_size,
]
)

options = {
"num_warps": compile_meta["num_warps"],
Expand Down

0 comments on commit fc183f0

Please sign in to comment.