Skip to content

Conversation

zeshengzong
Copy link
Contributor

@zeshengzong zeshengzong commented Jan 2, 2025

Fixes #122886

  1. Enable torch.normal working with DeviceContext to get default device which set via set_default_device.
  2. Add hint in set_default_device doc, suggest use torch.Tensor.to method move to desired device explicitly.

Test Result

  1. Doc Preview
    image

  2. Local Test

>>> import torch
>>> torch.normal(0.,1., (10,10)).device
device(type='cpu')
>>> torch.set_default_device('cuda')
>>> torch.normal(0.,1., (10,10)).device
device(type='cuda', index=0)
pytest test/test_tensor_creation_ops.py

image

lintrunner

image

cc @fritzo @neerajprad @alicanb @nikitaved

Copy link

pytorch-bot bot commented Jan 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144070

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 7fcad7e with merge base e141cb9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zeshengzong
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jan 2, 2025
@zeshengzong zeshengzong marked this pull request as ready for review January 3, 2025 07:34
torch.logspace,
torch.nested.nested_tensor,
# This function doesn't actually take a device argument
# torch.normal,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang Hello, need help reviewing restriction in here, seems it works now, any other need to be aware? Thanks!

Seems normal has device argument now

Tensor normal(
double mean,
double std,
IntArrayRef size,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options =
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
pin_memory);
auto result = at::empty(size, options);
return result.normal_(mean, std, std::move(generator));
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how useful this is but the following refs to normal don't include a device:

std::vector<c10::FunctionSchema> SchemaInfo::getNonDeterministicOps() {
// This list of nondeterministic ops is copied from JIT ir.cpp.
static const std::vector<std::string> nondeterministic_op_strings = {
"aten::dropout(Tensor input, float p, bool train) -> Tensor",
"aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
"aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
"aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
"aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
"aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
"aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
"aten::poisson(Tensor self, Generator? generator) -> Tensor",
"aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
"aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
"aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
std::vector<c10::FunctionSchema> nondeterministic_ops;
nondeterministic_ops.reserve(nondeterministic_op_strings.size());
for (const std::string& signature : nondeterministic_op_strings) {
nondeterministic_ops.emplace_back(torch::jit::parseSchema(signature));
}
return nondeterministic_ops;
}

"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",

"aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",

@cpuhrsch cpuhrsch requested a review from malfet January 7, 2025 06:21
@cpuhrsch cpuhrsch added module: distributions Related to torch.distributions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 7, 2025
@ezyang
Copy link
Contributor

ezyang commented Jan 10, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 10, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@vmoens
Copy link
Contributor

vmoens commented Jan 13, 2025

I believe this causes an issue with a cm is used:

import torch
with torch.device("cpu"):
    torch.normal(torch.zeros(()), torch.ones(()))

results in

Traceback (most recent call last):
  File "/Users/vmoens/Library/Application Support/JetBrains/PyCharm2023.3/scratches/scratch_8.py", line 7, in <module>
    torch.normal(torch.zeros(()), torch.ones(()))
  File "/Users/vmoens/venv/rl2/lib/python3.10/site-packages/torch/utils/_device.py", line 103, in __torch_function__
    return func(*args, **kwargs)
TypeError: normal() received an invalid combination of arguments - got (Tensor, Tensor, device=torch.device), but expected one of:
 * (Tensor mean, Tensor std, *, torch.Generator generator = None, Tensor out = None)
 * (Tensor mean, float std = 1, *, torch.Generator generator = None, Tensor out = None)
 * (float mean, Tensor std, *, torch.Generator generator = None, Tensor out = None)
 * (float mean, float std, tuple of ints size, *, torch.Generator generator = None, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)

Comment on lines +3385 to +3391
def test_normal_default_device(self, device):
try:
torch.set_default_device(device)
t = torch.normal(0, 1, (10, 10))
finally:
torch.set_default_device(None)
self.assertEqual(str(t.device), device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you set back the previous default device instead?

Like

Suggested change
def test_normal_default_device(self, device):
try:
torch.set_default_device(device)
t = torch.normal(0, 1, (10, 10))
finally:
torch.set_default_device(None)
self.assertEqual(str(t.device), device)
default_device = torch.default_device()
try:
torch.set_default_device(device)
t = torch.normal(0, 1, (10, 10))
finally:
torch.set_default_device(default_device)
self.assertEqual(str(t.device), device)

Also, note that this works

torch.normal(0, 1, (10, 10))

but this doesn't

torch.normal(0, 1)

@zeshengzong
Copy link
Contributor Author

@vmoens Hello, sorry about that, please helping revert this one if needed, and let me check how to fix it. Thanks!

@vmoens
Copy link
Contributor

vmoens commented Jan 14, 2025

Maybe @ezyang can advise - IMO if we can get a quick fix and additional tests it's not necessary to revert

@ezyang
Copy link
Contributor

ezyang commented Jan 14, 2025

@pytorchbot revert -c nosignal -m "broken a specific use case"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@zeshengzong your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Jan 14, 2025
This reverts commit 184549b.

Reverted #144070 on behalf of https://github.com/ezyang due to broken a specific use case ([comment](#144070 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Jan 14, 2025
@ezyang
Copy link
Contributor

ezyang commented Jan 14, 2025

@zeshengzong do you mind reopening a new PR? Thanks

@zeshengzong
Copy link
Contributor Author

Sure, thanks for help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: distributions Related to torch.distributions open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.normal ignores default_device

6 participants