From 87b75b201ee33a93ee675b83c16b0053fd51e205 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20=C5=9Euhan?= Date: Fri, 15 Mar 2019 12:13:16 -0700 Subject: [PATCH 1/2] Add min, max (without indices) to ATen XLA tensor --- test/cpp/test_aten_xla_tensor.cpp | 20 +++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 8 +++++++ torch_xla/csrc/aten_xla_type.h | 2 ++ torch_xla/csrc/helpers.cpp | 10 +++++++++ torch_xla/csrc/helpers.h | 2 ++ torch_xla/csrc/ops/ops.cpp | 36 +++++++++++++++++++++++++++++++ torch_xla/csrc/ops/ops.h | 4 ++++ torch_xla/csrc/tensor.h | 4 ++++ torch_xla/csrc/tensor_methods.cpp | 8 +++++++ 9 files changed, 94 insertions(+) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 0491d12e5509..bad41a1e38b1 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -793,6 +793,26 @@ TEST_F(AtenXlaTensorTest, TestMax) { }); } +TEST_F(AtenXlaTensorTest, TestUnaryMin) { + at::Tensor input = at::rand({2, 2}, at::TensorOptions(at::kFloat)); + at::Tensor output = at::min(input); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::min(xla_input); + AllClose(output, xla_output); + }); +} + +TEST_F(AtenXlaTensorTest, TestUnaryMax) { + at::Tensor input = at::rand({2, 2}, at::TensorOptions(at::kFloat)); + at::Tensor output = at::max(input); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::max(xla_input); + AllClose(output, xla_output); + }); +} + TEST_F(AtenXlaTensorTest, TestAll) { at::Tensor a = at::randint(0, 5, {2, 3, 4}, at::TensorOptions(at::kByte)); at::Tensor b = at::all(a); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 07599335943f..bea0b31b7b97 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2080,12 +2080,20 @@ at::Tensor AtenXlaType::min(const at::Tensor& self, XLATensor::min(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } +at::Tensor AtenXlaType::min(const at::Tensor& self) const { + return bridge::AtenFromXlaTensor(XLATensor::min(bridge::GetXlaTensor(self))); +} + at::Tensor AtenXlaType::max(const at::Tensor& self, const at::Tensor& other) const { return bridge::AtenFromXlaTensor( XLATensor::max(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } +at::Tensor AtenXlaType::max(const at::Tensor& self) const { + return bridge::AtenFromXlaTensor(XLATensor::max(bridge::GetXlaTensor(self))); +} + at::Tensor AtenXlaType::mean(const at::Tensor& self, at::ScalarType dtype) const { XLATensor self_tensor = bridge::GetXlaTensor(self); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 7ce6e0729459..9217fd35d979 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -663,8 +663,10 @@ class AtenXlaType : public AtenXlaTypeBase { at::Tensor min(const at::Tensor& self, const at::Tensor& other) const override; + at::Tensor min(const at::Tensor& self) const override; at::Tensor max(const at::Tensor& self, const at::Tensor& other) const override; + at::Tensor max(const at::Tensor& self) const override; at::Tensor mean(const at::Tensor& self, at::ScalarType dtype) const override; at::Tensor mean(const at::Tensor& self) const override; diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index ec0e5863d85f..bb62d77565d0 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -185,6 +185,16 @@ xla::XlaComputation XlaHelpers::CreateMaxComputation(xla::PrimitiveType type) { return ConsumeValue(builder.Build()); } +xla::XlaComputation XlaHelpers::CreateMinComputation(xla::PrimitiveType type) { + xla::XlaBuilder builder("MinComputation"); + xla::XlaOp x = + xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(type, {}), "x"); + xla::XlaOp y = + xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(type, {}), "y"); + xla::Min(x, y); + return ConsumeValue(builder.Build()); +} + xla::Shape XlaHelpers::ShapeOfXlaOp(const xla::XlaOp& op) { return ConsumeValue(op.builder()->GetShape(op)); } diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index da806f420118..5719b3af1757 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -150,6 +150,8 @@ class XlaHelpers { static xla::XlaComputation CreateMaxComputation(xla::PrimitiveType type); + static xla::XlaComputation CreateMinComputation(xla::PrimitiveType type); + // Returns an XLA operation which is a reshape to the expected rank, by // appending 1s to the major dimension. If offset is greater than zero, 1s // will be prepened to the minor dimension as well. diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index e53596e6a81e..9a452205bf97 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -542,6 +542,42 @@ NodePtr Remainder(const Value& input, const Value& divisor) { ScalarOp(0, input.shape())); } +NodePtr MaxUnary(const Value& input) { + auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(xla_input); + xla::PrimitiveType element_type = input_shape.element_type(); + XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(element_type); + xla::XlaOp init_value = + XlaHelpers::ScalarValue(min_max.min, element_type, loctx->builder()); + xla::XlaOp result = xla::Reduce( + xla_input, init_value, XlaHelpers::CreateMaxComputation(element_type), + xla::util::Iota(input_shape.rank())); + return node.ReturnOp(xla::Reshape(result, {1}), loctx); + }; + return GenericOp(OpKind(at::aten::max), {input}, + xla::ShapeUtil::MakeShape(input.shape().element_type(), {1}), + std::move(lower_fn)); +} + +NodePtr MinUnary(const Value& input) { + auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(xla_input); + xla::PrimitiveType element_type = input_shape.element_type(); + XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(element_type); + xla::XlaOp init_value = + XlaHelpers::ScalarValue(min_max.max, element_type, loctx->builder()); + xla::XlaOp result = xla::Reduce( + xla_input, init_value, XlaHelpers::CreateMinComputation(element_type), + xla::util::Iota(input_shape.rank())); + return node.ReturnOp(xla::Reshape(result, {1}), loctx); + }; + return GenericOp(OpKind(at::aten::min), {input}, + xla::ShapeUtil::MakeShape(input.shape().element_type(), {1}), + std::move(lower_fn)); +} + } // namespace ops } // namespace ir } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index edf27383e6cb..dcaa27cc8ec0 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -177,6 +177,10 @@ NodePtr Rshift(const Value& input, const Value& other); NodePtr Remainder(const Value& input, const Value& divisor); +NodePtr MaxUnary(const Value& input); + +NodePtr MinUnary(const Value& input); + } // namespace ops } // namespace ir } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 2895586c0fdb..d56597d68792 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -508,6 +508,8 @@ class XLATensor { static XLATensor max(const XLATensor& input, const XLATensor& other); + static XLATensor max(const XLATensor& input); + static XLATensor max_pool2d(const XLATensor& input, std::vector kernel_size, std::vector stride, @@ -526,6 +528,8 @@ class XLATensor { static XLATensor min(const XLATensor& input, const XLATensor& other); + static XLATensor min(const XLATensor& input); + static XLATensor mm(const XLATensor& input, const XLATensor& weight); static XLATensor mul(const XLATensor& input, const XLATensor& other); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 45d26a89f643..47a491877e81 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1068,6 +1068,10 @@ XLATensor XLATensor::max(const XLATensor& input, const XLATensor& other) { return input.CreateFrom(ir::ops::Max(input.GetIrValue(), other.GetIrValue())); } +XLATensor XLATensor::max(const XLATensor& input) { + return input.CreateFrom(ir::ops::MaxUnary(input.GetIrValue()), input.dtype()); +} + XLATensor XLATensor::max_pool2d(const XLATensor& input, std::vector kernel_size, std::vector stride, @@ -1102,6 +1106,10 @@ XLATensor XLATensor::min(const XLATensor& input, const XLATensor& other) { return input.CreateFrom(ir::ops::Min(input.GetIrValue(), other.GetIrValue())); } +XLATensor XLATensor::min(const XLATensor& input) { + return input.CreateFrom(ir::ops::MinUnary(input.GetIrValue()), input.dtype()); +} + XLATensor XLATensor::mm(const XLATensor& input, const XLATensor& weight) { return input.CreateFrom( ir::ops::Dot(input.GetIrValue(), weight.GetIrValue())); From 50cf27298a2037afb4b1064aaa14a565249c8e08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20=C5=9Euhan?= Date: Fri, 15 Mar 2019 13:49:38 -0700 Subject: [PATCH 2/2] Add one_hot to ATen XLA tensor It's sufficient to just route it to at::native::one_hot. Still goes back to host for range validation of labels, which cannot be avoided. --- test/cpp/test_aten_xla_tensor.cpp | 12 ++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 5 +++++ torch_xla/csrc/aten_xla_type.h | 3 +++ 3 files changed, 20 insertions(+) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index bad41a1e38b1..87cd61f13ab7 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -2752,6 +2752,18 @@ TEST_F(AtenXlaTensorTest, TestEmbedding) { }); } +TEST_F(AtenXlaTensorTest, TestOneHot) { + int num_classes = 5; + at::Tensor input = + at::randint(0, num_classes, {10}, at::TensorOptions(at::kLong)); + at::Tensor output = at::one_hot(input, num_classes); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::one_hot(xla_input, num_classes); + EXPECT_TRUE(EqualValues(output, xla_output)); + }); +} + TEST_F(AtenXlaTensorTest, TestTranspose) { at::Tensor input = at::rand({2, 3}, at::TensorOptions(at::kFloat)); at::Tensor output = at::t(input); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index bea0b31b7b97..0a69faeb64f5 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2207,6 +2207,11 @@ at::Tensor AtenXlaType::flip(const at::Tensor& self, XLATensor::flip(bridge::GetXlaTensor(self), XlaHelpers::I64List(dims))); } +at::Tensor AtenXlaType::one_hot(const at::Tensor& self, + int64_t num_classes) const { + return at::native::one_hot(self, num_classes); +} + at::Tensor AtenXlaType::transpose(const at::Tensor& self, int64_t dim0, int64_t dim1) const { return bridge::AtenFromXlaTensor( diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 9217fd35d979..4a7702a329e8 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -701,6 +701,9 @@ class AtenXlaType : public AtenXlaTypeBase { at::Tensor flip(const at::Tensor& self, at::IntArrayRef dims) const override; + at::Tensor one_hot(const at::Tensor& self, + int64_t num_classes) const override; + at::Tensor transpose(const at::Tensor& self, int64_t dim0, int64_t dim1) const override; at::Tensor& transpose_(at::Tensor& self, int64_t dim0,