Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed May 23, 2024
1 parent 37c2e14 commit 00cf903
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 44 deletions.
60 changes: 17 additions & 43 deletions torch/_inductor/codegen/simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ..ir import TritonTemplateBuffer
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
from ..runtime.runtime_utils import get_max_y_grid, green_text, yellow_text
from ..runtime.runtime_utils import green_text, yellow_text
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
from ..utils import (
get_dtype_size,
Expand Down Expand Up @@ -247,54 +247,16 @@ def add(node):
return list(reversed(index_vars)), list(reversed(sizes))

def ranges_code(self):
assert self.tensor_dim is not None
size = self.kernel.indexing_size_str(self.tensor_dim)
index_dtype = self.kernel.index_dtype
convert = f".to({index_dtype})" if index_dtype != "tl.int32" else ""
return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}{convert}"
return self.kernel.iteration_ranges_ranges_code(self)

def scalar_code(self, value):
index_dtype = self.kernel.index_dtype
ndim = self.kernel.triton_tensor_ndim()
size = [1] * ndim
return f"tl.full({size}, {value}, {index_dtype})"
return self.kernel.iteration_ranges_scalar_code(self, value)

def get_pid(self):
assert self.grid_dim is not None
key = f"tl.program_id({self.grid_dim})"
# y_grid has a limit, so express it in terms of y and z in case of overflow.
# z grid is only exercised when max_tiles == 3 (off by default).
if (
self.grid_dim == 1
and not self.has_zdim
and not (isinstance(self.numel, int) and self.numel <= get_max_y_grid())
):
key = f"{key} * (tl.program_id({self.grid_dim + 1}) + 1)"
pid = self.pid_cache.get(key, key)
if self.kernel.index_dtype != "tl.int32":
return f"{pid}.to({self.kernel.index_dtype})"
return pid
return self.kernel.iteration_ranges_get_pid(self)

def codegen_header(self, code):
x = self.prefix
if self.is_loop:
code.writeline(f"{self.name} = {x}offset + {x}base")
elif self.grid_dim is None:
# no need to "{x}offset = "
code.writeline(f"{self.name} = {self.ranges_code()}")
code.writeline(f"{x}offset = 0")
else:
if self.tensor_dim is not None:
line = f"{x}offset + {self.ranges_code()}"
else:
line = self.scalar_code(f"{x}offset")
code.writelines(
[
f"{x}offset = {self.get_pid()} * {x.upper()}BLOCK",
f"{self.name} = {line}",
]
)
code.writeline(f"{x}mask = {self.name} < {x}numel")
return self.kernel.iteration_ranges_codegen_header(self, code)


class IterationRangesEntry(IterationRanges):
Expand Down Expand Up @@ -1090,6 +1052,18 @@ def codegen_body(self):
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
raise NotImplementedError

def iteration_ranges_ranges_code(self, entry):
raise NotImplementedError

def iteration_ranges_scalar_code(self, entry, value):
raise NotImplementedError

def iteration_ranges_get_pid(self, entry):
raise NotImplementedError

def iteration_ranges_codegen_header(self, entry, code):
raise NotImplementedError


class SIMDScheduling(BaseScheduling):
kernel_type = SIMDKernel # override in subclass
Expand Down
52 changes: 51 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..ir import IRNode
from ..metrics import is_metric_table_enabled, log_kernel_metadata
from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
from ..runtime.runtime_utils import do_bench_gpu, next_power_of_2
from ..runtime.runtime_utils import do_bench_gpu, get_max_y_grid, next_power_of_2
from ..utils import (
cache_on_self,
get_bounds_index_expr,
Expand Down Expand Up @@ -2160,6 +2160,56 @@ def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
# lift non-reduction stores outside loop
self.body.writeline(line)

def iteration_ranges_ranges_code(self, entry):
assert entry.tensor_dim is not None
size = self.indexing_size_str(entry.tensor_dim)
index_dtype = self.index_dtype
convert = f".to({index_dtype})" if index_dtype != "tl.int32" else ""
return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{convert}"

def iteration_ranges_scalar_code(self, entry, value):
index_dtype = self.index_dtype
ndim = self.triton_tensor_ndim()
size = [1] * ndim
return f"tl.full({size}, {value}, {index_dtype})"

def iteration_ranges_get_pid(self, entry):
assert entry.grid_dim is not None
key = f"tl.program_id({entry.grid_dim})"
# y_grid has a limit, so express it in terms of y and z in case of overflow.
# z grid is only exercised when max_tiles == 3 (off by default).
if (
entry.grid_dim == 1
and not entry.has_zdim
and not (isinstance(entry.numel, int) and entry.numel <= get_max_y_grid())
):
key = f"{key} * (tl.program_id({entry.grid_dim + 1}) + 1)"
pid = entry.pid_cache.get(key, key)
if self.index_dtype != "tl.int32":
return f"{pid}.to({self.index_dtype})"
return pid

def iteration_ranges_codegen_header(self, entry, code):
x = entry.prefix
if entry.is_loop:
code.writeline(f"{entry.name} = {x}offset + {x}base")
elif entry.grid_dim is None:
# no need to "{x}offset = "
code.writeline(f"{entry.name} = {entry.ranges_code()}")
code.writeline(f"{x}offset = 0")
else:
if entry.tensor_dim is not None:
line = f"{x}offset + {entry.ranges_code()}"
else:
line = entry.scalar_code(f"{x}offset")
code.writelines(
[
f"{x}offset = {entry.get_pid()} * {x.upper()}BLOCK",
f"{entry.name} = {line}",
]
)
code.writeline(f"{x}mask = {entry.name} < {x}numel")


class TritonScheduling(SIMDScheduling):
int32_type = "tl.int32"
Expand Down

0 comments on commit 00cf903

Please sign in to comment.