Skip to content

Commit

Permalink
Striding for lists
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: b8f40e468b77d1b93324fe21c8a2d8d9e7248c7e
Pull Request resolved: #48719
  • Loading branch information
tugsbayasgalan committed Dec 3, 2020
1 parent 1c02be1 commit af48043
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 44 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/BatchingRegistrations.cpp
Expand Up @@ -314,7 +314,7 @@ Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes
return grad_physical.newLogicalFromPhysical(grad_input);
}

Tensor slice_batching_rule(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice_batching_rule(const Tensor& self, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, c10::optional<int64_t> step) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().slice(dim_physical, start, end, step);
Expand Down
56 changes: 34 additions & 22 deletions aten/src/ATen/native/TensorShape.cpp
@@ -1,4 +1,5 @@
#include <algorithm>
#include <cstdint>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
Expand Down Expand Up @@ -1070,36 +1071,47 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index)
}
}

Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice(const Tensor& self, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, c10::optional<int64_t> step) {
int64_t ndim = self.dim();
if (ndim == 0) {
TORCH_CHECK_INDEX(false, "slice() cannot be applied to a 0-dim tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();

// handle optinal parameters
int64_t start_val = start.has_value() ? start.value() : INT64_MAX;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
int64_t step_val = step.has_value() ? step.value() : INT64_MAX;

// TODO: support negative strides
TORCH_CHECK(step > 0, "slice step must be positive");
if (start < 0) {
start += sizes[dim];
}
if (end < 0) {
end += sizes[dim];
}
if (start < 0) {
start = 0;
} else if (start >= sizes[dim]) {
start = sizes[dim];
}
if (end < start) {
end = start;
} else if (end >= sizes[dim]) {
end = sizes[dim];
}
auto storage_offset = self.storage_offset() + start * strides[dim];
auto len = end - start;
sizes[dim] = (len + step - 1) / step; // round-up
strides[dim] *= step;
TORCH_CHECK(step_val > 0, "slice step must be positive");

// INT64_MAX stands for default value.
if (start_val == INT64_MAX) {
start_val = 0;
}
if (start_val < 0) {
start_val += sizes[dim];
}
if (end_val < 0) {
end_val += sizes[dim];
}
if (start_val < 0) {
start_val = 0;
} else if (start_val >= sizes[dim]) {
start_val = sizes[dim];
}
if (end_val < start_val) {
end_val = start_val;
} else if (end_val >= sizes[dim]) {
end_val = sizes[dim];
}
auto storage_offset = self.storage_offset() + start_val * strides[dim];
auto len = end_val - start_val;
sizes[dim] = (len + step_val - 1) / step_val; // round-up
strides[dim] *= step_val;
auto result = self.as_strided(sizes, strides, storage_offset);
namedinference::propagate_names(result, self);
return result;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3550,7 +3550,7 @@
variants: function, method
device_guard: False

- func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int? step=1) -> Tensor(a)
use_c10_dispatcher: full
variants: function, method
device_guard: False
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Expand Up @@ -964,7 +964,7 @@
- name: sinh(Tensor self) -> Tensor
self: grad * self.cosh().conj()

- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int? step=1) -> Tensor(a)
self: slice_backward(grad, self.sizes(), dim, start, end, step)

- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/frontend/tree_views.h
Expand Up @@ -928,15 +928,15 @@ struct SliceExpr : public Expr {
Maybe<Expr> step() const {
return Maybe<Expr>(subtree(2));
}
Expr startOr(int alternative) const {
Expr startOr(int64_t alternative) const {
const auto startOption = start();
return startOption.present() ? startOption.get() : createInt(alternative);
}
Expr endOr(int alternative) const {
Expr endOr(int64_t alternative) const {
const auto endOption = end();
return endOption.present() ? endOption.get() : createInt(alternative);
}
Expr stepOr(int alternative) const {
Expr stepOr(int64_t alternative) const {
const auto stepOption = step();
return stepOption.present() ? stepOption.get() : createInt(alternative);
}
Expand All @@ -950,7 +950,7 @@ struct SliceExpr : public Expr {
}

private:
Expr createInt(int value) const {
Expr createInt(int64_t value) const {
return Expr(Const::create(range(), c10::to_string(value)));
}
};
Expand Down
68 changes: 54 additions & 14 deletions torch/csrc/jit/runtime/register_ops_utils.cpp
Expand Up @@ -425,6 +425,56 @@ void listMulIntRight(Stack* stack) {

push(stack, std::move(ret));
}
// "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020 Python Software Foundation; All Rights Reserved"
// Stolen (with appropriate modifications) from cpython repo
// Objects/sliceobject.c with comment:
// this is harder to get right than you might think
//
// This adjusts indexes according to python list semantics and returns number
// of elements in the resulting list.
static int64_t PySlice_AdjustIndices(
int64_t length, int64_t* start, int64_t* stop, int64_t step) {
TORCH_CHECK(step != 0, "List slice should have non-zero step")
TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds")

// Comes from PySlice_Unpack.
if (*start == INT64_MAX) {
*start = (step < 0) ? INT64_MAX : 0;
}
if (*stop == INT64_MAX) {
*stop = (step < 0) ? INT64_MIN : INT64_MAX;
}

// Comes from PySlice_AdjustIndices.
if (*start < 0) {
*start += length;
if (*start < 0) {
*start = (step < 0) ? -1 : 0;
}
} else if (*start >= length) {
*start = (step < 0) ? length - 1 : length;
}

if (*stop < 0) {
*stop += length;
if (*stop < 0) {
*stop = (step < 0) ? -1 : 0;
}
} else if (*stop >= length) {
*stop = (step < 0) ? length - 1 : length;
}

if (step < 0) {
if (*stop < *start) {
return (*start - *stop - 1) / (-step) + 1;
}
} else {
if (*start < *stop) {
return (*stop - *start - 1) / step + 1;
}
}
return 0;
}

void listSlice(Stack* stack) {
int64_t step = pop(stack).to<int64_t>();
Expand All @@ -434,22 +484,12 @@ void listSlice(Stack* stack) {

const int64_t list_size = list.size();

// clamp start and end to the bounds of the list
const auto normalized_start =
std::max((int64_t)0, normalizeIndex(start, list_size));
const auto normalized_end =
std::min(list_size, normalizeIndex(end, list_size));

c10::List<IValue> sliced_list = make_result_list<IValue>(list.elementType());
if (normalized_end <= normalized_start) {
// early exit if the slice is trivially empty
push(stack, std::move(sliced_list));
return;
}

sliced_list.reserve(normalized_end - normalized_start);
const int64_t num_values = PySlice_AdjustIndices(list_size, &start, &end, step);
sliced_list.reserve(num_values);

for (auto i = normalized_start; i < normalized_end;) {
int i = start;
for (int j = 0; j < num_values; ++j) {
sliced_list.push_back(list.get(i));
i += step;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/runtime/register_prim_ops.cpp
Expand Up @@ -603,7 +603,7 @@ RegisterOperators reg(
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> t[]"),
"aten::slice.t(t[] l, int? start, int? end=9223372036854775807, int? step=1) -> t[]"),
listSlice,
aliasAnalysisFromSchema()),
OperatorGenerator(
Expand Down

0 comments on commit af48043

Please sign in to comment.