Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support randn_like() for NT #96528

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4439,7 +4439,7 @@
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd: randn_like
CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like
autogen: randn_like.out

- func: randperm(int n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
Expand Down Expand Up @@ -9646,6 +9646,7 @@
MPS: normal_mps_
Meta: normal_meta_
SparseCsrCPU, SparseCsrCUDA: normal_sparse_csr_
NestedTensorCPU, NestedTensorCUDA: normal_nested_
autogen: normal.out

# Only used by the functionalization pass.
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,5 +970,11 @@ Tensor reshape_as_nested(const Tensor& self, const Tensor& other) {
return self.reshape(sizes);
}

Tensor& normal_nested_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
self_buf.normal_(mean, std, gen);
return self;
}

} // namespace native
} // namespace at
1 change: 1 addition & 0 deletions docs/source/nested.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,5 @@ NestedTensor and any constraints they have.
:func:`torch.transpose`; "Supports transposing of all dims except ``dim=0``."
:func:`torch.Tensor.view`; "Rules for the new shape are similar to that of ``reshape``."
:func:`torch.empty_like`; "Behavior is analogous to that of regular tensors; returns a new empty nested tensor (i.e. with uninitialized values) matching the nested structure of the input."
:func:`torch.randn_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with values randomly initialized according to a standard normal distribution matching the nested structure of the input."
:func:`torch.zeros_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with all zero values matching the nested structure of the input."
4 changes: 3 additions & 1 deletion test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,15 @@ def test_zero_(self):
t.fill_(0.)
self.assertEqual(nt_ub, t)

@parametrize("func", [torch.ones_like, torch.zeros_like],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This currently passes, but I can imagine an alternative implementation of randn_like() for NT that doesn't have the property of "generating all random numbers at once -> gives same numbers as iteratively generating numbers matching underlying component sizes", and it could be argued this is no less correct. So I'm open to better forms of testing.

@parametrize("func", [torch.ones_like, torch.zeros_like, torch.randn_like],
name_fn=lambda f: f.__name__)
def test_like_functions(self, func):
ntensors = 4
nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4))
torch.manual_seed(1)
nt_like = func(nt)

torch.manual_seed(1)
for nt_ub in nt_like.unbind():
t_like = func(nt_ub)
self.assertEqual(nt_ub, t_like)
Expand Down