Skip to content

Commit

Permalink
[NT] Make NestedTensor register as having symbolic sizes/strides (#12…
Browse files Browse the repository at this point in the history
…4687)

Fixes #123698

This PR makes TensorImpl::has_symbolic_sizes_strides return false for NestedTensors.

1. It passes in the actual sizes when we call `_make_wrapper_subclass` - this is the change that makes the subclass register as `has_symbolic_sizes_strides() == True`
2. It adds a field to `_make_wrapper_subclass` where an explicit `numel` can be provided. This allows us to skip the numel computation for the storage, which previously fails due to arithmetic on NestedInts.
3. Implements `aten::numel` for NJT - this is separate from the overridden numel in `make_wrapper_subclass` for now. Note also that this means that we leave `dispatch_sizes_strides_policy="sizes"`, so that we call into the custom `numel` implementation (as well as `sizes` and `strides`), because `numel` cannot currently be computed from `sizes` for NJT.

Note also that this depends on #121361, because calling TensorImpl::set_sizes_and_strides() tries to clone the sizes into the tensor, which means that we need `clone` to be implemented on NestedInt.

Differential Revision: [D57225736](https://our.internmc.facebook.com/intern/diff/D57225736)
Pull Request resolved: #124687
Approved by: https://github.com/albanD
  • Loading branch information
davidberard98 authored and pytorchmergebot committed May 13, 2024
1 parent 96bdb7a commit 82edc8b
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 40 deletions.
12 changes: 0 additions & 12 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2445,18 +2445,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return has_symbolic_sizes_strides_;
}

// if this returns true, then it is guaranteed that this tensor does NOT have
// symbolic sizes/strides. This is different from the above, because it's
// possible that has_symbolic_sizes_strides() returns false, but we do
// not have symbolic sizes/strides. This exists for the case of
// Nested Tensor python subclass, where the sizes are implemented in python
// (TODO: clean this up and just implement sizes in nested tensor without a
// python implementation)
bool does_not_have_symbolic_sizes_strides() const {
return !has_symbolic_sizes_strides() &&
!matches_policy(SizesStridesPolicy::CustomStrides);
}

private:
void HandleResize();

Expand Down
1 change: 1 addition & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3070,6 +3070,7 @@ def test_tensor_attributes(self, device):
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.sym_size.default,
torch.ops.aten.dim.default,
torch.ops.aten.numel.default,
torch.ops.aten.sym_numel.default,
torch.ops.aten.sym_stride.default,
torch.ops.aten.sym_storage_offset.default,
Expand Down
27 changes: 24 additions & 3 deletions torch/csrc/autograd/python_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,14 +746,31 @@ static PyObject* THPVariable_make_wrapper_subclass(
HANDLE_TH_ERRORS
// NB: pin_memory doesn't actually do anything
// TODO: strides variant?

// cls: Python subclass type
// size, strides, storage_offset, memory_format, dtype: self-explanatory
// layout: memory layout, e.g. for types of Nested Tensors or other sparse
// tensors
// pin_memory, requires_grad: self-explanatory
// dispatch_sizes_strides_policy: string - which sizes/strides we should
// dispatch to a custom python implementation.
// dispatch_device: whether to dispatch to a custom python implementation
// for device
// dispatch_layout: whether to dispatch to a custom python implementation
// for layout
// _extra_dispatch_keys: additional dispatch keys to add to the tensor
// storage_size: if provided, skip storage size calculation and just use the
// value provided. One use case is for Nested Tensor, where the
// storage size cannot be calculated from the sizes/strides
// (because they contain a NestedInt).
static PythonArgParser parser({
"_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef? strides=None, "
"SymInt? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
"Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
"c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, "
"DispatchKeySet _extra_dispatch_keys=None)",
"DispatchKeySet _extra_dispatch_keys=None, SymInt? storage_size=None)",
});
ParsedArgs<14> parsed_args{};
ParsedArgs<15> parsed_args{};
auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);

Expand Down Expand Up @@ -803,7 +820,11 @@ static PyObject* THPVariable_make_wrapper_subclass(

c10::SymInt size_bytes;
auto dtype_itemsize = static_cast<int64_t>(options.dtype().itemsize());
if (sym_strides.has_value()) {
auto storage_size = r.toSymIntOptional(14);

if (storage_size.has_value()) {
size_bytes = storage_size.value();
} else if (sym_strides.has_value()) {
size_bytes = at::detail::computeStorageNbytes(
sym_sizes,
sym_strides.value(),
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/profiler/standalone/execution_trace_observer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ inline std::string getValueShape(
if (val.isTensor()) {
auto& tensor = val.toTensor();
if (tensor.defined() &&
tensor.unsafeGetTensorImpl()->does_not_have_symbolic_sizes_strides()) {
!tensor.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
return vectorToString(tensor.sizes().vec());
}
} else if (val.isTuple()) {
Expand Down Expand Up @@ -396,7 +396,7 @@ inline std::string convertIValue(
size_t itemsize = 0;
std::string device_str = "";
// symbolic sizes/strides implies t->storage_offset() will fail
if (t->has_storage() && t->does_not_have_symbolic_sizes_strides()) {
if (t->has_storage() && !t->has_symbolic_sizes_strides()) {
auto& t_storage = t->storage();
storage_id = getObjectID(ob, t_storage.data());
offset = t->storage_offset();
Expand Down
52 changes: 30 additions & 22 deletions torch/nested/_internal/nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,32 @@ def __new__(
):
ks = DispatchKeySet(DispatchKey.NestedTensor)
ks = ks.add(DispatchKey.AutogradNestedTensor)

# Only support jagged for now.
assert offsets is not None
assert offsets.ndim == 1
assert not isinstance(values, NestedTensor)
assert values.device == offsets.device

# Query cache for the symint associated with offsets or lengths
# (create a new one if needed).
ragged_source = offsets if lengths is None else lengths
ragged_size = get_tensor_symint(ragged_source, coeff=1)
_ragged_idx = kwargs.get("_ragged_idx", 1)
B = offsets.shape[0] - 1
if lengths is not None:
assert B == lengths.shape[0]

# subtract 1 to convert to values dim space
r = _ragged_idx - 1
_size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :])
stride = values.stride()
_strides = (ragged_size * stride[r], *stride)

r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
(0,),
(0,),
_size,
_strides,
0,
torch.contiguous_format,
values.dtype,
Expand All @@ -76,31 +98,17 @@ def __new__(
False,
True, # dispatch_layout
ks,
# don't try to calculate storage based on non-zero size
storage_size=values.untyped_storage().size(),
)
r._ragged_idx = _ragged_idx
r._size = _size
r._strides = _strides

return r

def __init__(self, values, offsets, *, lengths=None, **kwargs):
super().__init__()
# Only support jagged for now.
assert offsets is not None
assert offsets.ndim == 1
assert not isinstance(values, NestedTensor)
assert values.device == offsets.device

# Query cache for the symint associated with offsets or lengths
# (create a new one if needed).
ragged_source = offsets if lengths is None else lengths
ragged_size = get_tensor_symint(ragged_source, coeff=1)
self._ragged_idx = kwargs.get("_ragged_idx", 1)
B = offsets.shape[0] - 1
if lengths is not None:
assert B == lengths.shape[0]

# subtract 1 to convert to values dim space
r = self._ragged_idx - 1
self._size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :])
stride = values.stride()
self._strides = (ragged_size * stride[r], *stride)

self._values = values
self._offsets = offsets
Expand Down
3 changes: 2 additions & 1 deletion torch/nested/_internal/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _flatten_sig(input, start_dim=0, end_dim=-1):
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.sym_size.default,
torch.ops.aten.dim.default,
torch.ops.aten.numel.default,
torch.ops.aten.sym_numel.default,
torch.ops.aten.sym_stride.default,
torch.ops.aten.sym_storage_offset.default,
Expand All @@ -345,7 +346,7 @@ def tensor_attr_supported_getter(func, *args, **kwargs):
if func == torch.ops.aten.dim.default:
return len(args[0]._size)

if func == torch.ops.aten.sym_numel.default:
if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default):
if args[0]._lengths is not None:
return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
return args[0]._values.numel()
Expand Down

0 comments on commit 82edc8b

Please sign in to comment.