Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 160 additions & 100 deletions aten/src/ATen/FunctionalInverses.cpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion aten/src/ATen/native/TestOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Tensor _test_check_tensor(const Tensor& self) {
namespace at::functionalization {

// view_copy ops must have a functional inverse registered
Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) {
Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
TORCH_INTERNAL_ASSERT(false,
"Attempted to call _test_autograd_multiple_dispatch_view_copy_inverse() during the functionalization pass. ",
"This function is for testing only and should never be called.");
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/templates/FunctionalInverses.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
namespace at {
namespace functionalization {

enum class InverseReturnMode {
/// Specifies that functional inverses should always return a view.
AlwaysView,
/// Specifies that functional inverses should always return a non-view / copy.
NeverView,
/// Specifies that functional inverses should return a view unless a (copying) scatter
/// inverse exists, in which case that will be used instead.
/// This avoids as_strided() calls that can be difficult for subclasses to handle.
ViewOrScatterInverse,
};

struct FunctionalInverses {

${view_inverse_declarations}
Expand Down
247 changes: 247 additions & 0 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,21 @@ def forward(self, arg0_1):
as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1)
copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None
return as_strided_scatter
""")

# NB: even with reapply_views=True, we expect to see scatter op
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, arg0_1):
as_strided = torch.ops.aten.as_strided.default(arg0_1, [2], [2], 1)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None
as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1)
copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None
return as_strided_scatter
""")

def test_tensor_list_composite(self):
Expand Down Expand Up @@ -584,6 +599,22 @@ def forward(self, arg0_1):
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter)
copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None
return diagonal_scatter
""")

# NB: even with reapply_views=True, we expect to see scatter op
reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
diagonal = torch.ops.aten.diagonal.default(arg0_1)
add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None
diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter)
copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None
return diagonal_scatter
""")

def test_channels_last_contiguous(self):
Expand Down Expand Up @@ -635,6 +666,143 @@ def forward(self, arg0_1):
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
return diagonal_copy_1
""") # noqa: B950

# NB: even with reapply_views=True, we expect to see scatter op
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
split = torch.ops.aten.split.Tensor(arg0_1, 2)
getitem = split[0]
getitem_1 = split[1]; split = None
diagonal = torch.ops.aten.diagonal.default(getitem_1); getitem_1 = None
add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None
split_1 = torch.ops.aten.split.Tensor(arg0_1, 2)
getitem_2 = split_1[0]
getitem_3 = split_1[1]; split_1 = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None
slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None
split_2 = torch.ops.aten.split.Tensor(slice_scatter, 2)
getitem_4 = split_2[0]
getitem_5 = split_2[1]; split_2 = None
diagonal_1 = torch.ops.aten.diagonal.default(getitem_5); getitem_5 = None
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
return diagonal_1
""") # noqa: B950

def test_split_with_sizes(self):
def f(x):
# test: view ops that return multiple tensors (split_with_sizes)
tmp = torch.ones(2)
y1, y2 = x.split_with_sizes([2, 2])
y3 = y1.diagonal()
y3.add_(tmp)
z = x * x
return y3
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
split_with_sizes_copy = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
getitem = split_with_sizes_copy[0]
getitem_1 = split_with_sizes_copy[1]; split_with_sizes_copy = None
diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem); getitem = None
add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None
split_with_sizes_copy_1 = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2])
getitem_2 = split_with_sizes_copy_1[0]
getitem_3 = split_with_sizes_copy_1[1]; split_with_sizes_copy_1 = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None
slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None
split_with_sizes_copy_2 = torch.ops.aten.split_with_sizes_copy.default(slice_scatter, [2, 2])
getitem_4 = split_with_sizes_copy_2[0]
getitem_5 = split_with_sizes_copy_2[1]; split_with_sizes_copy_2 = None
diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_4); getitem_4 = None
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
return diagonal_copy_1
""") # noqa: B950

# NB: even with reapply_views=True, we expect to see scatter op
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False)
split_with_sizes = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]; split_with_sizes = None
diagonal = torch.ops.aten.diagonal.default(getitem); getitem = None
add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2])
getitem_2 = split_with_sizes_1[0]
getitem_3 = split_with_sizes_1[1]; split_with_sizes_1 = None
diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None
slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(slice_scatter, [2, 2])
getitem_4 = split_with_sizes_2[0]
getitem_5 = split_with_sizes_2[1]; split_with_sizes_2 = None
diagonal_1 = torch.ops.aten.diagonal.default(getitem_4); getitem_4 = None
mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter)
copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None
return diagonal_1
""") # noqa: B950

def test_slice(self):
def f(x):
tmp = torch.ones(4)
x.transpose_(1, 0)
y = x[0:2]
y.add_(tmp)
return x
self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
slice_copy = torch.ops.aten.slice_copy.Tensor(transpose_copy, 0, 0, 2); transpose_copy = None
add = torch.ops.aten.add.Tensor(slice_copy, ones); slice_copy = ones = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None
slice_scatter = torch.ops.aten.slice_scatter.default(transpose_copy_1, add, 0, 0, 2); transpose_copy_1 = add = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(slice_scatter, 1, 0); slice_scatter = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2); transpose_copy_3 = None
transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
return transpose_copy_4
""") # noqa: B950

# NB: even with reapply_views=True, we expect to see scatter op
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
slice_1 = torch.ops.aten.slice.Tensor(transpose, 0, 0, 2); transpose = None
add = torch.ops.aten.add.Tensor(slice_1, ones); slice_1 = ones = None
transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None
slice_scatter = torch.ops.aten.slice_scatter.default(transpose_1, add, 0, 0, 2); transpose_1 = add = None
transpose_2 = torch.ops.aten.transpose.int(slice_scatter, 1, 0); slice_scatter = None
transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2); transpose_3 = None
transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None
return transpose_4
""") # noqa: B950

def test_view_inplace(self):
Expand Down Expand Up @@ -663,6 +831,82 @@ def forward(self, arg0_1):
select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = None
transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
return transpose_copy_4
""") # noqa: B950

# NB: even with reapply_views=True, we expect to see scatter op
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
select = torch.ops.aten.select.int(transpose, 0, 0); transpose = None
add = torch.ops.aten.add.Tensor(select, ones); select = ones = None
transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None
select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None
transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None
transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
select_1 = torch.ops.aten.select.int(transpose_3, 0, 0); transpose_3 = None
transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None
return transpose_4
""") # noqa: B950

def test_unbind(self):
def f(x):
# test: view + inplace op (transpose_)
tmp = torch.ones(4)
x.transpose_(1, 0)
y, _ = x.unbind(0)
y.add_(tmp)
return x
self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True)
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0)
unbind_copy = torch.ops.aten.unbind_copy.int(transpose_copy); transpose_copy = None
getitem = unbind_copy[0]
getitem_1 = unbind_copy[1]; unbind_copy = None
add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None
select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None
transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None
transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0)
unbind_copy_1 = torch.ops.aten.unbind_copy.int(transpose_copy_3); transpose_copy_3 = None
getitem_2 = unbind_copy_1[0]
getitem_3 = unbind_copy_1[1]; unbind_copy_1 = None
transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None
return transpose_copy_4
""") # noqa: B950

# NB: even with reapply_views=True, we expect to see scatter op
reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False)
self.assertExpectedInline(reinplaced_logs, """\
def forward(self, arg0_1):
ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False)
transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0)
unbind = torch.ops.aten.unbind.int(transpose); transpose = None
getitem = unbind[0]
getitem_1 = unbind[1]; unbind = None
add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None
transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None
select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None
transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None
transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0)
unbind_1 = torch.ops.aten.unbind.int(transpose_3); transpose_3 = None
getitem_2 = unbind_1[0]
getitem_3 = unbind_1[1]; unbind_1 = None
transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None
return transpose_4
""") # noqa: B950

def test_optional_tensor_list(self):
Expand Down Expand Up @@ -1677,7 +1921,10 @@ def forward(self, arg0_1):
"test_diagonal_mutated_input",
"test_everything",
"test_fill_",
"test_slice",
"test_split",
"test_split_with_sizes",
"test_unbind",
"test_view_clone_view_inplace",
"test_view_inplace",
])
Expand Down
25 changes: 22 additions & 3 deletions torchgen/api/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from torchgen.api import dispatcher
from torchgen.api.types import (
BaseCppType,
BaseCType,
Binding,
boolT,
Expand Down Expand Up @@ -69,6 +70,20 @@
default=None,
)

InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
inverse_return_mode_binding = Binding(
name="inverse_return_mode",
nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
argument=Argument(
name="inverse_return_mode",
# NB: not actually a bool but it doesn't matter because this isn't used
type=BaseType(BaseTy.bool),
default=None,
annotation=None,
),
default=None,
)


# The lambda capture itself doesn't have a name.
# The name returned here corresponds to the name of the inner function called by the lambda.
Expand Down Expand Up @@ -115,7 +130,11 @@ def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding
non_self_value_bindings = [
dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
]
all_bindings = [reapply_views_binding] + non_self_value_bindings

all_bindings = [
inverse_return_mode_binding if is_reverse else reapply_views_binding
]
all_bindings.extend(non_self_value_bindings)
return all_bindings


Expand Down Expand Up @@ -165,12 +184,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]:
return [
base_binding,
mutated_view_binding,
reapply_views_binding,
inverse_return_mode_binding,
index_binding,
] + non_self_bindings
else:
return [
base_binding,
mutated_view_binding,
reapply_views_binding,
inverse_return_mode_binding,
] + non_self_bindings
3 changes: 2 additions & 1 deletion torchgen/api/types/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ def captures(self) -> List[Expr]:
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
# and plumb it into the lambda.
outer_ctx = dispatcher.arguments(self.g.view.func) + [
functionalization.reapply_views_binding
functionalization.reapply_views_binding,
functionalization.inverse_return_mode_binding,
]
capture_bindings = functionalization.capture_arguments(
self.g.view.func, is_reverse=self.is_reverse
Expand Down
8 changes: 8 additions & 0 deletions torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ def emit_view_functionalization_body(
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
}}
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
auto inverse_return_mode = (
reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
: at::functionalization::InverseReturnMode::NeverView
);
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
{forward_lambda.decl()} {{
if (reapply_views) {{
Expand Down Expand Up @@ -387,6 +391,10 @@ def emit_view_functionalization_body(
return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
}}
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
auto inverse_return_mode = (
reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
: at::functionalization::InverseReturnMode::NeverView
);
auto compute_reference_meta =
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
Expand Down