diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index e86ec4786c9..22bc39f355f 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -1275,6 +1275,25 @@ TEST_F(AtenXlaTensorTest, TestStdInDim) { } } +TEST_F(AtenXlaTensorTest, TestStdWithCorrection) { + torch::Tensor a = torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat)); + int rank = a.dim(); + c10::optional corrections[] = {1, 2, c10::nullopt}; + for (const auto& correction : corrections) { + for (auto keepdim : {true, false}) { + for (const auto& dim : + std::vector>{{0, 1}, {-3, -2}}) { + torch::Tensor b = torch::std(a, dim, correction, keepdim); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = torch::std(xla_a, dim, correction, keepdim); + AllClose(b, xla_b); + }); + } + } + } +} + TEST_F(AtenXlaTensorTest, TestSum) { torch::Tensor a = torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::sum(a); @@ -1387,6 +1406,25 @@ TEST_F(AtenXlaTensorTest, TestVarWithDim) { ExpectCounterChanged("xla::var", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestVarWithCorrection) { + torch::Tensor a = torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat)); + c10::optional corrections[] = {1, 2, c10::nullopt}; + for (const auto& dim : std::vector>{{0, 1}, {-3, -2}}) { + for (bool keepDim : {true, false}) { + for (const auto& correction : corrections) { + torch::Tensor b = torch::var(a, dim, correction, keepDim); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = torch::var(xla_a, dim, correction, keepDim); + AllClose(b, xla_b); + }); + } + } + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::var", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestMaxInDim) { torch::Tensor input = torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat)); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 7d0e8a6db5c..e0d19adeecb 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3128,15 +3128,27 @@ at::Tensor AtenXlaType::std(const at::Tensor& self, bool unbiased) { return bridge::AtenFromXlaTensor(XLATensor::std( self_tensor, xla::util::Iota(self_tensor.shape().get().rank()), - /*keep_reduced_dimensions=*/false, unbiased)); + /*keep_reduced_dimensions=*/false, /*correction=*/unbiased ? 1 : 0)); } at::Tensor AtenXlaType::std(const at::Tensor& self, at::IntArrayRef dim, bool unbiased, bool keepdim) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::std( - bridge::GetXlaTensor(self), xla::util::ToVector(dim), - /*keep_reduced_dimensions=*/keepdim, unbiased)); + bridge::GetXlaTensor(self), xla::util::ToVector(dim), keepdim, + /*correction=*/unbiased ? 1 : 0)); +} + +at::Tensor AtenXlaType::std(const at::Tensor& self, + c10::optional dim, + c10::optional correction, bool keepdim) { + XLA_FN_COUNTER("xla::"); + XLATensor self_tensor = bridge::GetXlaTensor(self); + return bridge::AtenFromXlaTensor(XLATensor::std( + self_tensor, + dim ? xla::util::ToVector(*dim) + : xla::util::Iota(self_tensor.shape().get().rank()), + keepdim, correction ? *correction : 1)); } at::Tensor AtenXlaType::sub(const at::Tensor& self, const at::Tensor& other, @@ -3532,7 +3544,7 @@ at::Tensor AtenXlaType::var(const at::Tensor& self, bool unbiased) { XLATensor::var(bridge::GetXlaTensor(self), xla::util::Iota( bridge::GetXlaTensor(self).shape().get().rank()), - unbiased, + /*correction=*/unbiased ? 1 : 0, /*keep_reduced_dimensions=*/false)); } @@ -3541,7 +3553,21 @@ at::Tensor AtenXlaType::var(const at::Tensor& self, at::IntArrayRef dim, XLA_FN_COUNTER("xla::"); XLATensor self_tensor = bridge::GetXlaTensor(self); return bridge::AtenFromXlaTensor( - XLATensor::var(self_tensor, XlaHelpers::I64List(dim), unbiased, keepdim)); + XLATensor::var(self_tensor, XlaHelpers::I64List(dim), + /*correction=*/unbiased ? 1 : 0, keepdim)); +} + +at::Tensor AtenXlaType::var(const at::Tensor& self, + c10::optional dim, + c10::optional correction, bool keepdim) { + XLA_FN_COUNTER("xla::"); + XLATensor self_tensor = bridge::GetXlaTensor(self); + return bridge::AtenFromXlaTensor( + XLATensor::var(self_tensor, + dim ? XlaHelpers::I64List(*dim) + : xla::util::Iota( + bridge::GetXlaTensor(self).shape().get().rank()), + correction ? *correction : 1, keepdim)); } at::Tensor AtenXlaType::view(const at::Tensor& self, at::IntArrayRef size) { diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 2d5496ac917..dcc166713f4 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -977,6 +977,10 @@ class AtenXlaType { static at::Tensor std(const at::Tensor& self, at::IntArrayRef dim, bool unbiased, bool keepdim); + static at::Tensor std(const at::Tensor& self, + c10::optional dim, + c10::optional correction, bool keepdim); + static at::Tensor sub(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha); @@ -1102,6 +1106,10 @@ class AtenXlaType { static at::Tensor var(const at::Tensor& self, at::IntArrayRef dim, bool unbiased, bool keepdim); + static at::Tensor var(const at::Tensor& self, + c10::optional dim, + c10::optional correction, bool keepdim); + static at::Tensor view(const at::Tensor& self, at::IntArrayRef size); static at::Tensor& zero_(at::Tensor& self); diff --git a/torch_xla/csrc/ops/std.cpp b/torch_xla/csrc/ops/std.cpp index cec37ac697b..b6ec485a055 100644 --- a/torch_xla/csrc/ops/std.cpp +++ b/torch_xla/csrc/ops/std.cpp @@ -13,11 +13,12 @@ namespace { xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, - bool keep_reduced_dimensions, bool unbiased) { + bool keep_reduced_dimensions, + xla::int64 correction) { auto lower_for_shape_fn = [&](absl::Span operands) -> xla::XlaOp { return BuildStdDeviation(operands[0], dimensions, keep_reduced_dimensions, - unbiased); + correction); }; return InferOutputShape({input.shape()}, lower_for_shape_fn); } @@ -25,27 +26,27 @@ xla::Shape NodeOutputShape(const Value& input, } // namespace Std::Std(const Value& input, std::vector dimensions, - bool keep_reduced_dimensions, bool unbiased) + bool keep_reduced_dimensions, xla::int64 correction) : Node(ir::OpKind(at::aten::std), {input}, [&]() { return NodeOutputShape(input, dimensions, keep_reduced_dimensions, - unbiased); + correction); }, /*num_outputs=*/1, - xla::util::MHash(dimensions, keep_reduced_dimensions, unbiased)), + xla::util::MHash(dimensions, keep_reduced_dimensions, correction)), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions), - unbiased_(unbiased) {} + correction_(correction) {} NodePtr Std::Clone(OpList operands) const { return MakeNode(operands.at(0), dimensions_, keep_reduced_dimensions_, - unbiased_); + correction_); } XlaOpVector Std::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); return ReturnOp(BuildStdDeviation(input, dimensions_, - keep_reduced_dimensions_, unbiased_), + keep_reduced_dimensions_, correction_), loctx); } @@ -53,7 +54,7 @@ std::string Std::ToString() const { std::stringstream ss; ss << Node::ToString() << ", dimensions=(" << absl::StrJoin(dimensions_, ", ") << "), keep_reduced_dimensions=" << keep_reduced_dimensions_ - << ", unbiased=" << unbiased_; + << ", correction=" << correction_; return ss.str(); } diff --git a/torch_xla/csrc/ops/std.h b/torch_xla/csrc/ops/std.h index c809bcaa58e..59e71f26d62 100644 --- a/torch_xla/csrc/ops/std.h +++ b/torch_xla/csrc/ops/std.h @@ -12,7 +12,7 @@ namespace ops { class Std : public Node { public: Std(const Value& input, std::vector dimensions, - bool keep_reduced_dimensions, bool unbiased); + bool keep_reduced_dimensions, xla::int64 correction); std::string ToString() const override; @@ -24,12 +24,12 @@ class Std : public Node { bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; } - bool unbiased() const { return unbiased_; } + xla::int64 correction() const { return correction_; } private: std::vector dimensions_; bool keep_reduced_dimensions_; - bool unbiased_; + xla::int64 correction_; }; } // namespace ops diff --git a/torch_xla/csrc/ops/var.cpp b/torch_xla/csrc/ops/var.cpp index c2d26ff3c4b..ed8f88c50f2 100644 --- a/torch_xla/csrc/ops/var.cpp +++ b/torch_xla/csrc/ops/var.cpp @@ -15,43 +15,46 @@ namespace ops { namespace { xla::Shape NodeOutputShape(const Value& input, - std::vector& dimensions, bool unbiased, + std::vector& dimensions, + xla::int64 correction, bool keep_reduced_dimensions) { auto lower_for_shape_fn = [&](absl::Span operands) -> xla::XlaOp { - return BuildVar(operands[0], dimensions, unbiased, keep_reduced_dimensions); + return BuildVar(operands[0], dimensions, correction, + keep_reduced_dimensions); }; return InferOutputShape({input.shape()}, lower_for_shape_fn); } } // namespace -Var::Var(const Value& input, std::vector dimensions, bool unbiased, - bool keep_reduced_dimensions) - : Node( - ir::OpKind(at::aten::var), {input}, - NodeOutputShape(input, dimensions, unbiased, keep_reduced_dimensions), - /*num_outputs=*/1, - xla::util::MHash(dimensions, unbiased, keep_reduced_dimensions)), +Var::Var(const Value& input, std::vector dimensions, + xla::int64 correction, bool keep_reduced_dimensions) + : Node(ir::OpKind(at::aten::var), {input}, + NodeOutputShape(input, dimensions, correction, + keep_reduced_dimensions), + /*num_outputs=*/1, + xla::util::MHash(dimensions, correction, keep_reduced_dimensions)), dimensions_(std::move(dimensions)), - unbiased_(unbiased), + correction_(correction), keep_reduced_dimensions_(keep_reduced_dimensions) {} NodePtr Var::Clone(OpList operands) const { - return MakeNode(operands.at(0), dimensions_, unbiased_, + return MakeNode(operands.at(0), dimensions_, correction_, keep_reduced_dimensions_); } XlaOpVector Var::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); return ReturnOp( - BuildVar(input, dimensions_, unbiased_, keep_reduced_dimensions_), loctx); + BuildVar(input, dimensions_, correction_, keep_reduced_dimensions_), + loctx); } std::string Var::ToString() const { std::stringstream ss; ss << Node::ToString() << ", dimensions=(" << absl::StrJoin(dimensions_, ", ") - << "), unbiased=" << unbiased_ + << "), correction=" << correction_ << ", keep_reduced_dimensions=" << keep_reduced_dimensions_; return ss.str(); } diff --git a/torch_xla/csrc/ops/var.h b/torch_xla/csrc/ops/var.h index 583ab1afe3f..f7647d11423 100644 --- a/torch_xla/csrc/ops/var.h +++ b/torch_xla/csrc/ops/var.h @@ -11,8 +11,8 @@ namespace ops { class Var : public Node { public: - Var(const Value& input, std::vector dimensions, bool unbiased, - bool keep_reduced_dimensions); + Var(const Value& input, std::vector dimensions, + xla::int64 correction, bool keep_reduced_dimensions); std::string ToString() const override; @@ -24,11 +24,11 @@ class Var : public Node { bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; } - bool unbiased() const { return unbiased_; } + xla::int64 correction() const { return correction_; } private: std::vector dimensions_; - bool unbiased_; + xla::int64 correction_; bool keep_reduced_dimensions_; }; diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index d5a663af7e4..4694c097740 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -284,7 +284,8 @@ xla::XlaOp BuildMean(xla::XlaOp input, absl::Span dimensions, xla::XlaOp BuildStdDeviation(xla::XlaOp input, absl::Span dimensions, - bool keep_reduced_dimensions, bool unbiased) { + bool keep_reduced_dimensions, + xla::int64 correction) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp mean = BuildMean(input, dimensions, /*keep_reduced_dimensions*/ true); @@ -294,15 +295,17 @@ xla::XlaOp BuildStdDeviation(xla::XlaOp input, xla::XlaOp input_mean_diff = input - bcast_mean; xla::XlaOp squared_var = input_mean_diff * input_mean_diff; xla::XlaOp squared_result; - if (unbiased) { + if (correction != 0) { SummationResult sum_result = CreateSummation( squared_var, dimensions, keep_reduced_dimensions, /*scale=*/false); - squared_result = GetScaleValue( - sum_result.result, - sum_result.rinfo.element_count.size - - xla::One(input.builder(), XlaHelpers::TypeOfXlaOp( - sum_result.rinfo.element_count.size)), - input_shape.element_type()); + xla::XlaOp correction_scalar = XlaHelpers::ScalarValue( + correction, + XlaHelpers::TypeOfXlaOp(sum_result.rinfo.element_count.size), + input.builder()); + squared_result = + GetScaleValue(sum_result.result, + sum_result.rinfo.element_count.size - correction_scalar, + input_shape.element_type()); } else { SummationResult sum_result = CreateSummation( squared_var, dimensions, keep_reduced_dimensions, /*scale=*/true); @@ -458,7 +461,7 @@ xla::XlaOp BuildAny(xla::XlaOp input, absl::Span dimensions, } xla::XlaOp BuildVar(xla::XlaOp input, absl::Span dimensions, - bool unbiased, bool keep_reduced_dimensions) { + xla::int64 correction, bool keep_reduced_dimensions) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); SummationResult mean_result = CreateSummation(input, dimensions, /*keep_reduced_dimensions=*/true, @@ -470,9 +473,11 @@ xla::XlaOp BuildVar(xla::XlaOp input, absl::Span dimensions, /*scale=*/false) .result; xla::XlaOp count = mean_result.rinfo.element_count.size; - if (unbiased) { - count = count - xla::One(input.builder(), - XlaHelpers::ShapeOfXlaOp(count).element_type()); + if (correction != 0) { + count = + count - XlaHelpers::ScalarValue( + correction, XlaHelpers::ShapeOfXlaOp(count).element_type(), + input.builder()); } return GetScaleValue(unscaled_result, count, input_shape.element_type()); } diff --git a/torch_xla/csrc/reduction.h b/torch_xla/csrc/reduction.h index 12cc21d4ccd..0d0d6208880 100644 --- a/torch_xla/csrc/reduction.h +++ b/torch_xla/csrc/reduction.h @@ -39,7 +39,8 @@ xla::XlaOp BuildMean(xla::XlaOp input, absl::Span dimensions, xla::XlaOp BuildStdDeviation(xla::XlaOp input, absl::Span dimensions, - bool keep_reduced_dimensions, bool unbiased); + bool keep_reduced_dimensions, + xla::int64 correction); // Builds the sum of all values by reducing all the dimensions listed in // dimensions. If keep_reduced_dimensions is true, the reduced dimensions will @@ -91,7 +92,7 @@ xla::XlaOp BuildAny(xla::XlaOp input, absl::Span dimensions, bool keep_reduced_dimensions); xla::XlaOp BuildVar(xla::XlaOp input, absl::Span dimensions, - bool unbiased, bool keep_reduced_dimensions); + xla::int64 correction, bool keep_reduced_dimensions); xla::XlaOp BuildLogsumexp(xla::XlaOp input, absl::Span dimensions, diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 170adecb2d5..7ae3a214303 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -1053,7 +1053,7 @@ class XLATensor { static XLATensor std(const XLATensor& input, std::vector dimensions, - bool keep_reduced_dimensions, bool unbiased); + bool keep_reduced_dimensions, xla::int64 correction); static XLATensor sub( const XLATensor& input, const XLATensor& other, const at::Scalar& alpha, @@ -1161,8 +1161,8 @@ class XLATensor { std::vector input_size); static XLATensor var(const XLATensor& input, - std::vector dimensions, bool unbiased, - bool keep_reduced_dimensions); + std::vector dimensions, + xla::int64 correction, bool keep_reduced_dimensions); // Like reshape, but it returns a view into the original tensor. static XLATensor view(const XLATensor& input, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 20cdb1987a4..3a6a7e2600e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2619,12 +2619,12 @@ XLATensor XLATensor::stack(absl::Span tensors, XLATensor XLATensor::std(const XLATensor& input, std::vector dimensions, - bool keep_reduced_dimensions, bool unbiased) { + bool keep_reduced_dimensions, xla::int64 correction) { return input.CreateFrom( ir::MakeNode(input.GetIrValue(), XlaHelpers::GetCanonicalDimensionIndices( dimensions, input.shape().get().rank()), - keep_reduced_dimensions, unbiased)); + keep_reduced_dimensions, correction)); } XLATensor XLATensor::sub(const XLATensor& input, const XLATensor& other, @@ -2916,13 +2916,13 @@ XLATensor XLATensor::view(const XLATensor& input, } XLATensor XLATensor::var(const XLATensor& input, - std::vector dimensions, bool unbiased, - bool keep_reduced_dimensions) { + std::vector dimensions, + xla::int64 correction, bool keep_reduced_dimensions) { return input.CreateFrom( ir::MakeNode(input.GetIrValue(), XlaHelpers::GetCanonicalDimensionIndices( dimensions, input.shape().get().rank()), - unbiased, keep_reduced_dimensions)); + correction, keep_reduced_dimensions)); } void XLATensor::zero_(XLATensor& input) {