Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class ArgTemplate(string.Template):
'take_out': FuncOpts(),
'true_divide_out': FuncOpts(),
'topk_out': FuncOpts(),
'var_out': FuncOpts(),
}

# List of tuples with the regex match first, and the corresponding FuncOpts()
Expand Down
32 changes: 32 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,38 @@ TEST_F(AtenXlaTensorTest, TestSumInDimsKeepCast) {
}
}

TEST_F(AtenXlaTensorTest, TestVar) {
torch::Tensor a = torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat));
for (bool unbiased : {true, false}) {
torch::Tensor b = torch::var(a, unbiased);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::var(xla_a, unbiased);
AllClose(b, xla_b);
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::var", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestVarWithDim) {
torch::Tensor a = torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat));
for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
for (bool keepDim : {true, false}) {
for (bool unbiased : {true, false}) {
torch::Tensor b = torch::var(a, dims, unbiased, keepDim);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::var(xla_a, dims, unbiased, keepDim);
AllClose(b, xla_b);
});
}
}
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::var", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestMaxInDim) {
torch::Tensor input =
torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat));
Expand Down
1 change: 1 addition & 0 deletions test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# test_name : floating_precision,
'test_pow_xla_float32': 0.0035,
'test_pow_xla_float64': 0.0045,
'test_var_neg_dim_xla_bfloat16': 0.01
}

DISABLED_TORCH_TESTS_ANY = {
Expand Down
19 changes: 19 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3282,6 +3282,25 @@ at::Tensor AtenXlaType::upsample_nearest2d_backward(
xla::util::ToVector<xla::int64>(input_size)));
}

at::Tensor AtenXlaType::var(const at::Tensor& self, bool unbiased) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(
XLATensor::var(bridge::GetXlaTensor(self),
xla::util::Iota<xla::int64>(
bridge::GetXlaTensor(self).shape().get().rank()),
unbiased,
/*keep_reduced_dimensions=*/false));
}

at::Tensor AtenXlaType::var(const at::Tensor& self, at::IntArrayRef dim,
bool unbiased, bool keepdim) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(
XLATensor::var(self_tensor, XlaHelpers::I64List(dim), unbiased, keepdim));
}

at::Tensor AtenXlaType::view(const at::Tensor& self, at::IntArrayRef size) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,11 @@ class AtenXlaType {
c10::optional<double> scales_h,
c10::optional<double> scales_w);

static at::Tensor var(const at::Tensor& self, bool unbiased);

static at::Tensor var(const at::Tensor& self, at::IntArrayRef dim,
bool unbiased, bool keepdim);

static at::Tensor view(const at::Tensor& self, at::IntArrayRef size);

static at::Tensor& zero_(at::Tensor& self);
Expand Down
61 changes: 61 additions & 0 deletions torch_xla/csrc/ops/var.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "torch_xla/csrc/ops/var.h"

#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& input,
std::vector<xla::int64>& dimensions, bool unbiased,
bool keep_reduced_dimensions) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildVar(operands[0], dimensions, unbiased, keep_reduced_dimensions);
};
return InferOutputShape({input.shape()}, lower_for_shape_fn);
}

} // namespace

Var::Var(const Value& input, std::vector<xla::int64> dimensions, bool unbiased,
bool keep_reduced_dimensions)
: Node(
ir::OpKind(at::aten::var), {input},
NodeOutputShape(input, dimensions, unbiased, keep_reduced_dimensions),
/*num_outputs=*/1,
xla::util::MHash(dimensions, unbiased, keep_reduced_dimensions)),
dimensions_(std::move(dimensions)),
unbiased_(unbiased),
keep_reduced_dimensions_(keep_reduced_dimensions) {}

NodePtr Var::Clone(OpList operands) const {
return MakeNode<Var>(operands.at(0), dimensions_, unbiased_,
keep_reduced_dimensions_);
}

XlaOpVector Var::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(
BuildVar(input, dimensions_, unbiased_, keep_reduced_dimensions_), loctx);
}

std::string Var::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", dimensions=(" << absl::StrJoin(dimensions_, ", ")
<< "), unbiased=" << unbiased_
<< ", keep_reduced_dimensions=" << keep_reduced_dimensions_;
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
37 changes: 37 additions & 0 deletions torch_xla/csrc/ops/var.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include <vector>

#include "tensorflow/compiler/xla/types.h"
#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class Var : public Node {
public:
Var(const Value& input, std::vector<xla::int64> dimensions, bool unbiased,
bool keep_reduced_dimensions);

std::string ToString() const override;

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

const std::vector<xla::int64>& dimensions() const { return dimensions_; }

bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; }

bool unbiased() const { return unbiased_; }

private:
std::vector<xla::int64> dimensions_;
bool unbiased_;
bool keep_reduced_dimensions_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
20 changes: 20 additions & 0 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,26 @@ xla::XlaOp BuildAny(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
return result;
}

xla::XlaOp BuildVar(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
bool unbiased, bool keep_reduced_dimensions) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
SummationResult mean_result =
CreateSummation(input, dimensions, /*keep_reduced_dimensions=*/true,
/*scale=*/true);
// var = ((input - mean)^2).sum(dim) / reduced_element_count
xla::XlaOp diff = input - mean_result.result;
xla::XlaOp unscaled_result =
CreateSummation(diff * diff, dimensions, keep_reduced_dimensions,
/*scale=*/false)
.result;
xla::XlaOp count = mean_result.rinfo.element_count.size;
if (unbiased) {
count = count - xla::One(input.builder(),
XlaHelpers::ShapeOfXlaOp(count).element_type());
}
return GetScaleValue(unscaled_result, count, input_shape.element_type());
}

xla::XlaOp BuildLogsumexp(xla::XlaOp input,
absl::Span<const xla::int64> dimensions,
bool keep_reduced_dimensions) {
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ xla::XlaOp BuildAll(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
xla::XlaOp BuildAny(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
bool keep_reduced_dimensions);

xla::XlaOp BuildVar(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
bool unbiased, bool keep_reduced_dimensions);

xla::XlaOp BuildLogsumexp(xla::XlaOp input,
absl::Span<const xla::int64> dimensions,
bool keep_reduced_dimensions);
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,10 @@ class XLATensor {
const XLATensor& grad_output, std::vector<xla::int64> output_size,
std::vector<xla::int64> input_size);

static XLATensor var(const XLATensor& input,
std::vector<xla::int64> dimensions, bool unbiased,
bool keep_reduced_dimensions);

// Like reshape, but it returns a view into the original tensor.
static XLATensor view(const XLATensor& input,
absl::Span<const xla::int64> output_size);
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
#include "torch_xla/csrc/ops/upsample_nearest2d.h"
#include "torch_xla/csrc/ops/upsample_nearest2d_backward.h"
#include "torch_xla/csrc/ops/user_computation.h"
#include "torch_xla/csrc/ops/var.h"
#include "torch_xla/csrc/ops/view.h"
#include "torch_xla/csrc/shape_builder.h"
#include "torch_xla/csrc/tensor.h"
Expand Down Expand Up @@ -2783,6 +2784,16 @@ XLATensor XLATensor::view(const XLATensor& input,
return input.CreateViewTensor(std::move(view_info));
}

XLATensor XLATensor::var(const XLATensor& input,
std::vector<xla::int64> dimensions, bool unbiased,
bool keep_reduced_dimensions) {
return input.CreateFrom(
ir::MakeNode<ir::ops::Var>(input.GetIrValue(),
XlaHelpers::GetCanonicalDimensionIndices(
dimensions, input.shape().get().rank()),
unbiased, keep_reduced_dimensions));
}

void XLATensor::zero_(XLATensor& input) {
ir::Value constant =
GetIrValueForScalar(0.0, input.shape(), input.GetDevice());
Expand Down