Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] add aten::normal.Tensor_float aten::normal.float_Tensor aten::normal.Tensor_Tensor #80297

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions aten/src/ATen/native/mps/operations/Distributions.mm
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,62 @@
return normal_mps_out(mean_t, std_t, gen, self);
}

Tensor normal_mps(const Tensor& mean, double std, c10::optional<Generator> gen) {
Tensor output = empty_mps(
mean.sizes(),
mean.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);

Tensor std_t = empty_mps(
output.sizes(),
output.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
std_t.fill_(std);

return normal_mps_out(mean, std_t, gen, output);
}

Tensor normal_mps(double mean, const Tensor& std, c10::optional<Generator> gen) {
Tensor output = empty_mps(
std.sizes(),
std.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);

Tensor mean_t = empty_mps(
output.sizes(),
output.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
mean_t.fill_(mean);

return normal_mps_out(mean_t, std, gen, output);
}

Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
auto shape = at::infer_size(mean.sizes(), std.sizes());

Tensor output = empty_mps(
shape,
mean.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);

return normal_mps_out(mean, std, gen, output);
}

Tensor& normal_mps_out(const Tensor& mean, double std, c10::optional<Generator> gen, Tensor& output) {
TORCH_CHECK(std >= 0.0, "normal_mps_out expects std >= 0.0, but found std=", std);

Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8388,7 +8388,7 @@
- func: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor
dispatch:
CPU, CUDA: normal
#MPS: normal_mps
MPS: normal_mps
Meta: normal_meta

- func: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
Expand All @@ -8400,8 +8400,8 @@
- func: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor
dispatch:
CPU, CUDA: normal
MPS: normal_mps
Meta: normal_meta
#MPS: normal_mps

- func: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand All @@ -8412,8 +8412,8 @@
- func: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
dispatch:
CPU, CUDA: normal
MPS: normal_mps
Meta: normal_meta
#MPS: normal_mps

- func: normal.float_float(float mean, float std, int[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

Expand Down
18 changes: 12 additions & 6 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4145,9 +4145,6 @@ def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
# Test normal
def test_normal(self):
def helper(shape, mean=0.0, std=1.0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')

mps_out = torch.normal(mean, std, shape, device='mps')

mean_array = np.ones(shape)
Expand All @@ -4160,6 +4157,7 @@ def helper(shape, mean=0.0, std=1.0):
cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
std_tensor = cpu_std_tensor.detach().clone().to('mps')

# test out
mps_out = torch.zeros(shape, device='mps')
torch.normal(mean_tensor, std, out=mps_out)

Expand All @@ -4169,14 +4167,22 @@ def helper(shape, mean=0.0, std=1.0):
mps_out = torch.zeros(shape, device='mps')
torch.normal(mean_tensor, std_tensor, out=mps_out)

# test without out
mps_out = torch.normal(mean_tensor, std)
self.assertEqual(mps_out.size(), mean_tensor.size())

mps_out = torch.normal(mean, std_tensor)
self.assertEqual(mps_out.size(), std_tensor.size())

inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size())
mps_out = torch.normal(mean_tensor, std_tensor)
self.assertEqual(mps_out.size(), inferred_shape)

helper((2, 3, 4, 5, 6))
helper((100, 100), 2.5, 1.2)

def test_bernoulli(self):
def helper(shape, prob=0.5):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')

prob_array = np.ones(shape)
prob_array *= prob
cpu_prob_tensor = torch.tensor(prob_array, device='cpu', dtype=torch.float, requires_grad=False)
Expand Down