-
Notifications
You must be signed in to change notification settings - Fork 21.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NT] Make NestedTensor register as having symbolic sizes/strides (#12…
…4687) Fixes #123698 This PR makes TensorImpl::has_symbolic_sizes_strides return false for NestedTensors. 1. It passes in the actual sizes when we call `_make_wrapper_subclass` - this is the change that makes the subclass register as `has_symbolic_sizes_strides() == True` 2. It adds a field to `_make_wrapper_subclass` where an explicit `numel` can be provided. This allows us to skip the numel computation for the storage, which previously fails due to arithmetic on NestedInts. 3. Implements `aten::numel` for NJT - this is separate from the overridden numel in `make_wrapper_subclass` for now. Note also that this means that we leave `dispatch_sizes_strides_policy="sizes"`, so that we call into the custom `numel` implementation (as well as `sizes` and `strides`), because `numel` cannot currently be computed from `sizes` for NJT. Note also that this depends on #121361, because calling TensorImpl::set_sizes_and_strides() tries to clone the sizes into the tensor, which means that we need `clone` to be implemented on NestedInt. Differential Revision: [D57225736](https://our.internmc.facebook.com/intern/diff/D57225736) Pull Request resolved: #124687 Approved by: https://github.com/albanD
- Loading branch information
1 parent
96bdb7a
commit 82edc8b
Showing
6 changed files
with
59 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters