From f605d7f956bd3c5b9b1718ff407a132a6b7968c6 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Tue, 8 Dec 2020 08:53:50 -0800 Subject: [PATCH] Striding for lists Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: fdf82e4c58b24e22ab5787490d26803537761de8 Pull Request resolved: https://github.com/pytorch/pytorch/pull/48719 --- aten/src/ATen/BatchingRegistrations.cpp | 7 +- aten/src/ATen/native/TensorShape.cpp | 80 +++++++++++-------- aten/src/ATen/native/native_functions.yaml | 2 +- tools/autograd/derivatives.yaml | 4 +- torch/csrc/autograd/FunctionsManual.cpp | 16 +++- torch/csrc/autograd/FunctionsManual.h | 8 +- torch/csrc/jit/frontend/tree_views.h | 8 +- torch/csrc/jit/runtime/register_ops_utils.cpp | 73 +++++++++++++---- torch/csrc/jit/runtime/register_prim_ops.cpp | 2 +- 9 files changed, 143 insertions(+), 57 deletions(-) diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index 16470f39ad54..e9ef9da9883f 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -314,7 +314,12 @@ Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes return grad_physical.newLogicalFromPhysical(grad_input); } -Tensor slice_batching_rule(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) { +Tensor slice_batching_rule( + const Tensor& self, + int64_t dim, + c10::optional start, + c10::optional end, + c10::optional step) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = self_physical.tensor().slice(dim_physical, start, end, step); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index eda688ad6e1d..d7bbcc79b9b2 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,23 +1,23 @@ -#include -#include #include #include #include #include +#include +#include #include +#include #include -#include -#include +#include #include -#include -#include -#include -#include #include #include #include -#include -#include +#include +#include +#include +#include +#include +#include namespace at { namespace native { @@ -1070,7 +1070,12 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) } } -Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) { +Tensor slice( + const Tensor& self, + int64_t dim, + c10::optional start, + c10::optional end, + c10::optional step) { int64_t ndim = self.dim(); if (ndim == 0) { TORCH_CHECK_INDEX(false, "slice() cannot be applied to a 0-dim tensor."); @@ -1078,28 +1083,39 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_ dim = maybe_wrap_dim(dim, ndim); auto sizes = self.sizes().vec(); auto strides = self.strides().vec(); + + // handle optinal parameters + int64_t start_val = start.has_value() ? start.value() : INT64_MAX; + int64_t end_val = end.has_value() ? end.value() : INT64_MAX; + int64_t step_val = step.has_value() ? step.value() : INT64_MAX; + // TODO: support negative strides - TORCH_CHECK(step > 0, "slice step must be positive"); - if (start < 0) { - start += sizes[dim]; - } - if (end < 0) { - end += sizes[dim]; - } - if (start < 0) { - start = 0; - } else if (start >= sizes[dim]) { - start = sizes[dim]; - } - if (end < start) { - end = start; - } else if (end >= sizes[dim]) { - end = sizes[dim]; - } - auto storage_offset = self.storage_offset() + start * strides[dim]; - auto len = end - start; - sizes[dim] = (len + step - 1) / step; // round-up - strides[dim] *= step; + TORCH_CHECK(step_val > 0, "slice step must be positive"); + + // INT64_MAX stands for default value. + if (start_val == INT64_MAX) { + start_val = 0; + } + if (start_val < 0) { + start_val += sizes[dim]; + } + if (end_val < 0) { + end_val += sizes[dim]; + } + if (start_val < 0) { + start_val = 0; + } else if (start_val >= sizes[dim]) { + start_val = sizes[dim]; + } + if (end_val < start_val) { + end_val = start_val; + } else if (end_val >= sizes[dim]) { + end_val = sizes[dim]; + } + auto storage_offset = self.storage_offset() + start_val * strides[dim]; + auto len = end_val - start_val; + sizes[dim] = (len + step_val - 1) / step_val; // round-up + strides[dim] *= step_val; auto result = self.as_strided(sizes, strides, storage_offset); namedinference::propagate_names(result, self); return result; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e7ac20599214..4e5c6e8d6b91 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3548,7 +3548,7 @@ variants: function, method device_guard: False -- func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a) +- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int? step=1) -> Tensor(a) use_c10_dispatcher: full variants: function, method device_guard: False diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index dadfe6018939..95816dab65e7 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -964,8 +964,8 @@ - name: sinh(Tensor self) -> Tensor self: grad * self.cosh().conj() -- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a) - self: slice_backward(grad, self.sizes(), dim, start, end, step) +- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int? step=1) -> Tensor(a) + self: slice_backward_wrapper(grad, self.sizes(), dim, start, end, step) - name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) self: slogdet_backward(grad, self, sign, logabsdet) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 91bad195f47e..3929f5b9aa58 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1821,6 +1821,20 @@ std::tuple prelu_double_backward( } } +Tensor slice_backward_wrapper( + const at::Tensor& grad, + const c10::IntArrayRef input_sizes, + int64_t dim, + c10::optional start, + c10::optional end, + c10::optional step) { + auto start_val = start.has_value() ? start.value() : INT64_MAX; + auto end_val = end.has_value() ? end.value() : INT64_MAX; + auto step_val = step.has_value() ? step.value() : 1; + + return slice_backward(grad, input_sizes, dim, start_val, end_val, step_val); +} + // https://j-towns.github.io/papers/svd-derivative.pdf // // This makes no assumption on the signs of sigma. @@ -2802,7 +2816,7 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { } Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) { - // since first backward takes care of scaling by frequency, + // since first backward takes care of scaling by frequency, // we don't need to worry about it here. auto gg_weight = grad.index_select(0, indices.reshape(-1)); diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 0fba31bdd894..6bdc65927642 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -121,9 +121,15 @@ at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx); at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad); at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); - Tensor svd_backward(const std::vector &grads, const Tensor& self, bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v); +Tensor slice_backward_wrapper( + const at::Tensor& grad, + const c10::IntArrayRef input_sizes, + int64_t dim, + c10::optional start, + c10::optional end, + c10::optional step); Tensor symeig_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v); std::tuple triangular_solve_backward( diff --git a/torch/csrc/jit/frontend/tree_views.h b/torch/csrc/jit/frontend/tree_views.h index e33d93f37566..a56b9e893218 100644 --- a/torch/csrc/jit/frontend/tree_views.h +++ b/torch/csrc/jit/frontend/tree_views.h @@ -929,15 +929,15 @@ struct SliceExpr : public Expr { Maybe step() const { return Maybe(subtree(2)); } - Expr startOr(int alternative) const { + Expr startOr(int64_t alternative) const { const auto startOption = start(); return startOption.present() ? startOption.get() : createInt(alternative); } - Expr endOr(int alternative) const { + Expr endOr(int64_t alternative) const { const auto endOption = end(); return endOption.present() ? endOption.get() : createInt(alternative); } - Expr stepOr(int alternative) const { + Expr stepOr(int64_t alternative) const { const auto stepOption = step(); return stepOption.present() ? stepOption.get() : createInt(alternative); } @@ -951,7 +951,7 @@ struct SliceExpr : public Expr { } private: - Expr createInt(int value) const { + Expr createInt(int64_t value) const { return Expr(Const::create(range(), c10::to_string(value))); } }; diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index 95f8dcee8763..5da0b8e8a9f3 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -425,6 +425,60 @@ void listMulIntRight(Stack* stack) { push(stack, std::move(ret)); } +// "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +// 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software +// Foundation; All Rights Reserved" Stolen (with appropriate modifications) from +// cpython repo Objects/sliceobject.c with comment: this is harder to get right +// than you might think +// +// This adjusts indexes according to python list semantics and returns number +// of elements in the resulting list. +static int64_t PySlice_AdjustIndices( + int64_t length, + int64_t* start, + int64_t* stop, + int64_t step) { + TORCH_CHECK(step != 0, "List slice should have non-zero step") + TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds") + + // Comes from PySlice_Unpack. + if (*start == INT64_MAX) { + *start = (step < 0) ? INT64_MAX : 0; + } + if (*stop == INT64_MAX) { + *stop = (step < 0) ? INT64_MIN : INT64_MAX; + } + + // Comes from PySlice_AdjustIndices. + if (*start < 0) { + *start += length; + if (*start < 0) { + *start = (step < 0) ? -1 : 0; + } + } else if (*start >= length) { + *start = (step < 0) ? length - 1 : length; + } + + if (*stop < 0) { + *stop += length; + if (*stop < 0) { + *stop = (step < 0) ? -1 : 0; + } + } else if (*stop >= length) { + *stop = (step < 0) ? length - 1 : length; + } + + if (step < 0) { + if (*stop < *start) { + return (*start - *stop - 1) / (-step) + 1; + } + } else { + if (*start < *stop) { + return (*stop - *start - 1) / step + 1; + } + } + return 0; +} void listSlice(Stack* stack) { int64_t step = pop(stack).to(); @@ -434,22 +488,13 @@ void listSlice(Stack* stack) { const int64_t list_size = list.size(); - // clamp start and end to the bounds of the list - const auto normalized_start = - std::max((int64_t)0, normalizeIndex(start, list_size)); - const auto normalized_end = - std::min(list_size, normalizeIndex(end, list_size)); - c10::List sliced_list = make_result_list(list.elementType()); - if (normalized_end <= normalized_start) { - // early exit if the slice is trivially empty - push(stack, std::move(sliced_list)); - return; - } - - sliced_list.reserve(normalized_end - normalized_start); + const int64_t num_values = + PySlice_AdjustIndices(list_size, &start, &end, step); + sliced_list.reserve(num_values); - for (auto i = normalized_start; i < normalized_end;) { + int i = start; + for (int j = 0; j < num_values; ++j) { sliced_list.push_back(list.get(i)); i += step; } diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index d9bffa7e4644..a486572e7f8c 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -603,7 +603,7 @@ RegisterOperators reg( aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA( - "aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> t[]"), + "aten::slice.t(t[] l, int? start, int? end=9223372036854775807, int? step=1) -> t[]"), listSlice, aliasAnalysisFromSchema()), OperatorGenerator(