Skip to content

Commit

Permalink
Revert "Symintify pytorch slicing logic (pytorch#91340)"
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchmergebot authored and pull[bot] committed Sep 20, 2023
1 parent b019485 commit 7ee2cf9
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 123 deletions.
55 changes: 30 additions & 25 deletions aten/src/ATen/TensorIndexing.h
Expand Up @@ -23,8 +23,8 @@
namespace at {
namespace indexing {

const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
const int64_t INDEX_MAX = -(INDEX_MIN + 1);
const int64_t INDEX_MAX = std::numeric_limits<int64_t>::max();
const int64_t INDEX_MIN = std::numeric_limits<int64_t>::min();

enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor };

Expand All @@ -37,47 +37,52 @@ TORCH_API extern const EllipsisIndexType Ellipsis;

struct TORCH_API Slice final {
public:
// This mirrors `__PySlice_Unpack` in torch/csrc/utils/python_compat.h
Slice(
c10::optional<c10::SymInt> start_index = c10::nullopt,
c10::optional<c10::SymInt> stop_index = c10::nullopt,
c10::optional<c10::SymInt> step_index = c10::nullopt) {
c10::optional<int64_t> start_index = c10::nullopt,
c10::optional<int64_t> stop_index = c10::nullopt,
c10::optional<int64_t> step_index = c10::nullopt) {
if (!step_index.has_value()) {
step_ = c10::SymInt(1);
step_ = 1;
} else {
step_ = std::move(step_index).value();
step_ = step_index.value();
TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");

// Here step might be -INDEX_MAX-1; in this case we replace it
// with -INDEX_MAX. This doesn't affect the semantics, and it
// guards against later undefined behaviour resulting from code that
// does "step = -step" as part of a slice reversal.
if (step_ < -INDEX_MAX)
step_ = -INDEX_MAX;
}

TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");

if (!start_index.has_value()) {
start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
start_ = step_ < 0 ? INDEX_MAX : 0;
} else {
start_ = std::move(start_index).value();
start_ = start_index.value();
}

if (!stop_index.has_value()) {
stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
stop_ = step_ < 0 ? INDEX_MIN : INDEX_MAX;
} else {
stop_ = std::move(stop_index).value();
stop_ = stop_index.value();
}
}

inline c10::SymInt start() const {
inline int64_t start() const {
return start_;
}

inline c10::SymInt stop() const {
inline int64_t stop() const {
return stop_;
}

inline c10::SymInt step() const {
inline int64_t step() const {
return step_;
}

private:
c10::SymInt start_;
c10::SymInt stop_;
c10::SymInt step_;
int64_t start_;
int64_t stop_;
int64_t step_;
};

TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
Expand Down Expand Up @@ -208,9 +213,9 @@ namespace impl {
static inline Tensor applySlice(
const Tensor& self,
int64_t dim,
c10::SymInt start,
c10::SymInt stop,
c10::SymInt step,
int64_t start,
int64_t stop,
int64_t step,
bool disable_slice_optimization,
const at::Device& self_device,
const c10::optional<SymIntArrayRef>& self_sizes) {
Expand All @@ -230,7 +235,7 @@ static inline Tensor applySlice(
return self;
}
}
return self.slice_symint(dim, start, stop, step);
return self.slice(dim, start, stop, step);
}

static inline Tensor applySelect(
Expand Down
5 changes: 0 additions & 5 deletions c10/core/SymInt.h
Expand Up @@ -190,11 +190,6 @@ class C10_API SymInt {
return i > MAX_UNREPRESENTABLE_INT;
}

// Return the min represetable integer as a SymInt
static constexpr int64_t min_representable_int() {
return MAX_UNREPRESENTABLE_INT + 1;
}

private:
// Constraints on the internal representation:
//
Expand Down
29 changes: 21 additions & 8 deletions torch/csrc/autograd/python_variable_indexing.cpp
Expand Up @@ -128,6 +128,16 @@ inline Variable valueToTensor(
}
}

static inline void checkUnpackSlice(
PyObject* index,
Py_ssize_t* start_ptr,
Py_ssize_t* stop_ptr,
Py_ssize_t* step_ptr) {
if (PySlice_Unpack(index, start_ptr, stop_ptr, step_ptr) != 0) {
throw python_error();
}
}

static inline void recordSliceTrace(PyObject* obj) {
PySliceObject* sliceobj = (PySliceObject*)obj;
if (THPVariable_Check(sliceobj->start)) {
Expand Down Expand Up @@ -199,12 +209,14 @@ static inline Variable applySlicing(
}
return at::indexing::TensorIndex(THPUtils_unpackLong(obj));
} else if (PySlice_Check(obj)) {
auto val = __PySlice_Unpack(obj);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t start, stop, step;
checkUnpackSlice(obj, &start, &stop, &step);
if (is_tracing) {
recordSliceTrace(obj);
}
return at::indexing::TensorIndex(
at::indexing::Slice(val.start, val.stop, val.step));
at::indexing::Slice(start, stop, step));
} else if (obj == Py_Ellipsis) {
return at::indexing::TensorIndex(at::indexing::Ellipsis);
} else if (obj == Py_None) {
Expand Down Expand Up @@ -348,14 +360,15 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
return THPVariable_Wrap(at::indexing::get_item(
self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}));
} else if (PySlice_Check(index)) {
auto val = __PySlice_Unpack(index);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t start, stop, step;
checkUnpackSlice(index, &start, &stop, &step);
if (is_tracing) {
recordSliceTrace(index);
}
return THPVariable_Wrap(at::indexing::get_item(
self_,
{at::indexing::TensorIndex(
at::indexing::Slice(val.start, val.stop, val.step))}));
{at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))}));
} else if (index == Py_False || index == Py_True) {
return THPVariable_Wrap(([&]() {
pybind11::gil_scoped_release no_gil;
Expand Down Expand Up @@ -476,16 +489,16 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
return 0;
} else if (PySlice_Check(index)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
auto val = __PySlice_Unpack(index);
Py_ssize_t start, stop, step;
checkUnpackSlice(index, &start, &stop, &step);
if (is_tracing) {
recordSliceTrace(index);
}
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
// indexing functions from Python ]
dispatch_set_item(
self_,
{at::indexing::TensorIndex(
at::indexing::Slice(val.start, val.stop, val.step))},
{at::indexing::TensorIndex(at::indexing::Slice(start, stop, step))},
value,
/*disable_slice_optimization=*/is_tracing);
return 0;
Expand Down
85 changes: 0 additions & 85 deletions torch/csrc/autograd/python_variable_indexing.h
@@ -1,96 +1,11 @@
#pragma once

#include <c10/core/SymInt.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_symnode.h>

namespace torch {
namespace autograd {

struct UnpackedSlice {
c10::SymInt start;
c10::SymInt stop;
c10::SymInt step;
};

// This mirrors Cpython's PySlice_Unpack method
static inline UnpackedSlice __PySlice_Unpack(PyObject* _r) {
PySliceObject* r = (PySliceObject*)_r;
/* this is harder to get right than you might think */

c10::SymInt start_sym, stop_sym, step_sym;

auto clip_val = [](Py_ssize_t val) {
if (val < c10::SymInt::min_representable_int()) {
auto r = PyErr_WarnEx(
PyExc_UserWarning,
"Truncating the start/stop/step "
"of slice. This is likely because of "
"saved old models when the start/stop/step were larger.",
1);
if (r != 0) {
throw python_error();
}
return c10::SymInt::min_representable_int();
}
return val;
};

if (r->step == Py_None) {
step_sym = c10::SymInt(1);
} else {
if (torch::is_symint(r->step)) {
auto step_sym = py::handle(r->step).cast<c10::SymInt>();
} else {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t step;
if (!_PyEval_SliceIndex(r->step, &step)) {
throw python_error();
}
if (step == 0) {
PyErr_SetString(PyExc_ValueError, "slice step cannot be zero");
}

step = clip_val(step);
step_sym = c10::SymInt(step);
}
}

if (torch::is_symint(r->start)) {
start_sym = py::handle(r->start).cast<c10::SymInt>();
} else if (r->start == Py_None) {
start_sym = c10::SymInt(step_sym < 0 ? PY_SSIZE_T_MAX : 0);
} else {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t start;
if (!_PyEval_SliceIndex(r->start, &start)) {
throw python_error();
}
start = clip_val(start);
start_sym = c10::SymInt(start);
}

if (torch::is_symint(r->stop)) {
stop_sym = py::handle(r->stop).cast<c10::SymInt>();
} else if (r->stop == Py_None) {
stop_sym = c10::SymInt(
step_sym < 0 ? c10::SymInt::min_representable_int() : PY_SSIZE_T_MAX);
} else {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t stop;
if (!_PyEval_SliceIndex(r->stop, &stop)) {
throw python_error();
}
stop = clip_val(stop);
stop_sym = c10::SymInt(stop);
}

return UnpackedSlice{
std::move(start_sym), std::move(stop_sym), std::move(step_sym)};
}

Py_ssize_t THPVariable_length(PyObject* self);
PyObject* THPVariable_getitem(PyObject* self, PyObject* index);
int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* value);
Expand Down

0 comments on commit 7ee2cf9

Please sign in to comment.