Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change list striding parameters to take optional integer (preceding work for #49352) #48719

Closed
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion aten/src/ATen/BatchingRegistrations.cpp
Expand Up @@ -314,7 +314,12 @@ Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes
return grad_physical.newLogicalFromPhysical(grad_input);
}

Tensor slice_batching_rule(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice_batching_rule(
tugsbayasgalan marked this conversation as resolved.
Show resolved Hide resolved
const Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
c10::optional<int64_t> step) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().slice(dim_physical, start, end, step);
Expand Down
80 changes: 48 additions & 32 deletions aten/src/ATen/native/TensorShape.cpp
@@ -1,23 +1,23 @@
#include <algorithm>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/ExpandUtils.h>
#include <ATen/InferSize.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <ATen/native/Copy.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/cpu/CatKernel.h>
#include <ATen/native/Copy.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/quantized/QTensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <algorithm>
#include <cstdint>
#include <vector>

namespace at {
namespace native {
Expand Down Expand Up @@ -1070,36 +1070,52 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index)
}
}

Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice(
const Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
c10::optional<int64_t> step) {
tugsbayasgalan marked this conversation as resolved.
Show resolved Hide resolved
int64_t ndim = self.dim();
if (ndim == 0) {
TORCH_CHECK_INDEX(false, "slice() cannot be applied to a 0-dim tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();

// handle optinal parameters
tugsbayasgalan marked this conversation as resolved.
Show resolved Hide resolved
int64_t start_val = start.has_value() ? start.value() : INT64_MAX;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
int64_t step_val = step.has_value() ? step.value() : INT64_MAX;

// TODO: support negative strides
TORCH_CHECK(step > 0, "slice step must be positive");
tugsbayasgalan marked this conversation as resolved.
Show resolved Hide resolved
if (start < 0) {
start += sizes[dim];
}
if (end < 0) {
end += sizes[dim];
}
if (start < 0) {
start = 0;
} else if (start >= sizes[dim]) {
start = sizes[dim];
}
if (end < start) {
end = start;
} else if (end >= sizes[dim]) {
end = sizes[dim];
}
auto storage_offset = self.storage_offset() + start * strides[dim];
auto len = end - start;
sizes[dim] = (len + step - 1) / step; // round-up
strides[dim] *= step;
TORCH_CHECK(step_val > 0, "slice step must be positive");

// INT64_MAX stands for default value.
if (start_val == INT64_MAX) {
start_val = 0;
}
if (start_val < 0) {
start_val += sizes[dim];
}
if (end_val < 0) {
end_val += sizes[dim];
}
if (start_val < 0) {
start_val = 0;
} else if (start_val >= sizes[dim]) {
start_val = sizes[dim];
}
if (end_val < start_val) {
end_val = start_val;
} else if (end_val >= sizes[dim]) {
end_val = sizes[dim];
}
auto storage_offset = self.storage_offset() + start_val * strides[dim];
auto len = end_val - start_val;
sizes[dim] = (len + step_val - 1) / step_val; // round-up
strides[dim] *= step_val;
auto result = self.as_strided(sizes, strides, storage_offset);
namedinference::propagate_names(result, self);
return result;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3548,7 +3548,7 @@
variants: function, method
device_guard: False

- func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int? step=1) -> Tensor(a)
use_c10_dispatcher: full
variants: function, method
device_guard: False
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/derivatives.yaml
Expand Up @@ -964,8 +964,8 @@
- name: sinh(Tensor self) -> Tensor
self: grad * self.cosh().conj()

- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
self: slice_backward(grad, self.sizes(), dim, start, end, step)
- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int? step=1) -> Tensor(a)
self: slice_backward_wrapper(grad, self.sizes(), dim, start, end, step)

- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
self: slogdet_backward(grad, self, sign, logabsdet)
Expand Down
16 changes: 15 additions & 1 deletion torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -1821,6 +1821,20 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
}
}

Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
c10::optional<int64_t> step) {
auto start_val = start.has_value() ? start.value() : INT64_MAX;
auto end_val = end.has_value() ? end.value() : INT64_MAX;
auto step_val = step.has_value() ? step.value() : 1;

return slice_backward(grad, input_sizes, dim, start_val, end_val, step_val);
}

// https://j-towns.github.io/papers/svd-derivative.pdf
//
// This makes no assumption on the signs of sigma.
Expand Down Expand Up @@ -2802,7 +2816,7 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
}

Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) {
// since first backward takes care of scaling by frequency,
// since first backward takes care of scaling by frequency,
// we don't need to worry about it here.
auto gg_weight = grad.index_select(0, indices.reshape(-1));

Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -121,9 +121,15 @@ at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out,
at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx);
at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad);
at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity);

Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v);
Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
c10::optional<int64_t> step);
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v);
std::tuple<Tensor, Tensor> triangular_solve_backward(
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/frontend/tree_views.h
Expand Up @@ -929,15 +929,15 @@ struct SliceExpr : public Expr {
Maybe<Expr> step() const {
return Maybe<Expr>(subtree(2));
}
Expr startOr(int alternative) const {
Expr startOr(int64_t alternative) const {
const auto startOption = start();
return startOption.present() ? startOption.get() : createInt(alternative);
}
Expr endOr(int alternative) const {
Expr endOr(int64_t alternative) const {
const auto endOption = end();
return endOption.present() ? endOption.get() : createInt(alternative);
}
Expr stepOr(int alternative) const {
Expr stepOr(int64_t alternative) const {
const auto stepOption = step();
return stepOption.present() ? stepOption.get() : createInt(alternative);
}
Expand All @@ -951,7 +951,7 @@ struct SliceExpr : public Expr {
}

private:
Expr createInt(int value) const {
Expr createInt(int64_t value) const {
return Expr(Const::create(range(), c10::to_string(value)));
}
};
Expand Down
73 changes: 59 additions & 14 deletions torch/csrc/jit/runtime/register_ops_utils.cpp
Expand Up @@ -425,6 +425,60 @@ void listMulIntRight(Stack* stack) {

push(stack, std::move(ret));
}
// "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
// 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software
// Foundation; All Rights Reserved" Stolen (with appropriate modifications) from
// cpython repo Objects/sliceobject.c with comment: this is harder to get right
// than you might think
//
// This adjusts indexes according to python list semantics and returns number
// of elements in the resulting list.
static int64_t PySlice_AdjustIndices(
int64_t length,
int64_t* start,
int64_t* stop,
int64_t step) {
TORCH_CHECK(step != 0, "List slice should have non-zero step")
TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds")

// Comes from PySlice_Unpack.
if (*start == INT64_MAX) {
*start = (step < 0) ? INT64_MAX : 0;
}
if (*stop == INT64_MAX) {
*stop = (step < 0) ? INT64_MIN : INT64_MAX;
}

// Comes from PySlice_AdjustIndices.
if (*start < 0) {
*start += length;
if (*start < 0) {
*start = (step < 0) ? -1 : 0;
}
} else if (*start >= length) {
*start = (step < 0) ? length - 1 : length;
}

if (*stop < 0) {
*stop += length;
if (*stop < 0) {
*stop = (step < 0) ? -1 : 0;
}
} else if (*stop >= length) {
*stop = (step < 0) ? length - 1 : length;
}

if (step < 0) {
if (*stop < *start) {
return (*start - *stop - 1) / (-step) + 1;
}
} else {
if (*start < *stop) {
return (*stop - *start - 1) / step + 1;
}
}
return 0;
}

void listSlice(Stack* stack) {
int64_t step = pop(stack).to<int64_t>();
Expand All @@ -434,22 +488,13 @@ void listSlice(Stack* stack) {

const int64_t list_size = list.size();

// clamp start and end to the bounds of the list
const auto normalized_start =
std::max((int64_t)0, normalizeIndex(start, list_size));
const auto normalized_end =
std::min(list_size, normalizeIndex(end, list_size));

c10::List<IValue> sliced_list = make_result_list<IValue>(list.elementType());
if (normalized_end <= normalized_start) {
// early exit if the slice is trivially empty
push(stack, std::move(sliced_list));
return;
}

sliced_list.reserve(normalized_end - normalized_start);
const int64_t num_values =
PySlice_AdjustIndices(list_size, &start, &end, step);
sliced_list.reserve(num_values);

for (auto i = normalized_start; i < normalized_end;) {
int i = start;
for (int j = 0; j < num_values; ++j) {
sliced_list.push_back(list.get(i));
i += step;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/runtime/register_prim_ops.cpp
Expand Up @@ -603,7 +603,7 @@ RegisterOperators reg(
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> t[]"),
"aten::slice.t(t[] l, int? start, int? end=9223372036854775807, int? step=1) -> t[]"),
listSlice,
aliasAnalysisFromSchema()),
OperatorGenerator(
Expand Down