Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions torch/_inductor/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,22 @@ def reads_and_writes(self):
class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
def __init__(self, var_ranges: VarRanges, normalize: bool):
super().__init__()
self._reads: Set[MemoryDep] = set()
self._reads: Set[Dep] = set()
self._writes: Set[MemoryDep] = set()
self._index_exprs: Set[IndexExprDep] = set()
self._var_ranges: VarRanges = var_ranges
self._normalize: bool = normalize

def canonicalize(
self, index: sympy.Expr
) -> Tuple[sympy.Expr, Tuple[sympy.Expr, ...]]:
) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
if not self._normalize:
sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
var_names = tuple(
k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
)
sizes = tuple(v for v in sizes if v != 1)
return index, var_names, sizes # type: ignore[return-value]
return index, var_names, sizes

# Try to further simplify the indexes even if simplify_loops didn't
# convert it to the simplest form because of the interference from
Expand All @@ -240,7 +240,7 @@ def canonicalize(
# if k in free_symbols
}
index_vars = [*var_ranges.keys()]
sizes = [*var_ranges.values()] # type: ignore[assignment]
sizes = tuple(var_ranges.values())
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
index_vars,
sizes,
Expand All @@ -261,25 +261,25 @@ def canonicalize(
# downstream users won't. Normalize this away.
new_vars.pop()
new_sizes.pop()
return index, tuple(new_vars), tuple(new_sizes) # type: ignore[return-value]
return index, tuple(new_vars), tuple(new_sizes)

def load(self, name: str, index: sympy.Expr) -> str:
self._reads.add(MemoryDep(name, *self.canonicalize(index))) # type: ignore[call-arg]
self._reads.add(MemoryDep(name, *self.canonicalize(index)))
return f"load({name}, {sympy_str(index)})"

def load_seed(self, name: str, index: int):
assert isinstance(index, int)
return self.load(name, sympy.Integer(index))

def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
self._writes.add(MemoryDep(name, *self.canonicalize(index))) # type: ignore[call-arg]
self._writes.add(MemoryDep(name, *self.canonicalize(index)))
return f"store({name}, {sympy_str(index)}, {value}, {mode})"

def store_reduction(self, name: str, index, value) -> str:
return self.store(name, index, f"store_reduction({value})")

def index_expr(self, index: sympy.Expr, dtype) -> str:
self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) # type: ignore[call-arg]
self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
return f"index_expr({sympy_str(index)}, {dtype})"

def bucketize(
Expand All @@ -290,7 +290,7 @@ def bucketize(
indexing_dtype: torch.dtype,
right: bool,
):
self._reads.add(StarDep(offsets_name)) # type: ignore[arg-type]
self._reads.add(StarDep(offsets_name))
return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"


Expand Down Expand Up @@ -376,7 +376,7 @@ def extract_read_writes(
)


def extract_input_node_reduction_ranges( # noqa: F722
def extract_input_node_reduction_ranges(
input_node: "torch._inductor.ir.TensorBox",
) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
"""
Expand Down