Skip to content

Commit

Permalink
Revert "Revert "Symintifying slice ops (#85196)"" (#85746)
Browse files Browse the repository at this point in the history
This reverts commit 3a171df.
Pull Request resolved: #85746
Approved by: https://github.com/albanD
  • Loading branch information
ezyang authored and pytorchmergebot committed Sep 28, 2022
1 parent 049be5a commit 793488c
Show file tree
Hide file tree
Showing 20 changed files with 153 additions and 71 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/FunctionalInverses.cpp
Expand Up @@ -172,9 +172,9 @@ Tensor FunctionalInverses::lift_fresh_copy_inverse(const Tensor& base, const Ten
return mutated_view;
}

Tensor FunctionalInverses::slice_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step) {
Tensor FunctionalInverses::slice_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) {
// Pessimism: we can't reapply views for slice_scatter.
return base.slice_scatter(mutated_view, dim, start, end, step);
return base.slice_scatter_symint(mutated_view, dim, start, end, step);
}

Tensor FunctionalInverses::split_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, int64_t split_size, int64_t dim) {
Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/functorch/BatchRulesViews.cpp
Expand Up @@ -346,13 +346,13 @@ std::tuple<Tensor,optional<int64_t>> slice_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
c10::optional<c10::SymInt> start,
c10::optional<c10::SymInt> end,
c10::SymInt step) {
auto self_ = moveBatchDimToFront(self, self_bdim);
dim = getPhysicalDim(self, self_bdim.has_value(), dim);

auto result = self_.slice(dim, start, end, step);
auto result = self_.slice_symint(dim, start, end, step);
return std::make_tuple(result, 0);
}

Expand Down Expand Up @@ -415,14 +415,14 @@ std::tuple<Tensor,optional<int64_t>> select_backward_batch_rule(

std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
const Tensor& grad_input, optional<int64_t> grad_input_bdim,
IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) {
auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
dim = maybe_wrap_dim(dim, logical_rank) + 1;
c10::SmallBuffer<int64_t, 5> input_sizes_(input_sizes.size() + 1);
c10::SymDimVector input_sizes_(input_sizes.size() + 1);
input_sizes_[0] = grad_input_.size(0);
std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
auto result = at::slice_backward(grad_input_, input_sizes_, dim, start, end, step);
auto result = at::slice_backward_symint(grad_input_, input_sizes_, dim, start, end, step);
return std::make_tuple(std::move(result), 0);
}

Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -4682,7 +4682,7 @@
device_check: NoCheck
device_guard: False

- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)
- func: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
variants: function, method
device_check: NoCheck
device_guard: False
Expand All @@ -4691,15 +4691,15 @@
# NOTE: The implementation of split_with_sizes bypasses the dispatcher to call this; undo
# that if adding specific implementations here!

- func: slice_backward(Tensor grad_output, int[] input_sizes, int dim, int start, int end, int step) -> Tensor
- func: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor
variants: function
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: slice_backward
autogen: slice_backward.out

- func: slice_scatter(Tensor self, Tensor src, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor
- func: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
variants: function, method
device_check: NoCheck
device_guard: False
Expand Down Expand Up @@ -12806,7 +12806,7 @@
CompositeExplicitAutogradNonFunctional: detach_copy
tags: view_copy

- func: slice_copy.Tensor(Tensor self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor
- func: slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
variants: function
dispatch:
CompositeExplicitAutogradNonFunctional: slice_copy_Tensor
Expand Down Expand Up @@ -13018,7 +13018,7 @@
CompositeExplicitAutograd: detach_copy_out


- func: slice_copy.Tensor_out(Tensor self, int dim=0, int? start=None, int? end=None, int step=1, *, Tensor(a!) out) -> Tensor(a!)
- func: slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CompositeExplicitAutograd: slice_copy_Tensor_out
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/ts_native_functions.yaml
Expand Up @@ -208,6 +208,9 @@ symint:
- view_copy
- as_strided_copy
- as_strided_scatter
- slice_backward
- slice_copy.Tensor
- slice_scatter
- empty_strided
- new_empty_strided
autograd:
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/test/vulkan_api_test.cpp
Expand Up @@ -173,7 +173,8 @@ static void slice_tests(const std::unordered_map<int64_t, std::vector<int64_t>>&
slice_test(dim2size.second, dim2size.first, -30, -10, 1); // i.e., 4D tensor's equivalent indexing = [:,:,:,-30:-10:1] with negative start/end
slice_test(dim2size.second, dim2size.first, 0, INT64_MAX, 1); // i.e., 4D 's equivalent indexing = [:,:,:,0:9223372036854775807:1] with end=INT64_MAX
slice_test(dim2size.second, dim2size.first, -10, INT64_MAX, 1); // i.e., 4D 's equivalent indexing = [:,:,:,-10:9223372036854775807:1] with negative start and end=INT64_MAX
slice_test(dim2size.second, dim2size.first, INT64_MIN, INT64_MAX, 1); // i.e., 4D 's equivalent indexing = [:,:,:,-9223372036854775808:9223372036854775807:1] with start=INT64_MIN and end=INT64_MAX
// This triggers a SymInt assert since [-2^63, -2^62-1] range is reserved for packed symints
//slice_test(dim2size.second, dim2size.first, INT64_MIN, INT64_MAX, 1); // i.e., 4D 's equivalent indexing = [:,:,:,-9223372036854775808:9223372036854775807:1] with start=INT64_MIN and end=INT64_MAX
slice_test(dim2size.second, dim2size.first, {}, {}, 1); // i.e., 4D 's equivalent indexing = [:,:,:,::1] with empty start/end
}
}
Expand Down
41 changes: 24 additions & 17 deletions test/lazy/test_ts_opinfo.py
Expand Up @@ -33,6 +33,7 @@ def init_lists():
with open(TS_NATIVE_FUNCTIONS_PATH) as f:
yaml_ts = yaml.load(f, yaml.Loader)
LAZY_OPS_LIST = set(remove_suffixes(itertools.chain(yaml_ts["full_codegen"], yaml_ts["supported"], yaml_ts["autograd"])))
HAS_SYMINT_SUFFIX = yaml_ts["symint"]
FALLBACK_LIST = set(["clamp"])
SKIP_RUNTIME_ERROR_LIST = set([
'index_select', # Empty output_sizes is not supported
Expand Down Expand Up @@ -70,9 +71,19 @@ def init_lists():
'logsumexp',
])

return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST)
return (LAZY_OPS_LIST,
FALLBACK_LIST,
SKIP_RUNTIME_ERROR_LIST,
SKIP_INCORRECT_RESULTS_LIST,
FUNCTIONAL_DECOMPOSE_LIST,
HAS_SYMINT_SUFFIX)

(LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST) = init_lists()
(LAZY_OPS_LIST,
FALLBACK_LIST,
SKIP_RUNTIME_ERROR_LIST,
SKIP_INCORRECT_RESULTS_LIST,
FUNCTIONAL_DECOMPOSE_LIST,
HAS_SYMINT_SUFFIX) = init_lists()

torch.manual_seed(42)

Expand Down Expand Up @@ -162,7 +173,7 @@ def get_name(op):
l.append(op.variant_test_name)
return '.'.join(l)

global FALLBACK_LIST
global HAS_SYMINT_SUFFIX, FALLBACK_LIST
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
sample = list(samples)[0]
args = [sample.input] + list(sample.args)
Expand All @@ -175,20 +186,16 @@ def get_name(op):
torch._lazy.mark_step()
torch._lazy.wait_device_ops()
prefix = "aten" if op.name in FALLBACK_LIST else "lazy"
cand_names = [
f"{prefix}::{op.name}",
f"{prefix}::{op.name}_symint",
]
for alias in op.aliases:
cand_names.extend([
f"{prefix}::{alias.name}",
f"{prefix}::{alias.name}_symint",
])
for name in cand_names:
if name in remove_suffixes(torch._lazy.metrics.counter_names()):
break
else:
self.fail("none of {cand_names} was not found in: {remove_suffixes(torch._lazy.metrics.counter_names())}")
symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else ""
found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes(torch._lazy.metrics.counter_names())
# check aliases
if not found:
for alias in op.aliases:
alias_found = f"{prefix}::{alias.name}{symint_suffix}" in remove_suffixes(torch._lazy.metrics.counter_names())
found = found or alias_found
if found:
break
self.assertTrue(found)


@ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST], allowed_dtypes=(torch.float,)) # noqa: B950
Expand Down
1 change: 0 additions & 1 deletion test/test_decomp.py
Expand Up @@ -514,7 +514,6 @@ def check_decomposed(aten_name):
"only backwards is decomposed, but dtype doesn't support AD"
)


instantiate_device_type_tests(TestDecomp, globals())

if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Expand Up @@ -1264,7 +1264,6 @@ def f(a, b, c, d, e):
xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition
xfail('select', ''), # aten.select.int - couldn't find symbolic meta function/decomposition
xfail('select_scatter', ''), # aten.select_scatter.default - couldn't find symbolic meta function/decomposition
xfail('slice', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition
xfail('slice_scatter', ''), # aten.slice_scatter.default - couldn't find symbolic meta function/decomposition
xfail('sort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition
Expand Down
14 changes: 7 additions & 7 deletions tools/autograd/derivatives.yaml
Expand Up @@ -1429,17 +1429,17 @@
self: grad * self.cosh().conj()
result: auto_element_wise

- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)
self: slice_backward_wrapper(grad, self.sizes(), dim, start, end, step)
- name: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
self: slice_backward_wrapper(grad, self.sym_sizes(), dim, start, end, step)
result: auto_linear

- name: slice_backward(Tensor grad_output, int[] input_sizes, int dim, int start, int end, int step) -> Tensor
grad_output: grad.slice(dim, start, end, step)
- name: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor
grad_output: grad.slice_symint(dim, start, end, step)
result: auto_linear

- name: slice_scatter(Tensor self, Tensor src, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor
self: slice_scatter(grad, zeros_like(src), dim, start, end, step)
src: grad.slice(dim, start, end, step)
- name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step)
src: grad.slice_symint(dim, start, end, step)
result: auto_linear

- name: select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_inplace_or_view_type.py
Expand Up @@ -327,6 +327,7 @@ def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str
known_view_arg_simple_types: List[CType] = [
BaseCType(longT),
OptionalCType(BaseCType(longT)),
BaseCType(SymIntT),
OptionalCType(BaseCType(SymIntT)),
BaseCType(boolT),
BaseCType(intArrayRefT),
Expand Down
53 changes: 53 additions & 0 deletions torch/_decomp/decompositions.py
@@ -1,5 +1,6 @@
import functools
import operator
import sys
from enum import Enum
from itertools import product
from typing import Callable, cast, Iterable, List, Optional, Tuple
Expand Down Expand Up @@ -609,6 +610,58 @@ def slice_backward(
return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)


@register_decomposition(aten.slice.Tensor)
def slice_forward(
# Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1
self: Tensor,
dim: int = 0,
start: Optional[int] = None,
end: Optional[int] = None,
step: int = 1,
):

ndim = self.dim()
if ndim == 0:
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
dim = utils.canonicalize_dim(self.dim(), dim)
sizes = list(self.size())
strides = list(self.stride())

if step <= 0:
raise RuntimeError("slice step must be positive")

start_val = start if start is not None else 0
end_val = end if end is not None else sys.maxsize # 2^63 – 1

if start_val < 0:
start_val += sizes[dim]

if end_val < 0:
end_val += sizes[dim]

if start_val < 0:
start_val = 0
elif start_val >= sizes[dim]:
start_val = sizes[dim]

if end_val < start_val:
end_val = start_val
elif end_val >= sizes[dim]:
end_val = sizes[dim]

storage_offset = self.storage_offset() + start_val * strides[dim]
len = end_val - start_val
sizes[dim] = (len + step - 1) // step
strides[dim] *= step

if self.is_quantized:
raise NotImplementedError(
"Slice decomposition for quantized tensors aren't implemented"
)
else:
return self.as_strided(sizes, strides, storage_offset)


@register_decomposition(aten.select_backward)
def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
grad_input = grad_output.new_zeros(input_sizes)
Expand Down
11 changes: 6 additions & 5 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -3160,15 +3160,16 @@ Tensor elu_double_backward(

Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef& input_sizes,
const c10::SymIntArrayRef& input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
c10::optional<c10::SymInt> start,
c10::optional<c10::SymInt> end,
c10::SymInt step) {
auto start_val = start.has_value() ? start.value() : 0;
auto end_val = end.has_value() ? end.value() : INT64_MAX;

return slice_backward(grad, input_sizes, dim, start_val, end_val, step);
return slice_backward_symint(
grad, input_sizes, dim, start_val, end_val, step);
}

std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -555,11 +555,11 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(
const bool full_matrices);
Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef& input_sizes,
const c10::SymIntArrayRef& input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step);
c10::optional<c10::SymInt> start,
c10::optional<c10::SymInt> end,
c10::SymInt step);
std::tuple<Tensor, Tensor> linalg_eig_jvp(
const Tensor& dA,
const Tensor& L,
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/lazy/core/shape_inference.cpp
Expand Up @@ -1334,13 +1334,13 @@ std::vector<Shape> compute_shape_diagonal_scatter(
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}

std::vector<Shape> compute_shape_slice_scatter(
std::vector<Shape> compute_shape_slice_scatter_symint(
const at::Tensor& self,
const at::Tensor& src,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
c10::optional<c10::SymInt> start,
c10::optional<c10::SymInt> end,
c10::SymInt step) {
auto self_meta = at::native::empty_strided_meta_symint(
self.sym_sizes(),
self.sym_strides(),
Expand All @@ -1355,7 +1355,7 @@ std::vector<Shape> compute_shape_slice_scatter(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::slice_scatter(
auto out_meta = at::compositeexplicitautograd::slice_scatter_symint(
self_meta, src_meta, dim, start, end, step);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/lazy/core/shape_inference.h
Expand Up @@ -115,7 +115,7 @@ TORCH_API std::vector<Shape> compute_shape_unsqueeze(const Output& input, const

TORCH_API std::vector<torch::lazy::Shape> compute_shape_select_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_diagonal_scatter(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter_symint(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_as_strided_scatter_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset);
// clang-format on
} // namespace lazy
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/lazy/ts_backend/ts_native_functions.cpp
Expand Up @@ -532,13 +532,13 @@ at::Tensor LazyNativeFunctions::diagonal_backward(
diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
}

at::Tensor LazyNativeFunctions::slice_backward(
at::Tensor LazyNativeFunctions::slice_backward_symint(
const at::Tensor& grad_output,
at::IntArrayRef input_sizes,
at::SymIntArrayRef input_sizes,
int64_t dim,
int64_t start,
int64_t end,
int64_t step) {
c10::SymInt start,
c10::SymInt end,
c10::SymInt step) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
}
Expand Down

0 comments on commit 793488c

Please sign in to comment.