Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,26 @@ TEST_F(AtenXlaTensorTest, TestMax) {
});
}

TEST_F(AtenXlaTensorTest, TestUnaryMin) {
at::Tensor input = at::rand({2, 2}, at::TensorOptions(at::kFloat));
at::Tensor output = at::min(input);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::min(xla_input);
AllClose(output, xla_output);
});
}

TEST_F(AtenXlaTensorTest, TestUnaryMax) {
at::Tensor input = at::rand({2, 2}, at::TensorOptions(at::kFloat));
at::Tensor output = at::max(input);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::max(xla_input);
AllClose(output, xla_output);
});
}

TEST_F(AtenXlaTensorTest, TestAll) {
at::Tensor a = at::randint(0, 5, {2, 3, 4}, at::TensorOptions(at::kByte));
at::Tensor b = at::all(a);
Expand Down Expand Up @@ -2732,6 +2752,18 @@ TEST_F(AtenXlaTensorTest, TestEmbedding) {
});
}

TEST_F(AtenXlaTensorTest, TestOneHot) {
int num_classes = 5;
at::Tensor input =
at::randint(0, num_classes, {10}, at::TensorOptions(at::kLong));
at::Tensor output = at::one_hot(input, num_classes);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::one_hot(xla_input, num_classes);
EXPECT_TRUE(EqualValues(output, xla_output));
});
}

TEST_F(AtenXlaTensorTest, TestTranspose) {
at::Tensor input = at::rand({2, 3}, at::TensorOptions(at::kFloat));
at::Tensor output = at::t(input);
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2080,12 +2080,20 @@ at::Tensor AtenXlaType::min(const at::Tensor& self,
XLATensor::min(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor AtenXlaType::min(const at::Tensor& self) const {
return bridge::AtenFromXlaTensor(XLATensor::min(bridge::GetXlaTensor(self)));
}

at::Tensor AtenXlaType::max(const at::Tensor& self,
const at::Tensor& other) const {
return bridge::AtenFromXlaTensor(
XLATensor::max(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor AtenXlaType::max(const at::Tensor& self) const {
return bridge::AtenFromXlaTensor(XLATensor::max(bridge::GetXlaTensor(self)));
}

at::Tensor AtenXlaType::mean(const at::Tensor& self,
at::ScalarType dtype) const {
XLATensor self_tensor = bridge::GetXlaTensor(self);
Expand Down Expand Up @@ -2199,6 +2207,11 @@ at::Tensor AtenXlaType::flip(const at::Tensor& self,
XLATensor::flip(bridge::GetXlaTensor(self), XlaHelpers::I64List(dims)));
}

at::Tensor AtenXlaType::one_hot(const at::Tensor& self,
int64_t num_classes) const {
return at::native::one_hot(self, num_classes);
}

at::Tensor AtenXlaType::transpose(const at::Tensor& self, int64_t dim0,
int64_t dim1) const {
return bridge::AtenFromXlaTensor(
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,10 @@ class AtenXlaType : public AtenXlaTypeBase {

at::Tensor min(const at::Tensor& self,
const at::Tensor& other) const override;
at::Tensor min(const at::Tensor& self) const override;
at::Tensor max(const at::Tensor& self,
const at::Tensor& other) const override;
at::Tensor max(const at::Tensor& self) const override;

at::Tensor mean(const at::Tensor& self, at::ScalarType dtype) const override;
at::Tensor mean(const at::Tensor& self) const override;
Expand Down Expand Up @@ -699,6 +701,9 @@ class AtenXlaType : public AtenXlaTypeBase {

at::Tensor flip(const at::Tensor& self, at::IntArrayRef dims) const override;

at::Tensor one_hot(const at::Tensor& self,
int64_t num_classes) const override;

at::Tensor transpose(const at::Tensor& self, int64_t dim0,
int64_t dim1) const override;
at::Tensor& transpose_(at::Tensor& self, int64_t dim0,
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ xla::XlaComputation XlaHelpers::CreateMaxComputation(xla::PrimitiveType type) {
return ConsumeValue(builder.Build());
}

xla::XlaComputation XlaHelpers::CreateMinComputation(xla::PrimitiveType type) {
xla::XlaBuilder builder("MinComputation");
xla::XlaOp x =
xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(type, {}), "x");
xla::XlaOp y =
xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(type, {}), "y");
xla::Min(x, y);
return ConsumeValue(builder.Build());
}

xla::Shape XlaHelpers::ShapeOfXlaOp(const xla::XlaOp& op) {
return ConsumeValue(op.builder()->GetShape(op));
}
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class XlaHelpers {

static xla::XlaComputation CreateMaxComputation(xla::PrimitiveType type);

static xla::XlaComputation CreateMinComputation(xla::PrimitiveType type);

// Returns an XLA operation which is a reshape to the expected rank, by
// appending 1s to the major dimension. If offset is greater than zero, 1s
// will be prepened to the minor dimension as well.
Expand Down
36 changes: 36 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,42 @@ NodePtr Remainder(const Value& input, const Value& divisor) {
ScalarOp(0, input.shape()));
}

NodePtr MaxUnary(const Value& input) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(xla_input);
xla::PrimitiveType element_type = input_shape.element_type();
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(element_type);
xla::XlaOp init_value =
XlaHelpers::ScalarValue(min_max.min, element_type, loctx->builder());
xla::XlaOp result = xla::Reduce(
xla_input, init_value, XlaHelpers::CreateMaxComputation(element_type),
xla::util::Iota<xla::int64>(input_shape.rank()));
return node.ReturnOp(xla::Reshape(result, {1}), loctx);
};
return GenericOp(OpKind(at::aten::max), {input},
xla::ShapeUtil::MakeShape(input.shape().element_type(), {1}),
std::move(lower_fn));
}

NodePtr MinUnary(const Value& input) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(xla_input);
xla::PrimitiveType element_type = input_shape.element_type();
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(element_type);
xla::XlaOp init_value =
XlaHelpers::ScalarValue(min_max.max, element_type, loctx->builder());
xla::XlaOp result = xla::Reduce(
xla_input, init_value, XlaHelpers::CreateMinComputation(element_type),
xla::util::Iota<xla::int64>(input_shape.rank()));
return node.ReturnOp(xla::Reshape(result, {1}), loctx);
};
return GenericOp(OpKind(at::aten::min), {input},
xla::ShapeUtil::MakeShape(input.shape().element_type(), {1}),
std::move(lower_fn));
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ NodePtr Rshift(const Value& input, const Value& other);

NodePtr Remainder(const Value& input, const Value& divisor);

NodePtr MaxUnary(const Value& input);

NodePtr MinUnary(const Value& input);

} // namespace ops
} // namespace ir
} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,8 @@ class XLATensor {

static XLATensor max(const XLATensor& input, const XLATensor& other);

static XLATensor max(const XLATensor& input);

static XLATensor max_pool2d(const XLATensor& input,
std::vector<xla::int64> kernel_size,
std::vector<xla::int64> stride,
Expand All @@ -526,6 +528,8 @@ class XLATensor {

static XLATensor min(const XLATensor& input, const XLATensor& other);

static XLATensor min(const XLATensor& input);

static XLATensor mm(const XLATensor& input, const XLATensor& weight);

static XLATensor mul(const XLATensor& input, const XLATensor& other);
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,10 @@ XLATensor XLATensor::max(const XLATensor& input, const XLATensor& other) {
return input.CreateFrom(ir::ops::Max(input.GetIrValue(), other.GetIrValue()));
}

XLATensor XLATensor::max(const XLATensor& input) {
return input.CreateFrom(ir::ops::MaxUnary(input.GetIrValue()), input.dtype());
}

XLATensor XLATensor::max_pool2d(const XLATensor& input,
std::vector<xla::int64> kernel_size,
std::vector<xla::int64> stride,
Expand Down Expand Up @@ -1102,6 +1106,10 @@ XLATensor XLATensor::min(const XLATensor& input, const XLATensor& other) {
return input.CreateFrom(ir::ops::Min(input.GetIrValue(), other.GetIrValue()));
}

XLATensor XLATensor::min(const XLATensor& input) {
return input.CreateFrom(ir::ops::MinUnary(input.GetIrValue()), input.dtype());
}

XLATensor XLATensor::mm(const XLATensor& input, const XLATensor& weight) {
return input.CreateFrom(
ir::ops::Dot(input.GetIrValue(), weight.GetIrValue()));
Expand Down