From e810bcb23bd897eab4e126807721c2e1084a7949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20=C5=9Euhan?= Date: Tue, 2 Apr 2019 13:21:09 -0700 Subject: [PATCH] Take the absolute value of a norm argument --- test/cpp/test_aten_xla_tensor.cpp | 2 +- torch_xla/csrc/ops/ops.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 352c298d70dc..ca7de62ffccf 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -1091,7 +1091,7 @@ TEST_F(AtenXlaTensorTest, TestNormInDimsKeep) { } TEST_F(AtenXlaTensorTest, TestNormGeneral) { - at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)); + at::Tensor a = at::randn({4, 3, 4}, at::TensorOptions(at::kFloat)); at::Tensor b = at::norm(a, 3.5); ForEachDevice([&](const Device& device) { at::Tensor xla_a = bridge::CreateXlaTensor(a, device); diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index a2d9c913abf4..b3c3aecda5a1 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -481,7 +481,7 @@ NodePtr Norm(const Value& input, c10::optional p, NodePtr norm_exp = ScalarOp(norm_value, input.shape().element_type()); NodePtr norm_exp_inv = ScalarOp(1.0 / norm_value, input.shape().element_type()); - NodePtr exp = Pow(input, norm_exp); + NodePtr exp = Pow(Abs(input), norm_exp); NodePtr result = MakeNode(exp, dimensions, keepdim, dtype); return Pow(result, norm_exp_inv); }