Skip to content

Commit

Permalink
Fix layer norm
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisVieriu97 committed Feb 9, 2023
1 parent 4c94a24 commit 604b8ad
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 2 additions & 4 deletions aten/src/ATen/native/mps/operations/Normalization.mm
Original file line number Diff line number Diff line change
Expand Up @@ -926,8 +926,6 @@ string get_mem_string(c10::MemoryFormat memory_format) {
at::Tensor mean = std::get<1>(outputs);
at::Tensor variance = std::get<2>(outputs);

at::Tensor rstd = at::rsqrt(at::add(variance, eps));

std::vector<int64_t> stat_shape;
for (const auto idx : c10::irange(axis)) {
stat_shape.push_back(input_shape[idx]);
Expand All @@ -937,8 +935,8 @@ string get_mem_string(c10::MemoryFormat memory_format) {
stat_shape.push_back(1);
}
mean = mean.view(stat_shape);
rstd = rstd.view(stat_shape);
return std::make_tuple(out, mean, rstd);
variance = variance.view(stat_shape);
return std::make_tuple(out, mean, variance);
}

std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(
Expand Down
4 changes: 4 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8972,6 +8972,7 @@ class TestConsistency(TestCase):
'triangular_solve': ['f32'],
'_native_batch_norm_legit': ['f32'],
'native_batch_norm': ['f32'],
'native_layer_norm': ['f32'],
}

# These ops that are problematic. So never run them even when
Expand Down Expand Up @@ -9187,6 +9188,9 @@ def get_samples():
elif (op.name == "masked.mean"):
atol = 7e-4
rtol = 2e-3
elif (op.name == "native_layer_norm"):
atol = 1e-4
rtol = 1.3e-5
else:
atol = None
rtol = None
Expand Down

0 comments on commit 604b8ad

Please sign in to comment.