From 3817de5d840bdff3f11ee23782494b5a13ae2001 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 30 Aug 2023 13:24:26 -0700 Subject: [PATCH] Fix layernorm cpu precision issues (#108089) #108072 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108089 Approved by: https://github.com/mingfeima, https://github.com/albanD --- aten/src/ATen/native/cpu/layer_norm_kernel.cpp | 6 +++--- test/test_nn.py | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp index 3171f3ff04fe1..a0c3e0955e17a 100644 --- a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp @@ -55,17 +55,17 @@ void LayerNormKernelImplInternal( std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N); rstd_val = T(1) / std::sqrt(rstd_val + eps); const T scale = rstd_val; - const T bias = -rstd_val * mean_val; + const T bias = - mean_val; if (gamma_null || beta_null) { for (const auto j : c10::irange(N)) { const T gamma_v = gamma_null ? T(1) : gamma_data[j]; const T beta_v = beta_null ? T(0) : beta_data[j]; - Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v; + Y_ptr[j] = (X_ptr[j] + bias) * rstd_val * gamma_v + beta_v; } } else { vec::map3( [scale, bias](Vec x, Vec gamma, Vec beta) { - return (x * Vec(scale) + Vec(bias)) * gamma + beta; + return (x + Vec(bias)) * Vec(scale) * gamma + beta; }, Y_ptr, X_ptr, diff --git a/test/test_nn.py b/test/test_nn.py index ad201dac7b3d7..c03b0221368c4 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7273,6 +7273,12 @@ def test_layer_norm_grads_with_create_graph_flag(self): self.assertEqual(grads1, grads2, rtol=rtol, atol=atol) + def test_layer_norm_eps(self): + # test for https://github.com/pytorch/pytorch/issues/108072 + x = torch.Tensor([[[2.0, 2.0], [14.0, 14.0]], [[2.0, 2.0], [14.0, 14.0]]]) + ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False) + self.assertEqual(ln.forward(x), torch.zeros_like(x)) + def test_padding_list(self): # Padding can be a list, or tuple (regression test for gh-54452) x = torch.randn(4, 8, 32, 32)