Skip to content

Commit

Permalink
Inductor respects strides for custom ops by default (#126986)
Browse files Browse the repository at this point in the history
Previously, the default was that Inductor did not respect strides for
all (builtin and custom) ops unless the op has a
"needs_fixed_stride_order" tag on it. This PR changes it so that:

- inductor doesn't respect strides for builtin ops. To change the
  behavior, one can add the "needs_fixed_stride_order" tag
- inductor does respect strides for custom ops. To change the behavior,
  one can add the "does_not_need_fixed_stride_order" tag

Test Plan:
- new tests

Pull Request resolved: #126986
Approved by: https://github.com/ezyang, https://github.com/albanD
  • Loading branch information
zou3519 authored and pytorchmergebot committed May 24, 2024
1 parent f14cdc5 commit dd64ca2
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 4 deletions.
10 changes: 10 additions & 0 deletions aten/src/ATen/native/tags.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@
desc: |
This tag indicates that the operator should be passed Tensors following
the same stride permutation as observed in eager when compiled in inductor.
The default for custom ops (i.e. not torch._library.utils.is_builtin)
is that they do need a fixed stride order; add `does_not_need_fixed_stride_order`
to change the behavior.
The default for builtin ops is that they do not need a fixed stride order;
add `needs_fixed_stride_order` to change the behavior.
- tag: does_not_need_fixed_stride_order
desc: |
This tag indicates that the operator doesn't need to be passed Tensors following
the same stride permutation as observed in eager when compiled in inductor.
See `needs_fixed_stride_order` for more details.
# NOTE [Core ATen Ops]
- tag: core
Expand Down
59 changes: 57 additions & 2 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9731,7 +9731,6 @@ def bar_meta(x):
bar_cuda,
bar_xpu,
bar_meta,
tags=[torch._C.Tag.needs_fixed_stride_order],
)

def fn(x):
Expand Down Expand Up @@ -9794,13 +9793,69 @@ def baz_meta(x):
baz_cuda,
baz_xpu,
baz_meta,
tags=[torch._C.Tag.needs_fixed_stride_order],
)

with torch.no_grad():
net = torch.compile(model)
out = net(input_t)

@requires_gpu()
@config.patch(implicit_fallbacks=True)
def test_needs_fixed_stride_order(self):
with torch.library._scoped_library("prims", "FRAGMENT") as prims_lib:
with torch.library._scoped_library("custom", "FRAGMENT") as custom_lib:
strides = []

def foo_impl(x):
strides.append(x.stride())
return x.clone()

def foo_meta(x):
return x.clone()

all_ops = []
for (
needs_fixed_stride_order,
does_not_need_fixed_stride_order,
) in itertools.product([True, False], [True, False]):
tags = []
if needs_fixed_stride_order:
tags.append(torch.Tag.needs_fixed_stride_order)
if does_not_need_fixed_stride_order:
tags.append(torch.Tag.does_not_need_fixed_stride_order)
name = f"foo_{int(needs_fixed_stride_order)}{int(does_not_need_fixed_stride_order)}"
for ns, lib in {"custom": custom_lib, "prims": prims_lib}.items():
all_ops.append(ns + "::" + name)
lib.define(f"{name}(Tensor x) -> Tensor", tags=tags)
lib.impl(name, foo_impl, "CompositeExplicitAutograd")
lib.impl(name, foo_meta, "Meta")

assert len(all_ops) == 8
expect_contig_strides = {
"custom::foo_01",
"prims::foo_00",
"prims::foo_01",
}
print(all_ops)

for qualname in all_ops:
ns, name = qualname.split("::")
op = getattr(getattr(torch.ops, ns), name)

@torch.compile(fullgraph=True)
def f(x):
y = x.t().contiguous().t()
y = y.sin()
return op(y)

x = torch.randn(24, 24, device=self.device)
f(x)
stride = strides[-1]
if qualname in expect_contig_strides:
self.assertEqual(stride, (24, 1))
else:
self.assertEqual(stride, (1, 24))

def test_buffer_use_after_remove(self):
# https://github.com/pytorch/pytorch/issues/102857

Expand Down
17 changes: 16 additions & 1 deletion torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,22 @@ def get_custom_op_layout_constraints(target, args, kwargs):
# which run through implicit fallback must constrain their
# arguments' fx strides
layout_constraint = None
if torch._C.Tag.needs_fixed_stride_order in target.tags:

def needs_fixed_stride_order(target):
if (
torch._C.Tag.needs_fixed_stride_order in target.tags
and torch._C.Tag.does_not_need_fixed_stride_order in target.tags
):
# If both tags were specified, pessimistically assume that we do need it.
return True
if torch._library.utils.is_builtin(target):
return torch._C.Tag.needs_fixed_stride_order in target.tags
else:
return (
torch._C.Tag.does_not_need_fixed_stride_order not in target.tags
)

if needs_fixed_stride_order(target):
# We have to set the current args because call_function will immediately
# evaluate this lowering after creating the fallback, without evaluating
# the layout constraint
Expand Down
2 changes: 1 addition & 1 deletion torch/_library/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def _register_to_dispatcher(self) -> None:

lib.define(
schema_str,
tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order],
tags=[_C.Tag.pt2_compliant_tag],
)
self._opoverload = _library.utils.lookup_op(self._qualname)

Expand Down

0 comments on commit dd64ca2

Please sign in to comment.