Skip to content
Closed
14 changes: 13 additions & 1 deletion aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
),
value_(value)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata();
}

Expand Down Expand Up @@ -130,6 +132,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
),
value_(view_value)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata();
// Copy the original tensor's ViewMeta vector and push the current one.
if (!base->view_metas_.empty()) {
Expand Down Expand Up @@ -168,7 +172,9 @@ void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta m
// So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()`
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
at::AutoDispatchSkipFunctionalize guard;
value_ = meta.forward_fn(value_, meta.out_index);
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
}

// Note [Functionalization: Mutation Removal]
Expand Down Expand Up @@ -200,15 +206,20 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
// TODO: going to need to change this if we want nested functionalize() transforms.
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
value_ = other;
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
// out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
// We need to propagate that metadata mutation to the wrapper (new size).
set_sizes_and_strides(value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset());
auto sizes_ = value_.sym_sizes();
auto strides_ = value_.sym_strides();
auto storage_offset_ = value_.sym_storage_offset();
set_sizes_and_strides(sizes_, strides_, storage_offset_);
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
// .to() should not re-entrantly go through functionalization.
at::AutoDispatchSkipFunctionalize guard;
// and we want _to_copy() to show up in the graph, not the composite .to() operator
// (this can happen if autograd has already run by the time we enter this code)
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
}
}

Expand Down Expand Up @@ -243,6 +254,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
// Then it's safe to throw out the old storage and replace it with the new, larger one.
storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
value_ = other;
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
generation_ = 0;
// And update the metadata on the wrapper to reflect the new sizes and strides
set_sizes_and_strides(value_.sizes(), value_.strides());
Expand Down
14 changes: 10 additions & 4 deletions aten/src/ATen/native/CPUFallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ c10::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args
}


void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) {
auto& schema_args = op.schema().arguments();
const auto num_arguments = schema_args.size();
auto arguments = torch::jit::last(stack, num_arguments);
Expand Down Expand Up @@ -176,9 +176,15 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
} else {
dev_str << "<none>";
}
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
if (error_on_views) {
TORCH_CHECK(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
} else {
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
}
}
// Case (2): copy case. Copy the cpu output tensor to the original device.

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/CPUFallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace at { namespace native {

// This function implements a boxed fallback to CPU.
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);

// This is a helper function that backends can use to directly call their boxed CPU fallback
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5062,7 +5062,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: slice_scatter
CompositeExplicitAutogradNonFunctional: slice_scatter
autogen: slice_scatter.out
tags: core

Expand All @@ -5071,23 +5071,23 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: select_scatter_symint
CompositeExplicitAutogradNonFunctional: select_scatter_symint
autogen: select_scatter.out

- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
variants: function, method
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: diagonal_scatter
CompositeExplicitAutogradNonFunctional: diagonal_scatter
autogen: diagonal_scatter.out

- func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
variants: function, method
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: as_strided_scatter_symint
CompositeExplicitAutogradNonFunctional: as_strided_scatter_symint
autogen: as_strided_scatter.out

- func: smm(Tensor self, Tensor mat2) -> Tensor
Expand Down
1 change: 0 additions & 1 deletion torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,7 +1799,6 @@ def meta_select(self, dim, index):

check(
not (-index > size or index >= size),
lambda: f"select(): index {index} out of range for tensor of size "
f"{self.size()} at dimension {dim}",
IndexError,
)
Expand Down
2 changes: 1 addition & 1 deletion torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def __call__(
or (ignore_subclass and isinstance(t, torch.Tensor))
or isinstance(t, FakeTensor)
):
if any(
if t.device.type != "xla" and any(
[
t.is_sparse_csr,
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
Expand Down
15 changes: 9 additions & 6 deletions torch/csrc/lazy/core/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

#include <ATen/AccumulateType.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Functions.h>
Expand Down Expand Up @@ -1304,7 +1305,7 @@ std::vector<Shape> compute_shape_select_scatter(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::select_scatter(
auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter(
self_meta, src_meta, dim, index);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
Expand All @@ -1329,7 +1330,7 @@ std::vector<Shape> compute_shape_diagonal_scatter(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::diagonal_scatter(
auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter(
self_meta, src_meta, offset, dim1, dim2);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
Expand All @@ -1355,8 +1356,9 @@ std::vector<Shape> compute_shape_slice_scatter_symint(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::slice_scatter_symint(
self_meta, src_meta, dim, start, end, step);
auto out_meta =
at::compositeexplicitautogradnonfunctional::slice_scatter_symint(
self_meta, src_meta, dim, start, end, step);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}

Expand All @@ -1380,8 +1382,9 @@ std::vector<Shape> compute_shape_as_strided_scatter_symint(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::as_strided_scatter_symint(
self_meta, src_meta, size, stride, storage_offset);
auto out_meta =
at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint(
self_meta, src_meta, size, stride, storage_offset);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/utils/tensor_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ Tensor internal_new_from_data(
// to dispatch to it.
// TODO: arguably it should have an autograd implementation that noops
at::AutoDispatchBelowADInplaceOrView guard;

return at::lift_fresh(tensor);
}

Expand Down
13 changes: 12 additions & 1 deletion torchgen/dest/register_dispatch_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,21 @@ def gen_out_inplace_wrapper(
for i, ret_name in enumerate(return_names)
)
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
else:
elif len(return_names) == 1:
ret_name = return_names[0]
updates = f"{copy_op}({func_res}, {ret_name});"
returns = ret_name
else:
assert len(f.func.arguments.out) == 1
returns = ""
out_arg = f.func.arguments.out[0]
if out_arg.type.is_list_like():
updates = f"""\
for (int64_t i = 0; i < {func_res}.size(); ++i) {{
{copy_op}({func_res}[i], {out_arg.name}[i]);
}}"""
else:
updates = f"{copy_op}({func_res}, {out_arg.name});"

functional_sig = self.wrapper_kernel_sig(g.functional)
wrapper_name = sig.name()
Expand Down
17 changes: 16 additions & 1 deletion torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,11 @@ def emit_inplace_functionalization_body(
for a in f.func.arguments.flat_all
if a.type.is_tensor_like() and a.annotation is None
]
non_mutated_tensor_names = [
a.name
for a in f.func.arguments.flat_all
if a.type == BaseType(BaseTy.Tensor) and a.annotation is None
]
# all mutable inputs must be functional tensors in order to participate in functionalization
check_all_mutated_args_are_functional = " && ".join(
["true"]
Expand All @@ -556,6 +561,14 @@ def emit_inplace_functionalization_body(
for a in non_mutated_names
]
)

check_any_non_mutated_tensors_are_xla = " || ".join(
["false"]
+ [
f"{a}.device().type() == c10::DeviceType::XLA"
for a in non_mutated_tensor_names
]
)
# These are used in the cases where we don't functionalize and redispatch to the inplace op
# case 1: we hit an inplace op that doesn't have an out-of-place equivalent
# case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
Expand Down Expand Up @@ -619,7 +632,9 @@ def emit_inplace_functionalization_body(
}}
{unwrap_tensor_args_str}
if (!({check_all_mutated_args_are_functional})) {{
if (({check_any_non_mutated_args_are_functional})) {{
// We want to disable this check if there are any XLA tensors.
// cpu_tensor.copy_(xla_tensor) is valid code.
if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{
// case 1: trying to mutate a non functional tensor with a functional tensor is an error
TORCH_INTERNAL_ASSERT(false,
"mutating a non-functional tensor with a functional tensor is not allowed.",
Expand Down