Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions test/inductor/test_loop_ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def f(x):
"triton.unique_kernel_names": True,
"loop_ordering_after_fusion": True,
"triton.max_tiles": 3,
"test_configs.global_tiling_analysis": True,
"triton.coalesce_tiling_analysis": True,
}
)
@instantiate_parametrized_tests
Expand Down Expand Up @@ -798,13 +798,14 @@ def fn(nodes):
# coalesce twice as many bytes as first dimension
# if not downcasted
# if downcasted, should be equal, bc larger dtype size
# we also weight writes x 2
cont_reads = coalesce_analysis.coalesced_by_var[i_vars[1]]
t_reads = coalesce_analysis.coalesced_by_var[i_vars[0]]

if not downcast_transposed_v:
self.assertEqual(cont_reads, t_reads * 2)
self.assertEqual(cont_reads, t_reads * 3)
else:
self.assertEqual(cont_reads, t_reads)
self.assertEqual(cont_reads, t_reads * 1.5)

return nodes

Expand Down Expand Up @@ -908,8 +909,7 @@ def forward(permute):
{
"triton.unique_kernel_names": True,
"loop_ordering_after_fusion": True,
"test_configs.global_tiling_analysis": True,
"triton.max_tiles": 3,
"triton.coalesce_tiling_analysis": True,
}
)
@instantiate_parametrized_tests
Expand Down
3 changes: 3 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14133,6 +14133,8 @@ def f(x, mask):
# it does not move the tensor constructor to cuda and keeps it on CPU.
self.assertFalse("empty_strided_cuda(()" in code)

# only uncoalesced without this :)
@config.patch("triton.coalesce_tiling_analysis", False)
@config.patch("triton.use_block_ptr", False)
def test_evict_last_non_coalesced_loads(self):
@torch.compile
Expand Down Expand Up @@ -14183,6 +14185,7 @@ def f(a, b):
)

@config.patch("triton.use_block_ptr", True)
@config.patch("triton.coalesce_tiling_analysis", False)
def test_evict_last_non_coalesced_loads_block_ptr(self):
@torch.compile
def f(a, b):
Expand Down
2 changes: 2 additions & 0 deletions test/inductor/test_triton_cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
if HAS_CPU and TRITON_HAS_CPU:

@config.patch(cpu_backend="triton")
@config.patch("triton.coalesce_tiling_analysis", False)
class SweepInputsCpuTritonTest(test_torchinductor.SweepInputsCpuTest):
pass

@config.patch(cpu_backend="triton")
@config.patch("triton.coalesce_tiling_analysis", False)
class CpuTritonTests(test_torchinductor.TestCase):
common = test_torchinductor.check_model
device = "cpu"
Expand Down
31 changes: 20 additions & 11 deletions torch/_inductor/codegen/simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@
all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"])


def get_max_tiles(default: int = 2) -> int:
max_tiles = torch._inductor.config.triton.max_tiles
return max_tiles if max_tiles is not None else default


@dataclasses.dataclass
class IterationRanges:
"""
Expand Down Expand Up @@ -1354,7 +1359,7 @@ def codegen_node(

nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment]

if torch._inductor.config.test_configs.global_tiling_analysis:
if torch._inductor.config.triton.coalesce_tiling_analysis:
coalesce_analysis = analyze_memory_coalescing(node)
else:
coalesce_analysis = None
Expand Down Expand Up @@ -1993,7 +1998,7 @@ def get_nd_tilings(

# Flatten leading dimensions, assigning labels to each dim.
for node_tiling in node_tilings:
num_leading_dims = max(0, len(node_tiling) - config.triton.max_tiles)
num_leading_dims = max(0, len(node_tiling) - get_max_tiles(2))
first_trailing_dim = num_leading_dims + 1
collapsed_leading_dim = sympy_product(node_tiling[:first_trailing_dim])
collapsed_splits = (collapsed_leading_dim,) + tuple(
Expand Down Expand Up @@ -2165,7 +2170,7 @@ def process_node_vars(
)
)

if torch._inductor.config.triton.max_tiles == 3 and reduction_numel == 1:
if get_max_tiles(default=3) == 3 and reduction_numel == 1:
for vars_to_use in itertools.combinations(overlapping_iter_vars, 2):
score_split.append(
(
Expand All @@ -2187,13 +2192,16 @@ def process_node_vars(

# add a slight penalty for longer tilings that dont increase score much,
# and are poor sizes
additional_tiling_penalty = 1.025
bad_size_additional_tiling_penalty = 1.025
good_size_tiling_penalty = 1.005

def score_mod(t):
score_factor = 1.0
for tile_size in t[0].tiling.values():
if not CandidateTiling.is_good_size(tile_size):
score_factor = score_factor / additional_tiling_penalty
score_factor = score_factor / bad_size_additional_tiling_penalty
else:
score_factor = score_factor / good_size_tiling_penalty

return -t[0].score * score_factor

Expand All @@ -2204,7 +2212,7 @@ def score_mod(t):
):
# we always include default reduction numel == 1, dont include
tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0)
if tiling_len > torch._inductor.config.triton.max_tiles:
if tiling_len > get_max_tiles(default=3):
perf_hint_log.info(
"Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles "
"set to %s. Consider increasing",
Expand Down Expand Up @@ -2289,16 +2297,17 @@ def get_tiling_and_scores(

# # TODO: enable by default
if (
torch._inductor.config.test_configs.global_tiling_analysis
torch._inductor.config.triton.coalesce_tiling_analysis
and coalesce_analysis
and not config.triton.prefer_nd_tiling
):
return cls.compute_tiling_strategy(
node_schedule, numel, reduction_numel, coalesce_analysis
)

if (
not is_pointwise and not config.triton.tile_reductions
) or config.triton.max_tiles <= 1:
if (not is_pointwise and not config.triton.tile_reductions) or get_max_tiles(
default=2
) <= 1:
# Emit a perf hint in case we miss an opportunity to tile a reduction.
if perf_hint_log.level <= logging.WARNING:
for node in EnableReduction.filter(node_schedule):
Expand Down Expand Up @@ -2333,7 +2342,7 @@ def get_tiling_and_scores(
for candidate_tiling, score in candidate_tiles.most_common()
]

if config.triton.max_tiles >= 3 and is_pointwise:
if get_max_tiles(default=2) >= 3 and is_pointwise:
# Consider adding a third dimension of tiling, but only
# when a1 is a multiple of b1; otherwise, you have a lot
# of stragglers which is annoying to generate code for.
Expand Down
18 changes: 13 additions & 5 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,12 +1115,23 @@ class triton:
# Always load full blocks (rather than broadcasting inside the block)
dense_indexing = False

# TODO - enable by default
coalesce_tiling_analysis: bool = (
os.environ.get(
"TORCHINDUCTOR_COALESCE_TILING_ANALYSIS", "1" if not is_fbcode() else "0"
)
== "1"
)

# limit tiling dimensions
# - max_tiles=1 disables tiling
# - max_tiles=2 is the default
# - max_tiles=2
# - max_tiles=3 is experimental and may have bugs
# higher values are unsupported
max_tiles = 2

# We use a max of 3 if coalesce_tiling_analysis is True, and 2 otherwise.
# Note - coalesce_tiling_analysis does not yet apply to dynamic shapes.
max_tiles: Optional[int] = None

# Prefer higher dimensional tilings. This simplifies indexing expressions, making
# it easier to identify block pointers.
Expand Down Expand Up @@ -1681,9 +1692,6 @@ class test_configs:

graphsafe_rng_func_ignores_fallback_random = False

# TODO - temporary config before enabled by default
global_tiling_analysis: bool = False


if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
Expand Down
11 changes: 10 additions & 1 deletion torch/_inductor/tiling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,9 @@ class VarTiling:

@dataclasses.dataclass(frozen=True)
class CoalesceVarAnalysis:
# Var -> Memory Score - not strictly the amount of memory
# because we multiply writes x2
# TODO: separate into dataclass that olds mem, dtype, is_write
coalesced_by_var: dict[sympy.Expr, int]

norm_read_writes: FusedNormalizedReadsWrites
Expand Down Expand Up @@ -656,7 +659,10 @@ def analyze_memory_coalescing(
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()

for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()):
for is_read, (memory_expr, buf_names) in itertools.chain(
((True, item) for item in reads.items()),
((False, item) for item in writes.items()),
):
# skip memory deps with indirect vars - todo: better handling
indirect_expr = bool(
memory_expr.free_symbols - norm_read_writes.var_ranges.keys()
Expand All @@ -676,6 +682,9 @@ def analyze_memory_coalescing(
if buf := V.graph.try_get_buffer(buf_name):
byte_multipler += buf.dtype.itemsize

# coalesced writes more important
byte_multipler *= 1 if is_read else 2

if maybe_coalesced_var:
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
else:
Expand Down
Loading