Skip to content

Commit

Permalink
Striding for lists
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 74bc6d79a75331aab5452d9034f8753891fe3ca3
Pull Request resolved: #48719
  • Loading branch information
tugsbayasgalan committed Dec 14, 2020
1 parent 7d406b4 commit 41ea666
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 58 deletions.
7 changes: 6 additions & 1 deletion aten/src/ATen/BatchingRegistrations.cpp
Expand Up @@ -340,7 +340,12 @@ Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes
return grad_physical.newLogicalFromPhysical(grad_input);
}

Tensor slice_batching_rule(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice_batching_rule(
const Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
c10::optional<int64_t> step) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().slice(dim_physical, start, end, step);
Expand Down
80 changes: 48 additions & 32 deletions aten/src/ATen/native/TensorShape.cpp
@@ -1,23 +1,23 @@
#include <algorithm>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/ExpandUtils.h>
#include <ATen/InferSize.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <ATen/native/Copy.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/cpu/CatKernel.h>
#include <ATen/native/Copy.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/quantized/QTensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <algorithm>
#include <cstdint>
#include <vector>

namespace at {
namespace native {
Expand Down Expand Up @@ -1070,36 +1070,52 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index)
}
}

Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice(
const Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
c10::optional<int64_t> step) {
int64_t ndim = self.dim();
if (ndim == 0) {
TORCH_CHECK_INDEX(false, "slice() cannot be applied to a 0-dim tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();

// handle optinal parameters
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
int64_t step_val = step.has_value() ? step.value() : 1;

// TODO: support negative strides
TORCH_CHECK(step > 0, "slice step must be positive");
if (start < 0) {
start += sizes[dim];
}
if (end < 0) {
end += sizes[dim];
}
if (start < 0) {
start = 0;
} else if (start >= sizes[dim]) {
start = sizes[dim];
}
if (end < start) {
end = start;
} else if (end >= sizes[dim]) {
end = sizes[dim];
}
auto storage_offset = self.storage_offset() + start * strides[dim];
auto len = end - start;
sizes[dim] = (len + step - 1) / step; // round-up
strides[dim] *= step;
TORCH_CHECK(step_val > 0, "slice step must be positive");

// INT64_MAX stands for default value.
if (start_val == INT64_MAX) {
start_val = 0;
}
if (start_val < 0) {
start_val += sizes[dim];
}
if (end_val < 0) {
end_val += sizes[dim];
}
if (start_val < 0) {
start_val = 0;
} else if (start_val >= sizes[dim]) {
start_val = sizes[dim];
}
if (end_val < start_val) {
end_val = start_val;
} else if (end_val >= sizes[dim]) {
end_val = sizes[dim];
}
auto storage_offset = self.storage_offset() + start_val * strides[dim];
auto len = end_val - start_val;
sizes[dim] = (len + step_val - 1) / step_val; // round-up
strides[dim] *= step_val;
auto result = self.as_strided(sizes, strides, storage_offset);
namedinference::propagate_names(result, self);
return result;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3569,7 +3569,7 @@
variants: function, method
device_guard: False

- func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int? step=1) -> Tensor(a)
use_c10_dispatcher: full
variants: function, method
device_guard: False
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/derivatives.yaml
Expand Up @@ -964,8 +964,8 @@
- name: sinh(Tensor self) -> Tensor
self: grad * self.cosh().conj()

- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
self: slice_backward(grad, self.sizes(), dim, start, end, step)
- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int? step=1) -> Tensor(a)
self: slice_backward_wrapper(grad, self.sizes(), dim, start, end, step)

- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
self: slogdet_backward(grad, self, sign, logabsdet)
Expand Down
16 changes: 15 additions & 1 deletion torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -1821,6 +1821,20 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
}
}

Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
c10::optional<int64_t> step) {
auto start_val = start.has_value() ? start.value() : 0;
auto end_val = end.has_value() ? end.value() : INT64_MAX;
auto step_val = step.has_value() ? step.value() : 1;

return slice_backward(grad, input_sizes, dim, start_val, end_val, step_val);
}

// https://j-towns.github.io/papers/svd-derivative.pdf
//
// This makes no assumption on the signs of sigma.
Expand Down Expand Up @@ -2771,7 +2785,7 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
}

Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) {
// since first backward takes care of scaling by frequency,
// since first backward takes care of scaling by frequency,
// we don't need to worry about it here.
auto gg_weight = grad.index_select(0, indices.reshape(-1));

Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -121,9 +121,15 @@ at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out,
at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx);
at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad);
at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity);

Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v);
Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
c10::optional<int64_t> step);
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v);
std::tuple<Tensor, Tensor> triangular_solve_backward(
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/jit/frontend/tracer.cpp
Expand Up @@ -547,7 +547,11 @@ void addInputs(Node* n, const char* name, int64_t value) {
}

void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
if (value) {
using ArgumentStash = jit::tracer::ArgumentStash;
if (ArgumentStash::hasValue(name)) {
Value* v = ArgumentStash::popValue(name);
n->addInput(v);
} else if (value) {
detail::genericAddInput(n, *value);
} else {
Graph* g = n->owningGraph();
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/frontend/tree_views.h
Expand Up @@ -929,15 +929,15 @@ struct SliceExpr : public Expr {
Maybe<Expr> step() const {
return Maybe<Expr>(subtree(2));
}
Expr startOr(int alternative) const {
Expr startOr(int64_t alternative) const {
const auto startOption = start();
return startOption.present() ? startOption.get() : createInt(alternative);
}
Expr endOr(int alternative) const {
Expr endOr(int64_t alternative) const {
const auto endOption = end();
return endOption.present() ? endOption.get() : createInt(alternative);
}
Expr stepOr(int alternative) const {
Expr stepOr(int64_t alternative) const {
const auto stepOption = step();
return stepOption.present() ? stepOption.get() : createInt(alternative);
}
Expand All @@ -951,7 +951,7 @@ struct SliceExpr : public Expr {
}

private:
Expr createInt(int value) const {
Expr createInt(int64_t value) const {
return Expr(Const::create(range(), c10::to_string(value)));
}
};
Expand Down
20 changes: 6 additions & 14 deletions torch/csrc/jit/runtime/register_ops_utils.cpp
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/runtime/register_ops_utils.h>
#include "jit/runtime/slice_indices_adjust.h"

namespace torch {
namespace jit {
Expand Down Expand Up @@ -434,22 +435,13 @@ void listSlice(Stack* stack) {

const int64_t list_size = list.size();

// clamp start and end to the bounds of the list
const auto normalized_start =
std::max((int64_t)0, normalizeIndex(start, list_size));
const auto normalized_end =
std::min(list_size, normalizeIndex(end, list_size));

c10::List<IValue> sliced_list = make_result_list<IValue>(list.elementType());
if (normalized_end <= normalized_start) {
// early exit if the slice is trivially empty
push(stack, std::move(sliced_list));
return;
}

sliced_list.reserve(normalized_end - normalized_start);
const int64_t num_values =
slice_indices_adjust(list_size, &start, &end, step);
sliced_list.reserve(num_values);

for (auto i = normalized_start; i < normalized_end;) {
int i = start;
for (int j = 0; j < num_values; ++j) {
sliced_list.push_back(list.get(i));
i += step;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/runtime/register_prim_ops.cpp
Expand Up @@ -603,7 +603,7 @@ RegisterOperators reg(
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> t[]"),
"aten::slice.t(t[] l, int? start, int? end=9223372036854775807, int? step=1) -> t[]"),
listSlice,
aliasAnalysisFromSchema()),
OperatorGenerator(
Expand Down

0 comments on commit 41ea666

Please sign in to comment.