Skip to content

Commit

Permalink
Fix sym_{sizes,strides} slow path (#107839)
Browse files Browse the repository at this point in the history
Previously, when SymInt is returned from sym_sizes slow path, it would segfault.

This is useful for tensors that have symbolic sizes and use the sym_sizes slow path, e.g. NestedTensor returning SingletonSymInt as its sizes in the slow path.

See also: https://github.com/pytorch/pytorch/pull/106405/files#r1303714865
Pull Request resolved: #107839
Approved by: https://github.com/ezyang
  • Loading branch information
soulitzer authored and pytorchmergebot committed Aug 24, 2023
1 parent 35de780 commit f6cce3c
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 64 deletions.
34 changes: 34 additions & 0 deletions test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,40 @@ def test_standard_is_not_subclass(self):
# https://github.com/pytorch/pytorch/issues/79079
self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0)))

def test_sym_sizes_strides_slow_path(self):
class TestTensor(torch.Tensor):
@staticmethod
def __new__(cls, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, (0,), dispatch_sizes_strides_policy="sizes")
return r

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func in (
torch.ops.aten.sym_size.default,
torch.ops.aten.sym_stride.default
):
from torch._dynamo.source import ConstantSource
from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic
shape_env = ShapeEnv()
si = shape_env.create_symintnode(
shape_env.create_symbol(
123,
source=ConstantSource("abc"),
dynamic_dim=DimDynamic.DUCK,
constraint_dim=None,
),
hint=123
)
return (si,)

t = TestTensor()
si = t.size()[0]
self.assertIsInstance(si, torch.SymInt)
si = t.stride()[0]
self.assertIsInstance(si, torch.SymInt)

def test_strides_slow_path(self):
for use_wrapper_subclass in [True, False]:
class StridesNotImplemented(torch.Tensor):
Expand Down
101 changes: 48 additions & 53 deletions torch/csrc/PyInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,28 +587,22 @@ c10::IntArrayRef ConcretePyInterpreterVTable::strides(
return c10::IntArrayRef(start, len);
}

static std::vector<int64_t> values_from_buffer(
const c10::TensorImpl* self,
py::handle values) {
c10::TensorImpl* ptr = const_cast<c10::TensorImpl*>(self);
static void set_tensor_attr_with_capsule(
c10::TensorImpl* tensor,
py::capsule& capsule,
const char* attr_name) {
c10::optional<PyObject*> mb_obj =
ptr->pyobj_slot()->check_pyobj(getPyInterpreter());
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");

py::object os = py::module_::import("torch").attr("overrides");
py::function get_buffer =
py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
auto buffer = get_buffer(py::handle(*mb_obj), values, "size");
auto result = THPUtils_unpackLongs(buffer.ptr());
return result;
py::handle(mb_obj.value()).attr(attr_name) = capsule;
}

c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;

HANDLE_TH_ERRORS
auto out = torchDispatchFromTensorImpl(
self,
"size",
Expand All @@ -619,20 +613,27 @@ c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
.attr("default")
.ptr(),
"torch.ops.aten");

if (out.is_none()) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call sizes on a tensor with symbolic shapes/strides");
return self->sizes_default();
}

py::object values = py::reinterpret_steal<py::object>(out.ptr());
auto result = values_from_buffer(self, values);
int64_t* start = (int64_t*)result[0];
int64_t len = result[1];

return c10::IntArrayRef(start, len);
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"sizes must be a list or a tuple");
int64_t len = py::len(out);
int64_t* ptr = new int64_t[len];
auto capsule =
py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<int64_t*>(p); });
int64_t idx = 0;
for (auto it = out.begin(); it != out.end(); ++it, ++idx) {
ptr[idx] = py::cast<int64_t>(*it);
}
set_tensor_attr_with_capsule(
const_cast<c10::TensorImpl*>(self), capsule, "_sizes_capsule");
return c10::IntArrayRef(ptr, len);
END_HANDLE_TH_ERRORS_PYBIND
}

c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes(
Expand All @@ -654,24 +655,20 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes(
if (out.is_none()) {
return self->sym_sizes_default();
}
// We need to squeeze SymIntNodes and ints into `SymInts`
// since it's a format `sym_sizes()` are stored in
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"Symshape must be a list or a tuple");
py::list symints;
for (auto it = out.begin(); it != out.end(); it++) {
auto elm = *it;
auto si = py::cast<c10::SymInt>(elm);
// TODO: the buffer will need to be made owning later
symints.append(si.as_int_unchecked());
}

auto result = values_from_buffer(self, symints);
c10::SymInt* start = (c10::SymInt*)result[0];
int64_t len = result[1];

return c10::SymIntArrayRef(start, len);
"sym_size must be a list or a tuple");
int64_t len = py::len(out);
c10::SymInt* ptr = new c10::SymInt[len];
auto capsule = py::capsule(
ptr, [](void* p) { delete[] reinterpret_cast<c10::SymInt*>(p); });
int64_t idx = 0;
for (auto it = out.begin(); it != out.end(); ++it, ++idx) {
ptr[idx] = py::cast<c10::SymInt>(*it);
}
set_tensor_attr_with_capsule(
const_cast<c10::TensorImpl*>(self), capsule, "_sym_sizes_capsule");
return c10::SymIntArrayRef(ptr, len);
END_HANDLE_TH_ERRORS_PYBIND
}

Expand Down Expand Up @@ -769,25 +766,21 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
// since it's a format `sym_strides()` are stored in
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"Symshape must be a list or a tuple");
py::list symints;
for (auto it = out.begin(); it != out.end(); it++) {
auto elm = *it;
auto si = torch::is_symint(elm) ? elm.cast<c10::SymInt>()
: c10::SymInt{py::cast<int64_t>(elm)};
symints.append(si.as_int_unchecked());
}

auto result = values_from_buffer(self, symints);
c10::SymInt* start = (c10::SymInt*)result[0];
int64_t len = result[1];

return c10::SymIntArrayRef(start, len);
"sym_strides must be a list or a tuple");
int64_t len = py::len(out);
c10::SymInt* ptr = new c10::SymInt[len];
auto capsule = py::capsule(
ptr, [](void* p) { delete[] reinterpret_cast<c10::SymInt*>(p); });
int64_t idx = 0;
for (auto it = out.begin(); it != out.end(); ++it, ++idx) {
ptr[idx] = py::cast<c10::SymInt>(*it);
}
set_tensor_attr_with_capsule(
const_cast<c10::TensorImpl*>(self), capsule, "_sym_strides_capsule");
return c10::SymIntArrayRef(ptr, len);
END_HANDLE_TH_ERRORS_PYBIND
}

PyInterpreterHolder self_interpreter;

void ConcretePyInterpreterVTable::reset_backward_hooks(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
Expand All @@ -802,6 +795,8 @@ void ConcretePyInterpreterVTable::reset_backward_hooks(
END_HANDLE_TH_ERRORS_PYBIND
}

PyInterpreterHolder self_interpreter;

} // anonymous namespace

c10::impl::PyInterpreter* getPyInterpreter() {
Expand Down
11 changes: 0 additions & 11 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
"is_tensor_method_or_property",
"wrap_torch_function",
"enable_reentrant_dispatch",
"get_buffer",
]

@functools.lru_cache(None)
Expand Down Expand Up @@ -1907,13 +1906,3 @@ def enable_reentrant_dispatch():
yield
finally:
pass

def get_buffer(tensor_subclass, data, prefix):
import ctypes
assert prefix in {"stride", "size", "sym_size"}
buffer_name = f"_{prefix}_buffer"
if not hasattr(tensor_subclass, buffer_name):
SizeType = ctypes.c_longlong * len(data)
setattr(tensor_subclass, buffer_name, SizeType(*data))
ptr = ctypes.addressof(getattr(tensor_subclass, buffer_name))
return (ptr, len(data))

0 comments on commit f6cce3c

Please sign in to comment.