Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions test/inductor/test_torchinductor_strided_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,64 @@ def foo(x, y, z):
# Singleton splits should be discarded.
self._assert_pointwise_ndims(triton_code, 2)

# Integration test to ensure that matched dims & strides from match_mod_div_expr
# are unsigned and signed integers respectively. This test case has the following
# index:=(ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))
# and the match below is a candidate that is invalid:
# match={
# dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16,
# dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0
# }
# This is now fixed by ensuring that that wild symbols only match integers
def test_ensure_integral_dims_and_strides(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the test case! This is a good one.

def model(data, *args):
return torch.nn.functional.unfold(data, *args)

data = torch.zeros(
[2, 3, 5, 5], dtype=torch.float16, requires_grad=True, device=self.device
)
args = [2, 1, 0, 1]
run_and_compare(
self,
model,
data,
*args,
expected_num_triton_kernels=2,
expected_num_block_pointers=4,
compile_kwargs={"fullgraph": True},
)

# Integration test to test block analysis with index expressions using
# negative strides.
# This test case has the following index:
# index_relative_to_xyr_index = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8))
# - 16*(ModularIndexing(xindex, 8, 8)) + 1911
# subexpr = -256*((xindex//64)) - (ModularIndexing(xindex, 1, 8)) - 16*(ModularIndexing(xindex, 8, 8))
# Block analysis should produce the following:
# BlockParameters(
# shape=[8, 8, 8],
# block_shape=[((XBLOCK + 63)//64), Min(8, ((XBLOCK + 7)//8)), Min(8, XBLOCK) ],
# strides=[-256, -16, -1],
# offsets=[(xoffset//64), ModularIndexing(xoffset, 8, 8), ModularIndexing(xoffset, 1, 8)]
# )
# constant_offset = 1911
def test_negative_strides(self):
def model(x, y):
# Slice in reverse order via a negative stride
return torch.flip(x, [0, 1, 2]) + y

x, y = (
self._discontiguous_tensor((8, 8, 8), device=self.device) for _ in range(2)
)
run_and_compare(
self,
model,
x,
y,
expected_num_triton_kernels=1,
expected_num_block_pointers=3,
)

@config.patch("triton.prefer_nd_tiling", True)
@config.patch("triton.max_tiles", 3)
@parametrize(
Expand Down
27 changes: 22 additions & 5 deletions torch/_inductor/codegen/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ class BlockPatternMatcher:
Matches block indexing expressions.
"""

_indexing_wild_signed_int = functools.partial(
sympy.Wild, properties=[lambda x: x.is_integer]
)
_indexing_wild_unsigned_int = functools.partial(
sympy.Wild, properties=[lambda x: x.is_integer and x.is_nonnegative]
)

@classmethod
def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
"""
Expand Down Expand Up @@ -63,9 +70,18 @@ def match_mod_div_block_expr(
index = cls._preprocess(index)

# Pattern match to find the strides and offset.
wild = functools.partial(sympy.Wild, exclude=[index_var])
dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)]
strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)]
wild_unsigned_int = functools.partial(
cls._indexing_wild_unsigned_int, exclude=[index_var]
)
wild_signed_int = functools.partial(
cls._indexing_wild_signed_int, exclude=[index_var]
)
dims: list[Expr] = [
wild_unsigned_int(f"dim_mod{idx}") for idx in range(num_dims)
]
strides: list[Expr] = [
wild_signed_int(f"stride_mod{idx}") for idx in range(num_dims)
]

# The first dimension's index is computed by division.
# The remaining are computed by modulo.
Expand All @@ -83,7 +99,8 @@ def match_mod_div_block_expr(
# for more details. In short, here we check that each subexpression in sympy.Add contains
# only FloorDiv or ModularIndexing expressions.
if num_dims >= 5:
stride, denom, other = sympy.symbols("stride denominator other", cls=wild)
stride = sympy.symbols("stride", cls=wild_signed_int)
denom, other = sympy.symbols("denominator other", cls=wild_unsigned_int)
mod_div_pattern = stride * ModularIndexing(index_var, denom, other)
floor_div_pattern = stride * FloorDiv(index_var, denom)
first_dim_floor_div_matched = False
Expand Down Expand Up @@ -167,7 +184,7 @@ def match_affine_block_expr(
stride.
"""
index = cls._preprocess(index)
stride = sympy.Wild("stride", exclude=[index_var])
stride = cls._indexing_wild_signed_int(name="stride", exclude=[index_var])
m = index.match(index_var * stride)
if m is None:
return None
Expand Down
Loading