Skip to content

Commit

Permalink
Striding for lists Part 1 (#48719)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #48719

Attempt to break this PR (#33019) into two parts. As per our discussion with eellison,  the first part is to make sure our aten::slice operator take optional parameters for begin/step/end. This will help with refactoring ir_emitter.cpp for genering handling for list and slice striding. Once this PR merged, we will submit a second PR with compiler change.

Test Plan:
None for this PR, but new tests will be added for the second part.

Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D25929902

fbshipit-source-id: 5385df04e6d61ded0699b09bbfec6691396b56c3
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed Jan 19, 2021
1 parent 1154a85 commit 1a38fa9
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 65 deletions.
7 changes: 6 additions & 1 deletion aten/src/ATen/BatchingRegistrations.cpp
Expand Up @@ -359,7 +359,12 @@ Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
}

Tensor slice_batching_rule(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice_batching_rule(
const Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().slice(dim_physical, start, end, step);
Expand Down
69 changes: 42 additions & 27 deletions aten/src/ATen/native/TensorShape.cpp
@@ -1,24 +1,24 @@
#include <algorithm>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/ExpandUtils.h>
#include <ATen/InferSize.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/SmallVector.h>
#include <ATen/native/Copy.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/cpu/CatKernel.h>
#include <ATen/native/Copy.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/quantized/QTensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/SmallVector.h>
#include <algorithm>
#include <cstdint>
#include <vector>

namespace at {
namespace native {
Expand Down Expand Up @@ -1190,35 +1190,50 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index)
}
}

Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice(
const Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
int64_t ndim = self.dim();
if (ndim == 0) {
TORCH_CHECK_INDEX(false, "slice() cannot be applied to a 0-dim tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();

// handle optional parameters
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;

// TODO: support negative strides
TORCH_CHECK(step > 0, "slice step must be positive");
if (start < 0) {
start += sizes[dim];

// INT64_MAX stands for default value.
if (start_val == INT64_MAX) {
start_val = 0;
}
if (start_val < 0) {
start_val += sizes[dim];
}
if (end < 0) {
end += sizes[dim];
if (end_val < 0) {
end_val += sizes[dim];
}
if (start < 0) {
start = 0;
} else if (start >= sizes[dim]) {
start = sizes[dim];
if (start_val < 0) {
start_val = 0;
} else if (start_val >= sizes[dim]) {
start_val = sizes[dim];
}
if (end < start) {
end = start;
} else if (end >= sizes[dim]) {
end = sizes[dim];
if (end_val < start_val) {
end_val = start_val;
} else if (end_val >= sizes[dim]) {
end_val = sizes[dim];
}
auto storage_offset = self.storage_offset() + start * strides[dim];
auto len = end - start;
sizes[dim] = (len + step - 1) / step; // round-up
auto storage_offset = self.storage_offset() + start_val * strides[dim];
auto len = end_val - start_val;
sizes[dim] = (len + step - 1) / step; // round-up
strides[dim] *= step;
auto result = self.as_strided(sizes, strides, storage_offset);
namedinference::propagate_names(result, self);
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3426,7 +3426,8 @@
variants: function, method
device_guard: False

- func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)
use_c10_dispatcher: full
variants: function, method
device_guard: False
dispatch:
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/templates/RegisterSchema.cpp
Expand Up @@ -23,7 +23,8 @@ TORCH_LIBRARY(aten, m) {
// String Ops
// Implementations located in torch/csrc/jit/runtime/register_prim_ops.cpp
m.def(TORCH_SELECTIVE_SCHEMA("aten::splitlines(str self, bool keepends=False) -> str[]"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::slice.str(str string, int start, int end=9223372036854775807, int step=1) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA(
"aten::slice.str(str string, int? start=0, int? end=9223372036854775807, int step=1) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::isupper(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::islower(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::capitalize(str self) -> str"));
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/derivatives.yaml
Expand Up @@ -980,8 +980,8 @@
- name: sinh(Tensor self) -> Tensor
self: grad * self.cosh().conj()

- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
self: slice_backward(grad, self.sizes(), dim, start, end, step)
- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)
self: slice_backward_wrapper(grad, self.sizes(), dim, start, end, step)

- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
self: slogdet_backward(grad, self, sign, logabsdet)
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -1921,6 +1921,19 @@ Tensor elu_double_backward(
}
}

Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef& input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
auto start_val = start.has_value() ? start.value() : 0;
auto end_val = end.has_value() ? end.value() : INT64_MAX;

return slice_backward(grad, input_sizes, dim, start_val, end_val, step);
}

// https://j-towns.github.io/papers/svd-derivative.pdf
//
// This makes no assumption on the signs of sigma.
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -134,6 +134,13 @@ at::Tensor elu_double_backward(const Tensor& grad, const Tensor& grad_output, Sc

Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v);
Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef& input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step);
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v);
std::tuple<Tensor, Tensor> triangular_solve_backward(
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/jit/frontend/tracer.cpp
Expand Up @@ -562,7 +562,11 @@ void addInputs(Node* n, const char* name, int64_t value) {
}

void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
if (value) {
using ArgumentStash = jit::tracer::ArgumentStash;
if (ArgumentStash::hasValue(name)) {
Value* v = ArgumentStash::popValue(name);
n->addInput(v);
} else if (value) {
detail::genericAddInput(n, *value);
} else {
Graph* g = n->owningGraph();
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/frontend/tree_views.h
Expand Up @@ -965,15 +965,15 @@ struct SliceExpr : public Expr {
Maybe<Expr> step() const {
return Maybe<Expr>(subtree(2));
}
Expr startOr(int alternative) const {
Expr startOr(int64_t alternative) const {
const auto startOption = start();
return startOption.present() ? startOption.get() : createInt(alternative);
}
Expr endOr(int alternative) const {
Expr endOr(int64_t alternative) const {
const auto endOption = end();
return endOption.present() ? endOption.get() : createInt(alternative);
}
Expr stepOr(int alternative) const {
Expr stepOr(int64_t alternative) const {
const auto stepOption = step();
return stepOption.present() ? stepOption.get() : createInt(alternative);
}
Expand All @@ -987,7 +987,7 @@ struct SliceExpr : public Expr {
}

private:
Expr createInt(int value) const {
Expr createInt(int64_t value) const {
return Expr(Const::create(range(), c10::to_string(value)));
}
};
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/shape_analysis.cpp
Expand Up @@ -883,7 +883,7 @@ class ShapePropagator {
"aten::trunc(Tensor self) -> Tensor",
"aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
"aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
"aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor",
"aten::slice(Tensor self, int dim, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor",
"aten::alias(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
Expand Down
20 changes: 6 additions & 14 deletions torch/csrc/jit/runtime/register_ops_utils.cpp
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/runtime/register_ops_utils.h>
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -434,22 +435,13 @@ void listSlice(Stack* stack) {

const int64_t list_size = list.size();

// clamp start and end to the bounds of the list
const auto normalized_start =
std::max((int64_t)0, normalizeIndex(start, list_size));
const auto normalized_end =
std::min(list_size, normalizeIndex(end, list_size));

c10::List<IValue> sliced_list = make_result_list<IValue>(list.elementType());
if (normalized_end <= normalized_start) {
// early exit if the slice is trivially empty
push(stack, std::move(sliced_list));
return;
}

sliced_list.reserve(normalized_end - normalized_start);
const int64_t num_values =
slice_indices_adjust(list_size, &start, &end, step);
sliced_list.reserve(num_values);

for (auto i = normalized_start; i < normalized_end;) {
int i = start;
for (int j = 0; j < num_values; ++j) {
sliced_list.push_back(list.get(i));
i += step;
}
Expand Down
27 changes: 14 additions & 13 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -1,6 +1,8 @@
#include <c10/util/Optional.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/register_ops_utils.h>
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
#include <torch/library.h>

#include <algorithm>
Expand Down Expand Up @@ -29,23 +31,22 @@ namespace {

std::string stringSlice(
std::string string,
int64_t start,
int64_t end,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
TORCH_CHECK(step == 1, "Slicing a string only supports step=1");

const int64_t size = string.size();
int64_t start_val = start.has_value() ? start.value() : INT64_MAX;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;

// Clamp start and end to the bounds of the list
start = std::max(int64_t(0), normalizeIndex(start, size));
end = std::min(size, normalizeIndex(end, size));
const int64_t num_vals =
slice_indices_adjust(string.size(), &start_val, &end_val, step);

if (end <= start) {
// Slice is empty
return std::string("");
int64_t i = start_val;
std::string result = "";
for (int64_t j = 0; j < num_vals; j++) {
result += string[i];
i += step;
}

std::string result(string.begin() + start, string.begin() + end);
return result;
}

Expand Down Expand Up @@ -603,7 +604,7 @@ RegisterOperators reg(
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> t[]"),
"aten::slice.t(t[] l, int? start=0, int? end=9223372036854775807, int step=1) -> t[]"),
listSlice,
aliasAnalysisFromSchema()),
OperatorGenerator(
Expand Down

0 comments on commit 1a38fa9

Please sign in to comment.