Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change list striding parameters to take optional integer (preceding work for #49352) #48719

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion aten/src/ATen/BatchingRegistrations.cpp
Expand Up @@ -359,7 +359,12 @@ Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
}

Tensor slice_batching_rule(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice_batching_rule(
tugsbayasgalan marked this conversation as resolved.
Show resolved Hide resolved
const Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = self_physical.tensor().slice(dim_physical, start, end, step);
Expand Down
69 changes: 42 additions & 27 deletions aten/src/ATen/native/TensorShape.cpp
@@ -1,24 +1,24 @@
#include <algorithm>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/ExpandUtils.h>
#include <ATen/InferSize.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/SmallVector.h>
#include <ATen/native/Copy.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/cpu/CatKernel.h>
#include <ATen/native/Copy.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/quantized/QTensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/SmallVector.h>
#include <algorithm>
#include <cstdint>
#include <vector>

namespace at {
namespace native {
Expand Down Expand Up @@ -1190,35 +1190,50 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index)
}
}

Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
Tensor slice(
const Tensor& self,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
int64_t ndim = self.dim();
if (ndim == 0) {
TORCH_CHECK_INDEX(false, "slice() cannot be applied to a 0-dim tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();

// handle optional parameters
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;

// TODO: support negative strides
TORCH_CHECK(step > 0, "slice step must be positive");
tugsbayasgalan marked this conversation as resolved.
Show resolved Hide resolved
if (start < 0) {
start += sizes[dim];

// INT64_MAX stands for default value.
if (start_val == INT64_MAX) {
start_val = 0;
}
if (start_val < 0) {
start_val += sizes[dim];
}
if (end < 0) {
end += sizes[dim];
if (end_val < 0) {
end_val += sizes[dim];
}
if (start < 0) {
start = 0;
} else if (start >= sizes[dim]) {
start = sizes[dim];
if (start_val < 0) {
start_val = 0;
} else if (start_val >= sizes[dim]) {
start_val = sizes[dim];
}
if (end < start) {
end = start;
} else if (end >= sizes[dim]) {
end = sizes[dim];
if (end_val < start_val) {
end_val = start_val;
} else if (end_val >= sizes[dim]) {
end_val = sizes[dim];
}
auto storage_offset = self.storage_offset() + start * strides[dim];
auto len = end - start;
sizes[dim] = (len + step - 1) / step; // round-up
auto storage_offset = self.storage_offset() + start_val * strides[dim];
auto len = end_val - start_val;
sizes[dim] = (len + step - 1) / step; // round-up
strides[dim] *= step;
auto result = self.as_strided(sizes, strides, storage_offset);
namedinference::propagate_names(result, self);
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3426,7 +3426,8 @@
variants: function, method
device_guard: False

- func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
- func: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)
use_c10_dispatcher: full
variants: function, method
device_guard: False
dispatch:
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/templates/RegisterSchema.cpp
Expand Up @@ -23,7 +23,8 @@ TORCH_LIBRARY(aten, m) {
// String Ops
// Implementations located in torch/csrc/jit/runtime/register_prim_ops.cpp
m.def(TORCH_SELECTIVE_SCHEMA("aten::splitlines(str self, bool keepends=False) -> str[]"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::slice.str(str string, int start, int end=9223372036854775807, int step=1) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA(
"aten::slice.str(str string, int? start=0, int? end=9223372036854775807, int step=1) -> str"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::isupper(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::islower(str self) -> bool"));
m.def(TORCH_SELECTIVE_SCHEMA("aten::capitalize(str self) -> str"));
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/derivatives.yaml
Expand Up @@ -980,8 +980,8 @@
- name: sinh(Tensor self) -> Tensor
self: grad * self.cosh().conj()

- name: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
self: slice_backward(grad, self.sizes(), dim, start, end, step)
- name: slice.Tensor(Tensor(a) self, int dim=0, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor(a)
self: slice_backward_wrapper(grad, self.sizes(), dim, start, end, step)

- name: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
self: slogdet_backward(grad, self, sign, logabsdet)
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -1906,6 +1906,19 @@ Tensor elu_double_backward(
}
}

Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef& input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
auto start_val = start.has_value() ? start.value() : 0;
auto end_val = end.has_value() ? end.value() : INT64_MAX;

return slice_backward(grad, input_sizes, dim, start_val, end_val, step);
}

// https://j-towns.github.io/papers/svd-derivative.pdf
//
// This makes no assumption on the signs of sigma.
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/autograd/FunctionsManual.h
Expand Up @@ -131,6 +131,13 @@ at::Tensor elu_double_backward(const Tensor& grad, const Tensor& grad_output, Sc

Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v);
Tensor slice_backward_wrapper(
const at::Tensor& grad,
const c10::IntArrayRef& input_sizes,
int64_t dim,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step);
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v);
std::tuple<Tensor, Tensor> triangular_solve_backward(
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/jit/frontend/tracer.cpp
Expand Up @@ -562,7 +562,11 @@ void addInputs(Node* n, const char* name, int64_t value) {
}

void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
if (value) {
using ArgumentStash = jit::tracer::ArgumentStash;
if (ArgumentStash::hasValue(name)) {
Value* v = ArgumentStash::popValue(name);
n->addInput(v);
} else if (value) {
detail::genericAddInput(n, *value);
} else {
Graph* g = n->owningGraph();
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/frontend/tree_views.h
Expand Up @@ -965,15 +965,15 @@ struct SliceExpr : public Expr {
Maybe<Expr> step() const {
return Maybe<Expr>(subtree(2));
}
Expr startOr(int alternative) const {
Expr startOr(int64_t alternative) const {
const auto startOption = start();
return startOption.present() ? startOption.get() : createInt(alternative);
}
Expr endOr(int alternative) const {
Expr endOr(int64_t alternative) const {
const auto endOption = end();
return endOption.present() ? endOption.get() : createInt(alternative);
}
Expr stepOr(int alternative) const {
Expr stepOr(int64_t alternative) const {
const auto stepOption = step();
return stepOption.present() ? stepOption.get() : createInt(alternative);
}
Expand All @@ -987,7 +987,7 @@ struct SliceExpr : public Expr {
}

private:
Expr createInt(int value) const {
Expr createInt(int64_t value) const {
return Expr(Const::create(range(), c10::to_string(value)));
}
};
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/shape_analysis.cpp
Expand Up @@ -883,7 +883,7 @@ class ShapePropagator {
"aten::trunc(Tensor self) -> Tensor",
"aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
"aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
"aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor",
"aten::slice(Tensor self, int dim, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor",
"aten::alias(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
Expand Down
20 changes: 6 additions & 14 deletions torch/csrc/jit/runtime/register_ops_utils.cpp
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/runtime/register_ops_utils.h>
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -434,22 +435,13 @@ void listSlice(Stack* stack) {

const int64_t list_size = list.size();

// clamp start and end to the bounds of the list
const auto normalized_start =
std::max((int64_t)0, normalizeIndex(start, list_size));
const auto normalized_end =
std::min(list_size, normalizeIndex(end, list_size));

c10::List<IValue> sliced_list = make_result_list<IValue>(list.elementType());
if (normalized_end <= normalized_start) {
// early exit if the slice is trivially empty
push(stack, std::move(sliced_list));
return;
}

sliced_list.reserve(normalized_end - normalized_start);
const int64_t num_values =
slice_indices_adjust(list_size, &start, &end, step);
sliced_list.reserve(num_values);

for (auto i = normalized_start; i < normalized_end;) {
int i = start;
for (int j = 0; j < num_values; ++j) {
sliced_list.push_back(list.get(i));
i += step;
}
Expand Down
27 changes: 14 additions & 13 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
@@ -1,6 +1,8 @@
#include <c10/util/Optional.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/register_ops_utils.h>
#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
#include <torch/library.h>

#include <algorithm>
Expand Down Expand Up @@ -29,23 +31,22 @@ namespace {

std::string stringSlice(
tugsbayasgalan marked this conversation as resolved.
Show resolved Hide resolved
std::string string,
int64_t start,
int64_t end,
c10::optional<int64_t> start,
c10::optional<int64_t> end,
int64_t step) {
TORCH_CHECK(step == 1, "Slicing a string only supports step=1");

const int64_t size = string.size();
int64_t start_val = start.has_value() ? start.value() : INT64_MAX;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this default to INT64_MAX? Why not just set it to 0? cc @eellison this is impeding SymInt-ification

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, I forget, but I think that was mirroring how the python slice handling code went where None would be encoded as INT64_MAX

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may have also been a forward compatible thing because None support had to land in multiple parts

int64_t end_val = end.has_value() ? end.value() : INT64_MAX;

// Clamp start and end to the bounds of the list
start = std::max(int64_t(0), normalizeIndex(start, size));
end = std::min(size, normalizeIndex(end, size));
const int64_t num_vals =
slice_indices_adjust(string.size(), &start_val, &end_val, step);

if (end <= start) {
// Slice is empty
return std::string("");
int64_t i = start_val;
std::string result = "";
for (int64_t j = 0; j < num_vals; j++) {
result += string[i];
i += step;
}

std::string result(string.begin() + start, string.begin() + end);
return result;
}

Expand Down Expand Up @@ -603,7 +604,7 @@ RegisterOperators reg(
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> t[]"),
"aten::slice.t(t[] l, int? start=0, int? end=9223372036854775807, int step=1) -> t[]"),
listSlice,
aliasAnalysisFromSchema()),
OperatorGenerator(
Expand Down