Skip to content

Commit

Permalink
[Inductor] Properly package target info for triton.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbaden committed Apr 30, 2024
1 parent ab80a59 commit aeabcb2
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,17 @@
from triton.compiler.compiler import ASTSource
except ImportError:
ASTSource = None

try:
from triton.backends.compiler import GPUTarget
except ImportError:
GPUTarget = None
else:
Config = object
KernelInterface = object
OutOfResources = object
ASTSource = None
GPUTarget = None

try:
autograd_profiler = torch.autograd.profiler
Expand Down Expand Up @@ -343,11 +349,22 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: Optional[int]
else:
rocm_warp_size = 64

target = (
(compile_meta["device_type"], compile_meta["cc"])
if not torch.version.hip
else [compile_meta["device_type"], compile_meta["cc"], rocm_warp_size]
)
if GPUTarget:
target = GPUTarget(
compile_meta["device_type"],
compile_meta["cc"],
rocm_warp_size if torch.version.hip else 32,
)
else:
target = (
(compile_meta["device_type"], compile_meta["cc"])
if not torch.version.hip
else [
compile_meta["device_type"],
compile_meta["cc"],
rocm_warp_size,
]
)

options = {
"num_warps": compile_meta["num_warps"],
Expand Down

0 comments on commit aeabcb2

Please sign in to comment.