Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,25 @@ TEST_F(AtenXlaTensorTest, TestStdInDim) {
}
}

TEST_F(AtenXlaTensorTest, TestStdWithCorrection) {
torch::Tensor a = torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat));
int rank = a.dim();
c10::optional<int64_t> corrections[] = {1, 2, c10::nullopt};
for (const auto& correction : corrections) {
for (auto keepdim : {true, false}) {
for (const auto& dim :
std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
torch::Tensor b = torch::std(a, dim, correction, keepdim);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::std(xla_a, dim, correction, keepdim);
AllClose(b, xla_b);
});
}
}
}
}

TEST_F(AtenXlaTensorTest, TestSum) {
torch::Tensor a = torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::sum(a);
Expand Down Expand Up @@ -1387,6 +1406,25 @@ TEST_F(AtenXlaTensorTest, TestVarWithDim) {
ExpectCounterChanged("xla::var", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestVarWithCorrection) {
torch::Tensor a = torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat));
c10::optional<int64_t> corrections[] = {1, 2, c10::nullopt};
for (const auto& dim : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
for (bool keepDim : {true, false}) {
for (const auto& correction : corrections) {
torch::Tensor b = torch::var(a, dim, correction, keepDim);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::var(xla_a, dim, correction, keepDim);
AllClose(b, xla_b);
});
}
}
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::var", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestMaxInDim) {
torch::Tensor input =
torch::rand({4, 3, 4}, torch::TensorOptions(torch::kFloat));
Expand Down
36 changes: 31 additions & 5 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3128,15 +3128,27 @@ at::Tensor AtenXlaType::std(const at::Tensor& self, bool unbiased) {
return bridge::AtenFromXlaTensor(XLATensor::std(
self_tensor,
xla::util::Iota<xla::int64>(self_tensor.shape().get().rank()),
/*keep_reduced_dimensions=*/false, unbiased));
/*keep_reduced_dimensions=*/false, /*correction=*/unbiased ? 1 : 0));
}

at::Tensor AtenXlaType::std(const at::Tensor& self, at::IntArrayRef dim,
bool unbiased, bool keepdim) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::std(
bridge::GetXlaTensor(self), xla::util::ToVector<xla::int64>(dim),
/*keep_reduced_dimensions=*/keepdim, unbiased));
bridge::GetXlaTensor(self), xla::util::ToVector<xla::int64>(dim), keepdim,
/*correction=*/unbiased ? 1 : 0));
}

at::Tensor AtenXlaType::std(const at::Tensor& self,
c10::optional<at::IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(XLATensor::std(
self_tensor,
dim ? xla::util::ToVector<xla::int64>(*dim)
: xla::util::Iota<xla::int64>(self_tensor.shape().get().rank()),
keepdim, correction ? *correction : 1));
}

at::Tensor AtenXlaType::sub(const at::Tensor& self, const at::Tensor& other,
Expand Down Expand Up @@ -3532,7 +3544,7 @@ at::Tensor AtenXlaType::var(const at::Tensor& self, bool unbiased) {
XLATensor::var(bridge::GetXlaTensor(self),
xla::util::Iota<xla::int64>(
bridge::GetXlaTensor(self).shape().get().rank()),
unbiased,
/*correction=*/unbiased ? 1 : 0,
/*keep_reduced_dimensions=*/false));
}

Expand All @@ -3541,7 +3553,21 @@ at::Tensor AtenXlaType::var(const at::Tensor& self, at::IntArrayRef dim,
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(
XLATensor::var(self_tensor, XlaHelpers::I64List(dim), unbiased, keepdim));
XLATensor::var(self_tensor, XlaHelpers::I64List(dim),
/*correction=*/unbiased ? 1 : 0, keepdim));
}

at::Tensor AtenXlaType::var(const at::Tensor& self,
c10::optional<at::IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(
XLATensor::var(self_tensor,
dim ? XlaHelpers::I64List(*dim)
: xla::util::Iota<xla::int64>(
bridge::GetXlaTensor(self).shape().get().rank()),
correction ? *correction : 1, keepdim));
}

at::Tensor AtenXlaType::view(const at::Tensor& self, at::IntArrayRef size) {
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,10 @@ class AtenXlaType {
static at::Tensor std(const at::Tensor& self, at::IntArrayRef dim,
bool unbiased, bool keepdim);

static at::Tensor std(const at::Tensor& self,
c10::optional<at::IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim);

static at::Tensor sub(const at::Tensor& self, const at::Tensor& other,
const at::Scalar& alpha);

Expand Down Expand Up @@ -1102,6 +1106,10 @@ class AtenXlaType {
static at::Tensor var(const at::Tensor& self, at::IntArrayRef dim,
bool unbiased, bool keepdim);

static at::Tensor var(const at::Tensor& self,
c10::optional<at::IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim);

static at::Tensor view(const at::Tensor& self, at::IntArrayRef size);

static at::Tensor& zero_(at::Tensor& self);
Expand Down
19 changes: 10 additions & 9 deletions torch_xla/csrc/ops/std.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,48 @@ namespace {

xla::Shape NodeOutputShape(const Value& input,
std::vector<xla::int64>& dimensions,
bool keep_reduced_dimensions, bool unbiased) {
bool keep_reduced_dimensions,
xla::int64 correction) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildStdDeviation(operands[0], dimensions, keep_reduced_dimensions,
unbiased);
correction);
};
return InferOutputShape({input.shape()}, lower_for_shape_fn);
}

} // namespace

Std::Std(const Value& input, std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions, bool unbiased)
bool keep_reduced_dimensions, xla::int64 correction)
: Node(ir::OpKind(at::aten::std), {input},
[&]() {
return NodeOutputShape(input, dimensions, keep_reduced_dimensions,
unbiased);
correction);
},
/*num_outputs=*/1,
xla::util::MHash(dimensions, keep_reduced_dimensions, unbiased)),
xla::util::MHash(dimensions, keep_reduced_dimensions, correction)),
dimensions_(std::move(dimensions)),
keep_reduced_dimensions_(keep_reduced_dimensions),
unbiased_(unbiased) {}
correction_(correction) {}

NodePtr Std::Clone(OpList operands) const {
return MakeNode<Std>(operands.at(0), dimensions_, keep_reduced_dimensions_,
unbiased_);
correction_);
}

XlaOpVector Std::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildStdDeviation(input, dimensions_,
keep_reduced_dimensions_, unbiased_),
keep_reduced_dimensions_, correction_),
loctx);
}

std::string Std::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", dimensions=(" << absl::StrJoin(dimensions_, ", ")
<< "), keep_reduced_dimensions=" << keep_reduced_dimensions_
<< ", unbiased=" << unbiased_;
<< ", correction=" << correction_;
return ss.str();
}

Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/ops/std.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace ops {
class Std : public Node {
public:
Std(const Value& input, std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions, bool unbiased);
bool keep_reduced_dimensions, xla::int64 correction);

std::string ToString() const override;

Expand All @@ -24,12 +24,12 @@ class Std : public Node {

bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; }

bool unbiased() const { return unbiased_; }
xla::int64 correction() const { return correction_; }

private:
std::vector<xla::int64> dimensions_;
bool keep_reduced_dimensions_;
bool unbiased_;
xla::int64 correction_;
};

} // namespace ops
Expand Down
29 changes: 16 additions & 13 deletions torch_xla/csrc/ops/var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,46 @@ namespace ops {
namespace {

xla::Shape NodeOutputShape(const Value& input,
std::vector<xla::int64>& dimensions, bool unbiased,
std::vector<xla::int64>& dimensions,
xla::int64 correction,
bool keep_reduced_dimensions) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildVar(operands[0], dimensions, unbiased, keep_reduced_dimensions);
return BuildVar(operands[0], dimensions, correction,
keep_reduced_dimensions);
};
return InferOutputShape({input.shape()}, lower_for_shape_fn);
}

} // namespace

Var::Var(const Value& input, std::vector<xla::int64> dimensions, bool unbiased,
bool keep_reduced_dimensions)
: Node(
ir::OpKind(at::aten::var), {input},
NodeOutputShape(input, dimensions, unbiased, keep_reduced_dimensions),
/*num_outputs=*/1,
xla::util::MHash(dimensions, unbiased, keep_reduced_dimensions)),
Var::Var(const Value& input, std::vector<xla::int64> dimensions,
xla::int64 correction, bool keep_reduced_dimensions)
: Node(ir::OpKind(at::aten::var), {input},
NodeOutputShape(input, dimensions, correction,
keep_reduced_dimensions),
/*num_outputs=*/1,
xla::util::MHash(dimensions, correction, keep_reduced_dimensions)),
dimensions_(std::move(dimensions)),
unbiased_(unbiased),
correction_(correction),
keep_reduced_dimensions_(keep_reduced_dimensions) {}

NodePtr Var::Clone(OpList operands) const {
return MakeNode<Var>(operands.at(0), dimensions_, unbiased_,
return MakeNode<Var>(operands.at(0), dimensions_, correction_,
keep_reduced_dimensions_);
}

XlaOpVector Var::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(
BuildVar(input, dimensions_, unbiased_, keep_reduced_dimensions_), loctx);
BuildVar(input, dimensions_, correction_, keep_reduced_dimensions_),
loctx);
}

std::string Var::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", dimensions=(" << absl::StrJoin(dimensions_, ", ")
<< "), unbiased=" << unbiased_
<< "), correction=" << correction_
<< ", keep_reduced_dimensions=" << keep_reduced_dimensions_;
return ss.str();
}
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/ops/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace ops {

class Var : public Node {
public:
Var(const Value& input, std::vector<xla::int64> dimensions, bool unbiased,
bool keep_reduced_dimensions);
Var(const Value& input, std::vector<xla::int64> dimensions,
xla::int64 correction, bool keep_reduced_dimensions);

std::string ToString() const override;

Expand All @@ -24,11 +24,11 @@ class Var : public Node {

bool keep_reduced_dimensions() const { return keep_reduced_dimensions_; }

bool unbiased() const { return unbiased_; }
xla::int64 correction() const { return correction_; }

private:
std::vector<xla::int64> dimensions_;
bool unbiased_;
xla::int64 correction_;
bool keep_reduced_dimensions_;
};

Expand Down
29 changes: 17 additions & 12 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ xla::XlaOp BuildMean(xla::XlaOp input, absl::Span<const xla::int64> dimensions,

xla::XlaOp BuildStdDeviation(xla::XlaOp input,
absl::Span<const xla::int64> dimensions,
bool keep_reduced_dimensions, bool unbiased) {
bool keep_reduced_dimensions,
xla::int64 correction) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp mean =
BuildMean(input, dimensions, /*keep_reduced_dimensions*/ true);
Expand All @@ -294,15 +295,17 @@ xla::XlaOp BuildStdDeviation(xla::XlaOp input,
xla::XlaOp input_mean_diff = input - bcast_mean;
xla::XlaOp squared_var = input_mean_diff * input_mean_diff;
xla::XlaOp squared_result;
if (unbiased) {
if (correction != 0) {
SummationResult sum_result = CreateSummation(
squared_var, dimensions, keep_reduced_dimensions, /*scale=*/false);
squared_result = GetScaleValue(
sum_result.result,
sum_result.rinfo.element_count.size -
xla::One(input.builder(), XlaHelpers::TypeOfXlaOp(
sum_result.rinfo.element_count.size)),
input_shape.element_type());
xla::XlaOp correction_scalar = XlaHelpers::ScalarValue(
correction,
XlaHelpers::TypeOfXlaOp(sum_result.rinfo.element_count.size),
input.builder());
squared_result =
GetScaleValue(sum_result.result,
sum_result.rinfo.element_count.size - correction_scalar,
input_shape.element_type());
} else {
SummationResult sum_result = CreateSummation(
squared_var, dimensions, keep_reduced_dimensions, /*scale=*/true);
Expand Down Expand Up @@ -458,7 +461,7 @@ xla::XlaOp BuildAny(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
}

xla::XlaOp BuildVar(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
bool unbiased, bool keep_reduced_dimensions) {
xla::int64 correction, bool keep_reduced_dimensions) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
SummationResult mean_result =
CreateSummation(input, dimensions, /*keep_reduced_dimensions=*/true,
Expand All @@ -470,9 +473,11 @@ xla::XlaOp BuildVar(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
/*scale=*/false)
.result;
xla::XlaOp count = mean_result.rinfo.element_count.size;
if (unbiased) {
count = count - xla::One(input.builder(),
XlaHelpers::ShapeOfXlaOp(count).element_type());
if (correction != 0) {
count =
count - XlaHelpers::ScalarValue(
correction, XlaHelpers::ShapeOfXlaOp(count).element_type(),
input.builder());
}
return GetScaleValue(unscaled_result, count, input_shape.element_type());
}
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ xla::XlaOp BuildMean(xla::XlaOp input, absl::Span<const xla::int64> dimensions,

xla::XlaOp BuildStdDeviation(xla::XlaOp input,
absl::Span<const xla::int64> dimensions,
bool keep_reduced_dimensions, bool unbiased);
bool keep_reduced_dimensions,
xla::int64 correction);

// Builds the sum of all values by reducing all the dimensions listed in
// dimensions. If keep_reduced_dimensions is true, the reduced dimensions will
Expand Down Expand Up @@ -91,7 +92,7 @@ xla::XlaOp BuildAny(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
bool keep_reduced_dimensions);

xla::XlaOp BuildVar(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
bool unbiased, bool keep_reduced_dimensions);
xla::int64 correction, bool keep_reduced_dimensions);

xla::XlaOp BuildLogsumexp(xla::XlaOp input,
absl::Span<const xla::int64> dimensions,
Expand Down
Loading