Skip to content

Commit 3375ed1

Browse files
committed
Add trunc to ATen XLA tensor
1 parent 96f85b1 commit 3375ed1

File tree

9 files changed

+48
-8
lines changed

9 files changed

+48
-8
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,16 @@ TEST_F(AtenXlaTensorTest, TestFloor) {
11431143
});
11441144
}
11451145

1146+
TEST_F(AtenXlaTensorTest, TestTrunc) {
1147+
at::Tensor a = at::randn({2, 2}, at::TensorOptions(at::kFloat)) * 100.0;
1148+
at::Tensor b = at::trunc(a);
1149+
ForEachDevice([&](const Device& device) {
1150+
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
1151+
at::Tensor xla_b = at::trunc(xla_a);
1152+
AllClose(b, xla_b);
1153+
});
1154+
}
1155+
11461156
TEST_F(AtenXlaTensorTest, TestNeg) {
11471157
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
11481158
at::Tensor b = at::neg(a);

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,17 @@ at::Tensor& AtenXlaType::floor_(at::Tensor& self) const {
10271027
return self;
10281028
}
10291029

1030+
at::Tensor AtenXlaType::trunc(const at::Tensor& self) const {
1031+
return bridge::AtenFromXlaTensor(
1032+
XLATensor::trunc(bridge::GetXlaTensor(self)));
1033+
}
1034+
1035+
at::Tensor& AtenXlaType::trunc_(at::Tensor& self) const {
1036+
XLATensor self_tensor = bridge::GetXlaTensor(self);
1037+
XLATensor::trunc_(self_tensor);
1038+
return self;
1039+
}
1040+
10301041
int64_t AtenXlaType::size(const at::Tensor& self, int64_t dim) const {
10311042
return bridge::GetXlaTensor(self).size(dim);
10321043
}

torch_xla/csrc/aten_xla_type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ class AtenXlaType : public AtenXlaTypeBase {
300300
at::Tensor floor(const at::Tensor& self) const override;
301301
at::Tensor& floor_(at::Tensor& self) const override;
302302

303+
at::Tensor trunc(const at::Tensor& self) const override;
304+
at::Tensor& trunc_(at::Tensor& self) const override;
305+
303306
int64_t size(const at::Tensor& self, int64_t dim) const override;
304307

305308
std::tuple<at::Tensor, at::Tensor> kthvalue(const at::Tensor& self, int64_t k,

torch_xla/csrc/ops/arithmetic_ir_ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace torch_xla {
1010
namespace ir {
1111

12-
Value operator+(const Value& node1, const Value& node2) {
12+
NodePtr operator+(const Value& node1, const Value& node2) {
1313
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
1414
xla::XlaOp op0 = loctx->GetOutputOp(node.operand(0));
1515
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(1));
@@ -21,7 +21,7 @@ Value operator+(const Value& node1, const Value& node2) {
2121
std::move(lower_fn));
2222
}
2323

24-
Value operator-(const Value& node1, const Value& node2) {
24+
NodePtr operator-(const Value& node1, const Value& node2) {
2525
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
2626
xla::XlaOp op0 = loctx->GetOutputOp(node.operand(0));
2727
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(1));
@@ -33,7 +33,7 @@ Value operator-(const Value& node1, const Value& node2) {
3333
std::move(lower_fn));
3434
}
3535

36-
Value operator*(const Value& node1, const Value& node2) {
36+
NodePtr operator*(const Value& node1, const Value& node2) {
3737
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
3838
xla::XlaOp op0 = loctx->GetOutputOp(node.operand(0));
3939
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(1));
@@ -45,7 +45,7 @@ Value operator*(const Value& node1, const Value& node2) {
4545
std::move(lower_fn));
4646
}
4747

48-
Value operator/(const Value& node1, const Value& node2) {
48+
NodePtr operator/(const Value& node1, const Value& node2) {
4949
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
5050
xla::XlaOp op0 = loctx->GetOutputOp(node.operand(0));
5151
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(1));

torch_xla/csrc/ops/arithmetic_ir_ops.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
namespace torch_xla {
66
namespace ir {
77

8-
Value operator+(const Value& node1, const Value& node2);
9-
Value operator-(const Value& node1, const Value& node2);
10-
Value operator*(const Value& node1, const Value& node2);
11-
Value operator/(const Value& node1, const Value& node2);
8+
NodePtr operator+(const Value& node1, const Value& node2);
9+
NodePtr operator-(const Value& node1, const Value& node2);
10+
NodePtr operator*(const Value& node1, const Value& node2);
11+
NodePtr operator/(const Value& node1, const Value& node2);
1212

1313
} // namespace ir
1414
} // namespace torch_xla

torch_xla/csrc/ops/ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "torch_xla/csrc/helpers.h"
1212
#include "torch_xla/csrc/lowering_context.h"
1313
#include "torch_xla/csrc/nll_loss.h"
14+
#include "torch_xla/csrc/ops/arithmetic_ir_ops.h"
1415
#include "torch_xla/csrc/ops/constant.h"
1516
#include "torch_xla/csrc/ops/infer_output_shape.h"
1617
#include "torch_xla/csrc/pooling.h"
@@ -73,6 +74,8 @@ PTXLA_BINARY_OP(Pow, at::aten::pow, xla::Pow);
7374
PTXLA_BINARY_OP(Fmod, at::aten::fmod, xla::Rem);
7475
PTXLA_BINARY_OP(Atan2, at::aten::atan2, xla::Atan2);
7576

77+
NodePtr Trunc(const Value& input) { return Floor(Abs(input)) * SignOp(input); }
78+
7679
NodePtr LogBase(const Value& input, OpKind op, double base) {
7780
auto lower_fn = [base](const Node& node,
7881
LoweringContext* loctx) -> XlaOpVector {

torch_xla/csrc/ops/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ NodePtr Ceil(const Value& input);
114114

115115
NodePtr Floor(const Value& input);
116116

117+
NodePtr Trunc(const Value& input);
118+
117119
NodePtr AddMatMulOp(const Value& input, const Value& weight, const Value& bias);
118120

119121
NodePtr Dot(const Value& input, const Value& weight);

torch_xla/csrc/tensor.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,14 @@ void XLATensor::floor_(XLATensor& input) {
11771177
input.SetIrValue(ir::ops::Floor(input.GetIrValue()));
11781178
}
11791179

1180+
XLATensor XLATensor::trunc(const XLATensor& input) {
1181+
return input.CreateFrom(ir::ops::Trunc(input.GetIrValue()));
1182+
}
1183+
1184+
void XLATensor::trunc_(XLATensor& input) {
1185+
input.SetIrValue(ir::ops::Trunc(input.GetIrValue()));
1186+
}
1187+
11801188
XLATensor XLATensor::slice(const XLATensor& input, xla::int64 dim,
11811189
xla::int64 start, xla::int64 end, xla::int64 step) {
11821190
auto input_shape = input.shape();

torch_xla/csrc/tensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,9 @@ class XLATensor {
359359
static XLATensor floor(const XLATensor& input);
360360
static void floor_(XLATensor& input);
361361

362+
static XLATensor trunc(const XLATensor& input);
363+
static void trunc_(XLATensor& input);
364+
362365
static XLATensor slice(const XLATensor& input, xla::int64 dim,
363366
xla::int64 start, xla::int64 end, xla::int64 step);
364367

0 commit comments

Comments
 (0)