Skip to content

Commit 63de0b5

Browse files
committed
Added aten::norm operations.
1 parent 6db4fff commit 63de0b5

File tree

7 files changed

+148
-2
lines changed

7 files changed

+148
-2
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,66 @@ TEST_F(AtenXlaTensorTest, TestSumInDimsKeep) {
831831
});
832832
}
833833

834+
TEST_F(AtenXlaTensorTest, TestNorm) {
835+
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
836+
at::Tensor b = at::norm(a);
837+
ForEachDevice([&](const Device& device) {
838+
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
839+
at::Tensor xla_b = at::norm(xla_a);
840+
AllClose(b, xla_b);
841+
});
842+
}
843+
844+
TEST_F(AtenXlaTensorTest, TestNormInDim) {
845+
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
846+
at::Tensor b = at::norm(a, 2, {1}, /*keepdim=*/false);
847+
ForEachDevice([&](const Device& device) {
848+
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
849+
at::Tensor xla_b = at::norm(xla_a, 2, {1}, /*keepdim=*/false);
850+
AllClose(b, xla_b);
851+
});
852+
}
853+
854+
TEST_F(AtenXlaTensorTest, TestNormInDims) {
855+
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
856+
at::Tensor b = at::norm(a, 2, {1, 2}, /*keepdim=*/false);
857+
ForEachDevice([&](const Device& device) {
858+
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
859+
at::Tensor xla_b = at::norm(xla_a, 2, {1, 2}, /*keepdim=*/false);
860+
AllClose(b, xla_b);
861+
});
862+
}
863+
864+
TEST_F(AtenXlaTensorTest, TestNormInDimsKeep) {
865+
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
866+
at::Tensor b = at::norm(a, 2, {1, 2}, /*keepdim=*/true);
867+
ForEachDevice([&](const Device& device) {
868+
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
869+
at::Tensor xla_b = at::norm(xla_a, 2, {1, 2}, /*keepdim=*/true);
870+
AllClose(b, xla_b);
871+
});
872+
}
873+
874+
TEST_F(AtenXlaTensorTest, TestNormGeneral) {
875+
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
876+
at::Tensor b = at::norm(a, 3.5);
877+
ForEachDevice([&](const Device& device) {
878+
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
879+
at::Tensor xla_b = at::norm(xla_a, 3.5);
880+
AllClose(b, xla_b);
881+
});
882+
}
883+
884+
TEST_F(AtenXlaTensorTest, TestNormNuclear) {
885+
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
886+
at::Tensor b = at::norm(a, 1);
887+
ForEachDevice([&](const Device& device) {
888+
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
889+
at::Tensor xla_b = at::norm(xla_a, 1);
890+
AllClose(b, xla_b);
891+
});
892+
}
893+
834894
TEST_F(AtenXlaTensorTest, TestProd) {
835895
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
836896
at::Tensor b = at::prod(a);

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,8 +1039,7 @@ at::Tensor& AtenXlaType::trunc_(at::Tensor& self) const {
10391039
}
10401040

10411041
at::Tensor AtenXlaType::frac(const at::Tensor& self) const {
1042-
return bridge::AtenFromXlaTensor(
1043-
XLATensor::frac(bridge::GetXlaTensor(self)));
1042+
return bridge::AtenFromXlaTensor(XLATensor::frac(bridge::GetXlaTensor(self)));
10441043
}
10451044

10461045
at::Tensor& AtenXlaType::frac_(at::Tensor& self) const {
@@ -1365,6 +1364,32 @@ at::Tensor AtenXlaType::dropout(const at::Tensor& input, double p,
13651364
XLATensor::dropout(bridge::GetXlaTensor(input), p));
13661365
}
13671366

1367+
at::Tensor AtenXlaType::norm(const at::Tensor& self,
1368+
c10::optional<at::Scalar> p,
1369+
at::ScalarType dtype) const {
1370+
return bridge::AtenFromXlaTensor(XLATensor::norm(
1371+
bridge::GetXlaTensor(self), p, dtype, {}, /*keepdim=*/false));
1372+
}
1373+
1374+
at::Tensor AtenXlaType::norm(const at::Tensor& self, at::Scalar p) const {
1375+
return bridge::AtenFromXlaTensor(XLATensor::norm(
1376+
bridge::GetXlaTensor(self), p, c10::nullopt, {}, /*keepdim=*/false));
1377+
}
1378+
1379+
at::Tensor AtenXlaType::norm(const at::Tensor& self,
1380+
c10::optional<at::Scalar> p, at::IntArrayRef dim,
1381+
bool keepdim, at::ScalarType dtype) const {
1382+
return bridge::AtenFromXlaTensor(
1383+
XLATensor::norm(bridge::GetXlaTensor(self), p, dtype, dim, keepdim));
1384+
}
1385+
1386+
at::Tensor AtenXlaType::norm(const at::Tensor& self,
1387+
c10::optional<at::Scalar> p, at::IntArrayRef dim,
1388+
bool keepdim) const {
1389+
return bridge::AtenFromXlaTensor(XLATensor::norm(
1390+
bridge::GetXlaTensor(self), p, c10::nullopt, dim, keepdim));
1391+
}
1392+
13681393
at::Tensor AtenXlaType::log_softmax(const at::Tensor& self, int64_t dim) const {
13691394
return bridge::AtenFromXlaTensor(
13701395
XLATensor::log_softmax(bridge::GetXlaTensor(self), dim));

torch_xla/csrc/aten_xla_type.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,15 @@ class AtenXlaType : public AtenXlaTypeBase {
427427
at::Tensor dropout(const at::Tensor& input, double p,
428428
bool train) const override;
429429

430+
at::Tensor norm(const at::Tensor& self, c10::optional<at::Scalar> p,
431+
at::ScalarType dtype) const override;
432+
at::Tensor norm(const at::Tensor& self, at::Scalar p) const override;
433+
at::Tensor norm(const at::Tensor& self, c10::optional<at::Scalar> p,
434+
at::IntArrayRef dim, bool keepdim,
435+
at::ScalarType dtype) const override;
436+
at::Tensor norm(const at::Tensor& self, c10::optional<at::Scalar> p,
437+
at::IntArrayRef dim, bool keepdim) const override;
438+
430439
at::Tensor log_softmax(const at::Tensor& self, int64_t dim) const override;
431440
at::Tensor _log_softmax(const at::Tensor& self, int64_t dim,
432441
bool half_to_float) const override;

torch_xla/csrc/ops/ops.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "torch_xla/csrc/ops/arithmetic_ir_ops.h"
1515
#include "torch_xla/csrc/ops/constant.h"
1616
#include "torch_xla/csrc/ops/infer_output_shape.h"
17+
#include "torch_xla/csrc/ops/sum.h"
1718
#include "torch_xla/csrc/pooling.h"
1819
#include "torch_xla/csrc/tensor_util.h"
1920
#include "torch_xla/csrc/xla_lower_util.h"
@@ -415,6 +416,42 @@ NodePtr BroadcastTensors(tensorflow::gtl::ArraySlice<const Value> tensors) {
415416
std::move(lower_fn), /*num_outputs=*/tensors.size());
416417
}
417418

419+
NodePtr Norm(const Value& input, c10::optional<at::Scalar> p,
420+
c10::optional<at::ScalarType> dtype, at::IntArrayRef dim,
421+
bool keepdim) {
422+
std::vector<xla::int64> dimensions(dim.begin(), dim.end());
423+
if (dimensions.empty()) {
424+
dimensions = xla::util::Iota<xla::int64>(input.shape().rank());
425+
}
426+
if (!p.has_value() || p->toDouble() == 2.0) {
427+
NodePtr square = input * input;
428+
NodePtr result = MakeNode<Sum>(square, dimensions, keepdim, dtype);
429+
return Sqrt(result);
430+
}
431+
double norm_value = p->toDouble();
432+
if (norm_value == 1.0) {
433+
// Contrary to documentation, norm(p=1) has nothing to do with traces and
434+
// standard mathematical definitions of nuclear norms:
435+
//
436+
// >>> import torch
437+
// >>> x = torch.randn(4, 4)
438+
// >>> print(torch.norm(x, 1))
439+
// tensor(11.9437)
440+
// >>> print(torch.trace(x.abs()))
441+
// tensor(3.1235)
442+
// >>> print(x.abs().sum())
443+
// tensor(11.9437)
444+
return MakeNode<Sum>(Abs(input), dimensions, keepdim, dtype);
445+
}
446+
// Generic sum(x^p)^(1/p) norms.
447+
NodePtr norm_exp = ScalarOp(norm_value, input.shape().element_type());
448+
NodePtr norm_exp_inv =
449+
ScalarOp(1.0 / norm_value, input.shape().element_type());
450+
NodePtr exp = Pow(input, norm_exp);
451+
NodePtr result = MakeNode<Sum>(exp, dimensions, keepdim, dtype);
452+
return Pow(result, norm_exp_inv);
453+
}
454+
418455
} // namespace ops
419456
} // namespace ir
420457
} // namespace torch_xla

torch_xla/csrc/ops/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ NodePtr ARange(const at::Scalar& start, const at::Scalar& end,
142142

143143
NodePtr BroadcastTensors(tensorflow::gtl::ArraySlice<const Value> tensors);
144144

145+
NodePtr Norm(const Value& input, c10::optional<at::Scalar> p,
146+
c10::optional<at::ScalarType> dtype, at::IntArrayRef dim,
147+
bool keepdim);
148+
145149
} // namespace ops
146150
} // namespace ir
147151
} // namespace torch_xla

torch_xla/csrc/tensor.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,13 @@ XLATensor XLATensor::dropout(const XLATensor& input, double p) {
10371037
ir::MakeNode<ir::ops::Dropout>(input.GetIrValue(), p));
10381038
}
10391039

1040+
XLATensor XLATensor::norm(const XLATensor& input, c10::optional<at::Scalar> p,
1041+
c10::optional<at::ScalarType> dtype,
1042+
at::IntArrayRef dim, bool keepdim) {
1043+
return input.CreateFrom(
1044+
ir::ops::Norm(input.GetIrValue(), p, dtype, dim, keepdim));
1045+
}
1046+
10401047
XLATensor XLATensor::neg(const XLATensor& input) {
10411048
return input.CreateFrom(ir::ops::Neg(input.GetIrValue()));
10421049
}

torch_xla/csrc/tensor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,10 @@ class XLATensor {
302302

303303
static XLATensor dropout(const XLATensor& input, double p);
304304

305+
static XLATensor norm(const XLATensor& input, c10::optional<at::Scalar> p,
306+
c10::optional<at::ScalarType> dtype,
307+
at::IntArrayRef dim, bool keepdim);
308+
305309
static XLATensor neg(const XLATensor& input);
306310
static void neg_(XLATensor& input);
307311

0 commit comments

Comments
 (0)