Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed Jun 19, 2024
1 parent a0e1e20 commit cf003e9
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 32 deletions.
16 changes: 10 additions & 6 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9821,7 +9821,7 @@ def fn(x: torch.Tensor) -> torch.Tensor:
# Inductor specializes on the (unguarded) alignment of the initial input.
# Make sure that for different configurations, nothing breaks.
for offset in (0, 1, 2, 3, 4):
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device)
inp = torch.as_strided(base, (64, 64), (64, 1), offset)
torch._dynamo.reset()
fn_c = torch.compile(fn)
Expand All @@ -9831,8 +9831,10 @@ def fn(x: torch.Tensor) -> torch.Tensor:
self.assertEqual(ref, res)

for offset2 in (0, 1, 2, 3, 4):
base2 = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
inp2 = torch.as_strided(base, (64, 64), (64, 1), offset2)
base2 = torch.randn(
64 * 64 + 64, dtype=torch.float32, device=self.device
)
inp2 = torch.as_strided(base2, (64, 64), (64, 1), offset2)
ref2 = fn(inp2)
res2 = fn_c(inp2)
self.assertEqual(ref2, res2)
Expand All @@ -9853,7 +9855,7 @@ def fail(guard):
def fn(x: torch.Tensor) -> torch.Tensor:
return x.sin() + x.cos()

base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device)

inp1 = torch.as_strided(base, (32, 32), (32, 1), 4)
inp2 = torch.as_strided(base, (64, 64), (64, 1), 4)
Expand Down Expand Up @@ -9898,9 +9900,11 @@ def fn(x):
((64, 64), (64, 1), 5),
):
torch.manual_seed(42)
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device)
torch.manual_seed(42)
base_ref = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
base_ref = torch.randn(
64 * 64 + 64, dtype=torch.float32, device=self.device
)

inp = torch.as_strided(base, size, stride, offset)
inp_ref = torch.as_strided(base_ref, size, stride, offset)
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class BackendFeature(Enum):
MASKED_SCATTER_WITH_INDEX = auto()
SCAN = auto()
TUPLE_REDUCTION = auto()
PREFER_STORE_LOOP_ORDER = auto()


def get_backend_features(device: Union[torch.device, str]):
Expand Down
47 changes: 40 additions & 7 deletions torch/_inductor/codegen/simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,17 @@ def cache_clear(self):
for node in self.nodes.values():
node.cache_clear()

def index_sym(self):
return sympy_index_symbol(f"{self.prefix}index")

def lookup(self, divisor, length):
"""
Lookup a given RangeTreeEntry, creating it if needed
"""
if V.graph.sizevars.statically_known_equals(divisor * length, self.numel):
expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor)
expr = FloorDiv(self.index_sym(), divisor)
else:
expr = ModularIndexing(
sympy_index_symbol(f"{self.prefix}index"), divisor, length
)
expr = ModularIndexing(self.index_sym(), divisor, length)

if expr not in self.nodes:
node = IterationRangesEntry(
Expand Down Expand Up @@ -384,6 +385,13 @@ def initialize_range_tree(self, pid_cache):
)
)

def finalize_indexing(self, indices: Sequence[sympy.Expr]):
"""
Hook called right before codegen with every index that will be
used in the fused kernel.
"""
pass

def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
prior = self.inside_reduction
self.inside_reduction = False
Expand Down Expand Up @@ -431,7 +439,16 @@ def combine_modular_indexing_pairs(self, index):
if (tree_node := self.range_tree_nodes.get(x)) is None:
return index
new_index = sympy_subs(index, {x: tree_node.expr})
return V.graph.sizevars.combine_modular_indexing_pairs(new_index)
new_index = V.graph.sizevars.combine_modular_indexing_pairs(new_index)
# the index now contains xindex/etc, which is nonstandard, fix it up
return sympy_subs(
new_index,
{
tree_node.root.index_sym(): tree_node.root.lookup(
sympy.Integer(1), tree_node.root.numel
).symbol()
},
)

def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
if expand_res := V.graph.sizevars.expand_floor_div(index):
Expand Down Expand Up @@ -895,7 +912,7 @@ def codegen_body(self):
pass

def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
raise NotImplementedError
pass


class SIMDScheduling(BaseScheduling):
Expand Down Expand Up @@ -1371,10 +1388,26 @@ def current_reduction_nodes(nodes):
with kernel:
stack = contextlib.ExitStack()
kernel.set_last_usage(current_reduction_nodes(node_schedule))
all_indexing = {}

# First pass to collect indexing and decide inplace updates
for node in node_schedule:
if node not in (EnableReduction, DisableReduction):
if node is DisableReduction:
stack.enter_context(kernel.disable_reduction())
elif node is EnableReduction:
stack.close()
else:
node.decide_inplace_update()
index_vars = kernel.split_and_set_ranges(node.get_ranges())
all_indexing.update(
dict.fromkeys(
node._body.indexing_from_args(index_vars).values()
)
)

kernel.finalize_indexing(all_indexing.keys())

# Second pass to do codegen
for i, node in enumerate(node_schedule):
if node is DisableReduction:
stack.enter_context(kernel.disable_reduction())
Expand Down
26 changes: 9 additions & 17 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3416,16 +3416,9 @@ def simplify_and_reorder(
]
index_formulas += extra_indexing_expr

reads_bufs = [
V.graph.name_to_buffer[reads_name]
if reads_name in V.graph.name_to_buffer.keys()
else None
for reads_name in body.reads_name2expr.keys()
]
memory_addrs = [
*body.reads_name2expr.values(),
*body.writes_name2expr.values(),
]
memory_addrs = [*body.writes_name2expr.values()]
if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER):
memory_addrs.extend(body.reads_name2expr.values())

def simplify_and_reorder(x_vars, support_vars, sizes):
sizes, reindex0, reindex1 = self._apply_loop_reordering(
Expand All @@ -3438,10 +3431,6 @@ def simplify_and_reorder(x_vars, support_vars, sizes):
sizes,
index_prevent_reordering(index_formulas, x_vars, sizes),
)
x_vars = prune(x_vars)
# sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas)
# x_vars = prune(x_vars)
# sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs)
reindex = fuse_reindexing(reindex1, reindex2)
return sizes, reindex, reindex1

Expand Down Expand Up @@ -6404,15 +6393,18 @@ def get_index(self, name):
assert self.indexing is not None
return self.indexing[name]

def __call__(self, *indices):
index = list(itertools.chain.from_iterable(indices))
def indexing_from_args(self, indices):
index = [*itertools.chain.from_iterable(indices)]
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
assert all(v not in self.var_ranges for v in index)
replacements = dict(zip(self.var_ranges.keys(), index))
self.indexing = {
return {
name: sympy_subs(expr, replacements)
for name, expr in self.indexing_exprs.items()
}

def __call__(self, *indices):
self.indexing = self.indexing_from_args(indices)
result = self.root_block()
self.indexing = None
return result
Expand Down
14 changes: 12 additions & 2 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,18 @@ def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool:

def evaluate_min(self, left: Expr, right: Expr) -> Expr:
"""return the smaller of left and right, and guard on that choice"""
lv = self.size_hint(left)
rv = self.size_hint(right)
try:
lv = self.size_hint(left)
rv = self.size_hint(right)
except TypeError: # unbacked symints
gcd = sympy.gcd(left, right)
if left == gcd: # handle `min(10*u0, u0)` etc
return left
if right == gcd:
return right
raise TypeError(
f"evaluate_min({left}, {right}) with unbacked symints"
) from None
if lv <= rv:
self.guard_leq(left, right)
return left
Expand Down

0 comments on commit cf003e9

Please sign in to comment.