-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Symintify pytorch slicing logic #91340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
tugsbayasgalan
wants to merge
18
commits into
gh/tugsbayasgalan/86/base
from
gh/tugsbayasgalan/86/head
Closed
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
ae8889f
Symintify pytorch slicing logic
tugsbayasgalan 88ac30a
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 99540eb
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 85043a8
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 00ed450
Update on "Symintify pytorch slicing logic"
tugsbayasgalan aeba29c
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 0691dc8
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 7052a80
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 8555d26
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 4bd8ffe
Update on "Symintify pytorch slicing logic"
tugsbayasgalan a6d5338
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 8423771
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 5db8aa5
Update on "Symintify pytorch slicing logic"
tugsbayasgalan acdb2d7
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 92e1382
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 3d9bb36
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 1cdcd7e
Update on "Symintify pytorch slicing logic"
tugsbayasgalan 18a466e
Update on "Symintify pytorch slicing logic"
tugsbayasgalan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,8 +23,8 @@ | |
| namespace at { | ||
| namespace indexing { | ||
|
|
||
| const int64_t INDEX_MAX = std::numeric_limits<int64_t>::max(); | ||
| const int64_t INDEX_MIN = std::numeric_limits<int64_t>::min(); | ||
| const int64_t INDEX_MIN = c10::SymInt::min_representable_int(); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really tiny nit, this should be constexpr right? |
||
| const int64_t INDEX_MAX = -(INDEX_MIN + 1); | ||
|
|
||
| enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor }; | ||
|
|
||
|
|
@@ -37,52 +37,47 @@ TORCH_API extern const EllipsisIndexType Ellipsis; | |
|
|
||
| struct TORCH_API Slice final { | ||
| public: | ||
| // This mirrors `__PySlice_Unpack` in torch/csrc/utils/python_compat.h | ||
| Slice( | ||
| c10::optional<int64_t> start_index = c10::nullopt, | ||
| c10::optional<int64_t> stop_index = c10::nullopt, | ||
| c10::optional<int64_t> step_index = c10::nullopt) { | ||
| c10::optional<c10::SymInt> start_index = c10::nullopt, | ||
| c10::optional<c10::SymInt> stop_index = c10::nullopt, | ||
| c10::optional<c10::SymInt> step_index = c10::nullopt) { | ||
| if (!step_index.has_value()) { | ||
| step_ = 1; | ||
| step_ = c10::SymInt(1); | ||
| } else { | ||
| step_ = step_index.value(); | ||
| TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero"); | ||
|
|
||
| // Here step might be -INDEX_MAX-1; in this case we replace it | ||
| // with -INDEX_MAX. This doesn't affect the semantics, and it | ||
| // guards against later undefined behaviour resulting from code that | ||
| // does "step = -step" as part of a slice reversal. | ||
| if (step_ < -INDEX_MAX) | ||
albanD marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| step_ = -INDEX_MAX; | ||
| step_ = std::move(step_index).value(); | ||
| } | ||
|
|
||
| TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero"); | ||
|
|
||
| if (!start_index.has_value()) { | ||
| start_ = step_ < 0 ? INDEX_MAX : 0; | ||
| start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0); | ||
| } else { | ||
| start_ = start_index.value(); | ||
| start_ = std::move(start_index).value(); | ||
| } | ||
|
|
||
| if (!stop_index.has_value()) { | ||
| stop_ = step_ < 0 ? INDEX_MIN : INDEX_MAX; | ||
| stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX); | ||
| } else { | ||
| stop_ = stop_index.value(); | ||
| stop_ = std::move(stop_index).value(); | ||
| } | ||
| } | ||
|
|
||
| inline int64_t start() const { | ||
| inline c10::SymInt start() const { | ||
| return start_; | ||
| } | ||
|
|
||
| inline int64_t stop() const { | ||
| inline c10::SymInt stop() const { | ||
| return stop_; | ||
| } | ||
|
|
||
| inline int64_t step() const { | ||
| inline c10::SymInt step() const { | ||
| return step_; | ||
| } | ||
|
|
||
| private: | ||
| int64_t start_; | ||
| int64_t stop_; | ||
| int64_t step_; | ||
| c10::SymInt start_; | ||
| c10::SymInt stop_; | ||
| c10::SymInt step_; | ||
| }; | ||
|
|
||
| TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice); | ||
|
|
@@ -213,9 +208,9 @@ namespace impl { | |
| static inline Tensor applySlice( | ||
| const Tensor& self, | ||
| int64_t dim, | ||
| int64_t start, | ||
| int64_t stop, | ||
| int64_t step, | ||
| c10::SymInt start, | ||
| c10::SymInt stop, | ||
| c10::SymInt step, | ||
| bool disable_slice_optimization, | ||
| const at::Device& self_device, | ||
| const c10::optional<SymIntArrayRef>& self_sizes) { | ||
|
|
@@ -235,7 +230,7 @@ static inline Tensor applySlice( | |
| return self; | ||
| } | ||
| } | ||
| return self.slice(dim, start, stop, step); | ||
| return self.slice_symint(dim, start, stop, step); | ||
| } | ||
|
|
||
| static inline Tensor applySelect( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,96 @@ | ||
| #pragma once | ||
|
|
||
| #include <c10/core/SymInt.h> | ||
albanD marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #include <torch/csrc/autograd/python_variable.h> | ||
| #include <torch/csrc/python_headers.h> | ||
| #include <torch/csrc/utils/pybind.h> | ||
| #include <torch/csrc/utils/python_symnode.h> | ||
|
|
||
| namespace torch { | ||
| namespace autograd { | ||
|
|
||
| struct UnpackedSlice { | ||
| c10::SymInt start; | ||
| c10::SymInt stop; | ||
| c10::SymInt step; | ||
| }; | ||
|
|
||
| // This mirrors Cpython's PySlice_Unpack method | ||
| static inline UnpackedSlice __PySlice_Unpack(PyObject* _r) { | ||
| PySliceObject* r = (PySliceObject*)_r; | ||
| /* this is harder to get right than you might think */ | ||
|
|
||
| c10::SymInt start_sym, stop_sym, step_sym; | ||
|
|
||
| auto clip_val = [](Py_ssize_t val) { | ||
| if (val < c10::SymInt::min_representable_int()) { | ||
| auto r = PyErr_WarnEx( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TORCH_WARN("The message"); |
||
| PyExc_UserWarning, | ||
| "Truncating the start/stop/step " | ||
| "of slice. This is likely because of " | ||
| "saved old models when the start/stop/step were larger.", | ||
| 1); | ||
| if (r != 0) { | ||
| throw python_error(); | ||
| } | ||
| return (Py_ssize_t)(c10::SymInt::min_representable_int()); | ||
| } | ||
| return val; | ||
| }; | ||
|
|
||
| if (r->step == Py_None) { | ||
| step_sym = c10::SymInt(1); | ||
| } else { | ||
| if (torch::is_symint(r->step)) { | ||
| auto step_sym = py::handle(r->step).cast<c10::SymInt>(); | ||
| } else { | ||
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | ||
| Py_ssize_t step; | ||
| if (!_PyEval_SliceIndex(r->step, &step)) { | ||
| throw python_error(); | ||
| } | ||
| if (step == 0) { | ||
| PyErr_SetString(PyExc_ValueError, "slice step cannot be zero"); | ||
| } | ||
|
|
||
| step = clip_val(step); | ||
| step_sym = c10::SymInt(step); | ||
| } | ||
| } | ||
|
|
||
| if (torch::is_symint(r->start)) { | ||
| start_sym = py::handle(r->start).cast<c10::SymInt>(); | ||
| } else if (r->start == Py_None) { | ||
| start_sym = c10::SymInt(step_sym < 0 ? PY_SSIZE_T_MAX : 0); | ||
| } else { | ||
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | ||
| Py_ssize_t start; | ||
| if (!_PyEval_SliceIndex(r->start, &start)) { | ||
| throw python_error(); | ||
| } | ||
| start = clip_val(start); | ||
| start_sym = c10::SymInt(start); | ||
| } | ||
|
|
||
| if (torch::is_symint(r->stop)) { | ||
| stop_sym = py::handle(r->stop).cast<c10::SymInt>(); | ||
| } else if (r->stop == Py_None) { | ||
| stop_sym = c10::SymInt( | ||
| step_sym < 0 ? c10::SymInt::min_representable_int() : PY_SSIZE_T_MAX); | ||
| } else { | ||
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | ||
| Py_ssize_t stop; | ||
| if (!_PyEval_SliceIndex(r->stop, &stop)) { | ||
| throw python_error(); | ||
| } | ||
| stop = clip_val(stop); | ||
| stop_sym = c10::SymInt(stop); | ||
| } | ||
|
|
||
| return UnpackedSlice{ | ||
| std::move(start_sym), std::move(stop_sym), std::move(step_sym)}; | ||
| } | ||
|
|
||
| Py_ssize_t THPVariable_length(PyObject* self); | ||
| PyObject* THPVariable_getitem(PyObject* self, PyObject* index); | ||
| int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* value); | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.