Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 25 additions & 30 deletions aten/src/ATen/TensorIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
namespace at {
namespace indexing {

const int64_t INDEX_MAX = std::numeric_limits<int64_t>::max();
const int64_t INDEX_MIN = std::numeric_limits<int64_t>::min();
const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really tiny nit, this should be constexpr right?

const int64_t INDEX_MAX = -(INDEX_MIN + 1);

enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor };

Expand All @@ -37,52 +37,47 @@ TORCH_API extern const EllipsisIndexType Ellipsis;

struct TORCH_API Slice final {
public:
// This mirrors `__PySlice_Unpack` in torch/csrc/utils/python_compat.h
Slice(
c10::optional<int64_t> start_index = c10::nullopt,
c10::optional<int64_t> stop_index = c10::nullopt,
c10::optional<int64_t> step_index = c10::nullopt) {
c10::optional<c10::SymInt> start_index = c10::nullopt,
c10::optional<c10::SymInt> stop_index = c10::nullopt,
c10::optional<c10::SymInt> step_index = c10::nullopt) {
if (!step_index.has_value()) {
step_ = 1;
step_ = c10::SymInt(1);
} else {
step_ = step_index.value();
TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");

// Here step might be -INDEX_MAX-1; in this case we replace it
// with -INDEX_MAX. This doesn't affect the semantics, and it
// guards against later undefined behaviour resulting from code that
// does "step = -step" as part of a slice reversal.
if (step_ < -INDEX_MAX)
step_ = -INDEX_MAX;
step_ = std::move(step_index).value();
}

TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");

if (!start_index.has_value()) {
start_ = step_ < 0 ? INDEX_MAX : 0;
start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
} else {
start_ = start_index.value();
start_ = std::move(start_index).value();
}

if (!stop_index.has_value()) {
stop_ = step_ < 0 ? INDEX_MIN : INDEX_MAX;
stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
} else {
stop_ = stop_index.value();
stop_ = std::move(stop_index).value();
}
}

inline int64_t start() const {
inline c10::SymInt start() const {
return start_;
}

inline int64_t stop() const {
inline c10::SymInt stop() const {
return stop_;
}

inline int64_t step() const {
inline c10::SymInt step() const {
return step_;
}

private:
int64_t start_;
int64_t stop_;
int64_t step_;
c10::SymInt start_;
c10::SymInt stop_;
c10::SymInt step_;
};

TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
Expand Down Expand Up @@ -213,9 +208,9 @@ namespace impl {
static inline Tensor applySlice(
const Tensor& self,
int64_t dim,
int64_t start,
int64_t stop,
int64_t step,
c10::SymInt start,
c10::SymInt stop,
c10::SymInt step,
bool disable_slice_optimization,
const at::Device& self_device,
const c10::optional<SymIntArrayRef>& self_sizes) {
Expand All @@ -235,7 +230,7 @@ static inline Tensor applySlice(
return self;
}
}
return self.slice(dim, start, stop, step);
return self.slice_symint(dim, start, stop, step);
}

static inline Tensor applySelect(
Expand Down
5 changes: 5 additions & 0 deletions c10/core/SymInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ class C10_API SymInt {
return i > MAX_UNREPRESENTABLE_INT;
}

// Return the min represetable integer as a SymInt
static constexpr int64_t min_representable_int() {
return MAX_UNREPRESENTABLE_INT + 1;
}

private:
// Constraints on the internal representation:
//
Expand Down
29 changes: 8 additions & 21 deletions torch/csrc/autograd/python_variable_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,6 @@ inline Variable valueToTensor(
}
}

static inline void checkUnpackSlice(
PyObject* index,
Py_ssize_t* start_ptr,
Py_ssize_t* stop_ptr,
Py_ssize_t* step_ptr) {
if (PySlice_Unpack(index, start_ptr, stop_ptr, step_ptr) != 0) {
throw python_error();
}
}

static inline void recordSliceTrace(PyObject* obj) {
PySliceObject* sliceobj = (PySliceObject*)obj;
if (THPVariable_Check(sliceobj->start)) {
Expand Down Expand Up @@ -209,14 +199,12 @@ static inline Variable applySlicing(
}
return at::indexing::TensorIndex(THPUtils_unpackLong(obj));
} else if (PySlice_Check(obj)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t start, stop, step;
checkUnpackSlice(obj, &start, &stop, &step);
auto val = __PySlice_Unpack(obj);
if (is_tracing) {
recordSliceTrace(obj);
}
return at::indexing::TensorIndex(
at::indexing::Slice(start, stop, step));
at::indexing::Slice(val.start, val.stop, val.step));
} else if (obj == Py_Ellipsis) {
return at::indexing::TensorIndex(at::indexing::Ellipsis);
} else if (obj == Py_None) {
Expand Down Expand Up @@ -360,15 +348,14 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
return THPVariable_Wrap(at::indexing::get_item(
self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}));
} else if (PySlice_Check(index)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t start, stop, step;
checkUnpackSlice(index, &start, &stop, &step);
auto val = __PySlice_Unpack(index);
if (is_tracing) {
recordSliceTrace(index);
}
return THPVariable_Wrap(at::indexing::get_item(
self_,
{at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))}));
{at::indexing::TensorIndex(
at::indexing::Slice(val.start, val.stop, val.step))}));
} else if (index == Py_False || index == Py_True) {
return THPVariable_Wrap(([&]() {
pybind11::gil_scoped_release no_gil;
Expand Down Expand Up @@ -489,16 +476,16 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
return 0;
} else if (PySlice_Check(index)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t start, stop, step;
checkUnpackSlice(index, &start, &stop, &step);
auto val = __PySlice_Unpack(index);
if (is_tracing) {
recordSliceTrace(index);
}
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
// indexing functions from Python ]
dispatch_set_item(
self_,
{at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))},
{at::indexing::TensorIndex(
at::indexing::Slice(val.start, val.stop, val.step))},
value,
/*disable_slice_optimization=*/is_tracing);
return 0;
Expand Down
85 changes: 85 additions & 0 deletions torch/csrc/autograd/python_variable_indexing.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,96 @@
#pragma once

#include <c10/core/SymInt.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_symnode.h>

namespace torch {
namespace autograd {

struct UnpackedSlice {
c10::SymInt start;
c10::SymInt stop;
c10::SymInt step;
};

// This mirrors Cpython's PySlice_Unpack method
static inline UnpackedSlice __PySlice_Unpack(PyObject* _r) {
PySliceObject* r = (PySliceObject*)_r;
/* this is harder to get right than you might think */

c10::SymInt start_sym, stop_sym, step_sym;

auto clip_val = [](Py_ssize_t val) {
if (val < c10::SymInt::min_representable_int()) {
auto r = PyErr_WarnEx(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TORCH_WARN("The message");

PyExc_UserWarning,
"Truncating the start/stop/step "
"of slice. This is likely because of "
"saved old models when the start/stop/step were larger.",
1);
if (r != 0) {
throw python_error();
}
return (Py_ssize_t)(c10::SymInt::min_representable_int());
}
return val;
};

if (r->step == Py_None) {
step_sym = c10::SymInt(1);
} else {
if (torch::is_symint(r->step)) {
auto step_sym = py::handle(r->step).cast<c10::SymInt>();
} else {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t step;
if (!_PyEval_SliceIndex(r->step, &step)) {
throw python_error();
}
if (step == 0) {
PyErr_SetString(PyExc_ValueError, "slice step cannot be zero");
}

step = clip_val(step);
step_sym = c10::SymInt(step);
}
}

if (torch::is_symint(r->start)) {
start_sym = py::handle(r->start).cast<c10::SymInt>();
} else if (r->start == Py_None) {
start_sym = c10::SymInt(step_sym < 0 ? PY_SSIZE_T_MAX : 0);
} else {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t start;
if (!_PyEval_SliceIndex(r->start, &start)) {
throw python_error();
}
start = clip_val(start);
start_sym = c10::SymInt(start);
}

if (torch::is_symint(r->stop)) {
stop_sym = py::handle(r->stop).cast<c10::SymInt>();
} else if (r->stop == Py_None) {
stop_sym = c10::SymInt(
step_sym < 0 ? c10::SymInt::min_representable_int() : PY_SSIZE_T_MAX);
} else {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t stop;
if (!_PyEval_SliceIndex(r->stop, &stop)) {
throw python_error();
}
stop = clip_val(stop);
stop_sym = c10::SymInt(stop);
}

return UnpackedSlice{
std::move(start_sym), std::move(stop_sym), std::move(step_sym)};
}

Py_ssize_t THPVariable_length(PyObject* self);
PyObject* THPVariable_getitem(PyObject* self, PyObject* index);
int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* value);
Expand Down