From 00cf90375ca5aa4822f993cb2c97bdcd8ad33b10 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 22 May 2024 19:02:44 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- torch/_inductor/codegen/simd.py | 60 +++++++++---------------------- torch/_inductor/codegen/triton.py | 52 ++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index a1a1ddb6a286d..1d9ebce334f24 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -38,7 +38,7 @@ from ..ir import TritonTemplateBuffer from ..optimize_indexing import indexing_dtype_strength_reduction from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK -from ..runtime.runtime_utils import get_max_y_grid, green_text, yellow_text +from ..runtime.runtime_utils import green_text, yellow_text from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse from ..utils import ( get_dtype_size, @@ -247,54 +247,16 @@ def add(node): return list(reversed(index_vars)), list(reversed(sizes)) def ranges_code(self): - assert self.tensor_dim is not None - size = self.kernel.indexing_size_str(self.tensor_dim) - index_dtype = self.kernel.index_dtype - convert = f".to({index_dtype})" if index_dtype != "tl.int32" else "" - return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}{convert}" + return self.kernel.iteration_ranges_ranges_code(self) def scalar_code(self, value): - index_dtype = self.kernel.index_dtype - ndim = self.kernel.triton_tensor_ndim() - size = [1] * ndim - return f"tl.full({size}, {value}, {index_dtype})" + return self.kernel.iteration_ranges_scalar_code(self, value) def get_pid(self): - assert self.grid_dim is not None - key = f"tl.program_id({self.grid_dim})" - # y_grid has a limit, so express it in terms of y and z in case of overflow. - # z grid is only exercised when max_tiles == 3 (off by default). - if ( - self.grid_dim == 1 - and not self.has_zdim - and not (isinstance(self.numel, int) and self.numel <= get_max_y_grid()) - ): - key = f"{key} * (tl.program_id({self.grid_dim + 1}) + 1)" - pid = self.pid_cache.get(key, key) - if self.kernel.index_dtype != "tl.int32": - return f"{pid}.to({self.kernel.index_dtype})" - return pid + return self.kernel.iteration_ranges_get_pid(self) def codegen_header(self, code): - x = self.prefix - if self.is_loop: - code.writeline(f"{self.name} = {x}offset + {x}base") - elif self.grid_dim is None: - # no need to "{x}offset = " - code.writeline(f"{self.name} = {self.ranges_code()}") - code.writeline(f"{x}offset = 0") - else: - if self.tensor_dim is not None: - line = f"{x}offset + {self.ranges_code()}" - else: - line = self.scalar_code(f"{x}offset") - code.writelines( - [ - f"{x}offset = {self.get_pid()} * {x.upper()}BLOCK", - f"{self.name} = {line}", - ] - ) - code.writeline(f"{x}mask = {self.name} < {x}numel") + return self.kernel.iteration_ranges_codegen_header(self, code) class IterationRangesEntry(IterationRanges): @@ -1090,6 +1052,18 @@ def codegen_body(self): def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): raise NotImplementedError + def iteration_ranges_ranges_code(self, entry): + raise NotImplementedError + + def iteration_ranges_scalar_code(self, entry, value): + raise NotImplementedError + + def iteration_ranges_get_pid(self, entry): + raise NotImplementedError + + def iteration_ranges_codegen_header(self, entry, code): + raise NotImplementedError + class SIMDScheduling(BaseScheduling): kernel_type = SIMDKernel # override in subclass diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0617c21c8eaa1..c4b85472a6eb4 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -25,7 +25,7 @@ from ..ir import IRNode from ..metrics import is_metric_table_enabled, log_kernel_metadata from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK -from ..runtime.runtime_utils import do_bench_gpu, next_power_of_2 +from ..runtime.runtime_utils import do_bench_gpu, get_max_y_grid, next_power_of_2 from ..utils import ( cache_on_self, get_bounds_index_expr, @@ -2160,6 +2160,56 @@ def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): # lift non-reduction stores outside loop self.body.writeline(line) + def iteration_ranges_ranges_code(self, entry): + assert entry.tensor_dim is not None + size = self.indexing_size_str(entry.tensor_dim) + index_dtype = self.index_dtype + convert = f".to({index_dtype})" if index_dtype != "tl.int32" else "" + return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{convert}" + + def iteration_ranges_scalar_code(self, entry, value): + index_dtype = self.index_dtype + ndim = self.triton_tensor_ndim() + size = [1] * ndim + return f"tl.full({size}, {value}, {index_dtype})" + + def iteration_ranges_get_pid(self, entry): + assert entry.grid_dim is not None + key = f"tl.program_id({entry.grid_dim})" + # y_grid has a limit, so express it in terms of y and z in case of overflow. + # z grid is only exercised when max_tiles == 3 (off by default). + if ( + entry.grid_dim == 1 + and not entry.has_zdim + and not (isinstance(entry.numel, int) and entry.numel <= get_max_y_grid()) + ): + key = f"{key} * (tl.program_id({entry.grid_dim + 1}) + 1)" + pid = entry.pid_cache.get(key, key) + if self.index_dtype != "tl.int32": + return f"{pid}.to({self.index_dtype})" + return pid + + def iteration_ranges_codegen_header(self, entry, code): + x = entry.prefix + if entry.is_loop: + code.writeline(f"{entry.name} = {x}offset + {x}base") + elif entry.grid_dim is None: + # no need to "{x}offset = " + code.writeline(f"{entry.name} = {entry.ranges_code()}") + code.writeline(f"{x}offset = 0") + else: + if entry.tensor_dim is not None: + line = f"{x}offset + {entry.ranges_code()}" + else: + line = entry.scalar_code(f"{x}offset") + code.writelines( + [ + f"{x}offset = {entry.get_pid()} * {x.upper()}BLOCK", + f"{entry.name} = {line}", + ] + ) + code.writeline(f"{x}mask = {entry.name} < {x}numel") + class TritonScheduling(SIMDScheduling): int32_type = "tl.int32"