Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,92 @@ TEST_F(AtenXlaTensorTest, TestPermute) {
}
}

TEST_F(AtenXlaTensorTest, TestTriu) {
int size = 5;
at::Tensor input = GetTestTensor({size, size});
// Test all diagonals and out of bounds (must be no-op).
for (int diagonal = -size; diagonal <= size; ++diagonal) {
at::Tensor output = at::triu(input, diagonal);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::triu(xla_input, diagonal);
AllClose(output, xla_output);
});
}
}

TEST_F(AtenXlaTensorTest, TestTriuNonSquare) {
int size = 5;
at::Tensor input = GetTestTensor({size, size + 1});
// Test all diagonals and out of bounds (must be no-op).
for (int diagonal = -size; diagonal <= size; ++diagonal) {
at::Tensor output = at::triu(input, diagonal);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::triu(xla_input, diagonal);
AllClose(output, xla_output);
});
}
}

TEST_F(AtenXlaTensorTest, TestTriuBatch) {
int size = 5;
int batch_size = 3;
at::Tensor input = GetTestTensor({batch_size, size, size});
// Test all diagonals and out of bounds (must be no-op).
for (int diagonal = -size; diagonal <= size; ++diagonal) {
at::Tensor output = at::triu(input, diagonal);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::triu(xla_input, diagonal);
AllClose(output, xla_output);
});
}
}

TEST_F(AtenXlaTensorTest, TestTril) {
int size = 5;
at::Tensor input = GetTestTensor({size, size});
// Test all diagonals and out of bounds (must be no-op).
for (int diagonal = -size; diagonal <= size; ++diagonal) {
at::Tensor output = at::tril(input, diagonal);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::tril(xla_input, diagonal);
AllClose(output, xla_output);
});
}
}

TEST_F(AtenXlaTensorTest, TestTrilNonSquare) {
int size = 5;
at::Tensor input = GetTestTensor({size, size + 1});
// Test all diagonals and out of bounds (must be no-op).
for (int diagonal = -size; diagonal <= size; ++diagonal) {
at::Tensor output = at::tril(input, diagonal);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::tril(xla_input, diagonal);
AllClose(output, xla_output);
});
}
}

TEST_F(AtenXlaTensorTest, TestTrilBatch) {
int size = 5;
int batch_size = 3;
at::Tensor input = GetTestTensor({batch_size, size, size});
// Test all diagonals and out of bounds (must be no-op).
for (int diagonal = -size; diagonal <= size; ++diagonal) {
at::Tensor output = at::tril(input, diagonal);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::tril(xla_input, diagonal);
AllClose(output, xla_output);
});
}
}

TEST_F(AtenXlaTensorTest, TestAvgPool2DBackward) {
int kernel_size = 2;
for (int stride = 1; stride <= 2; ++stride) {
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,16 @@ at::Tensor AtenXlaType::where(const at::Tensor& condition,
bridge::GetXlaTensor(other)));
}

at::Tensor AtenXlaType::triu(const at::Tensor& self, int64_t diagonal) const {
return bridge::AtenFromXlaTensor(
XLATensor::triu(bridge::GetXlaTensor(self), diagonal));
}

at::Tensor AtenXlaType::tril(const at::Tensor& self, int64_t diagonal) const {
return bridge::AtenFromXlaTensor(
XLATensor::tril(bridge::GetXlaTensor(self), diagonal));
}

void AtenXlaType::SetFullConvPrecision(
bool use_full_conv_precision /*= true*/) {
s_use_full_conv_precision_ = use_full_conv_precision;
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ class AtenXlaType : public AtenXlaTypeBase {
at::Tensor where(const at::Tensor& condition, const at::Tensor& self,
const at::Tensor& other) const override;

at::Tensor triu(const at::Tensor& self, int64_t diagonal) const override;

at::Tensor tril(const at::Tensor& self, int64_t diagonal) const override;

static void SetFullConvPrecision(bool use_full_conv_precision = true);

// Registers the ATEN types for the XLA tensors.
Expand Down
17 changes: 17 additions & 0 deletions torch_xla/csrc/matrix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "torch_xla/csrc/matrix.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"

namespace torch_xla {

xla::XlaOp BuildTriu(const xla::XlaOp& input, int diagonal) {
return xla::Select(xla::TriangleMask(input, diagonal - 1),
xla::ZerosLike(input), input);
}

xla::XlaOp BuildTril(const xla::XlaOp& input, int diagonal) {
return xla::Select(xla::TriangleMask(input, diagonal), input,
xla::ZerosLike(input));
}

} // namespace torch_xla
12 changes: 12 additions & 0 deletions torch_xla/csrc/matrix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "torch/csrc/jit/ir.h"

namespace torch_xla {

xla::XlaOp BuildTriu(const xla::XlaOp& input, int diagonal);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use int64 as well here, for uniformity.


xla::XlaOp BuildTril(const xla::XlaOp& input, int diagonal);

} // namespace torch_xla
29 changes: 29 additions & 0 deletions torch_xla/csrc/ops/tril.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "torch_xla/csrc/ops/tril.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/matrix.h"

namespace torch_xla {
namespace ir {
namespace ops {

Tril::Tril(const Value& input, xla::int64 diagonal)
: Node(ir::OpKind(at::aten::triu), {input}, input.shape(),
/*num_outputs=*/1, xla::util::MHash(diagonal)),
diagonal_(diagonal) {}

XlaOpVector Tril::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = BuildTril(input, diagonal_);
return ReturnOp(output, loctx);
}

std::string Tril::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", diagonal=" << diagonal_;
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
27 changes: 27 additions & 0 deletions torch_xla/csrc/ops/tril.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

// Node for the lower triangular part of a matrix (2-D tensor) or batch of
// matrices input.
class Tril : public Node {
public:
Tril(const Value& input, xla::int64 diagonal);

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

xla::int64 diagonal() const { return diagonal_; }

private:
xla::int64 diagonal_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
29 changes: 29 additions & 0 deletions torch_xla/csrc/ops/triu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "torch_xla/csrc/ops/triu.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/matrix.h"

namespace torch_xla {
namespace ir {
namespace ops {

Triu::Triu(const Value& input, xla::int64 diagonal)
: Node(ir::OpKind(at::aten::triu), {input}, input.shape(),
/*num_outputs=*/1, xla::util::MHash(diagonal)),
diagonal_(diagonal) {}

XlaOpVector Triu::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = BuildTriu(input, diagonal_);
return ReturnOp(output, loctx);
}

std::string Triu::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", diagonal=" << diagonal_;
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
27 changes: 27 additions & 0 deletions torch_xla/csrc/ops/triu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

// Node for the upper triangular part of a matrix (2-D tensor) or batch of
// matrices input.
class Triu : public Node {
public:
Triu(const Value& input, xla::int64 diagonal);

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

xla::int64 diagonal() const { return diagonal_; }

private:
xla::int64 diagonal_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
12 changes: 12 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
#include "torch_xla/csrc/ops/squeeze.h"
#include "torch_xla/csrc/ops/threshold.h"
#include "torch_xla/csrc/ops/threshold_backward.h"
#include "torch_xla/csrc/ops/tril.h"
#include "torch_xla/csrc/ops/triu.h"
#include "torch_xla/csrc/ops/unsqueeze.h"
#include "torch_xla/csrc/ops/view.h"
#include "torch_xla/csrc/tensor_util.h"
Expand Down Expand Up @@ -950,6 +952,16 @@ XLATensor XLATensor::unsqueeze(const XLATensor& input, int dim) {
input.GetDevice());
}

XLATensor XLATensor::triu(const XLATensor& input, xla::int64 diagonal) {
return Create(ir::MakeNode<ir::ops::Triu>(input.GetIrValue(), diagonal),
input.GetDevice());
}

XLATensor XLATensor::tril(const XLATensor& input, xla::int64 diagonal) {
return Create(ir::MakeNode<ir::ops::Tril>(input.GetIrValue(), diagonal),
input.GetDevice());
}

XLATensor XLATensor::where(const XLATensor& condition, const XLATensor& input,
const XLATensor& other) {
return Create(ir::ops::Where(condition.GetIrValue(), input.GetIrValue(),
Expand Down
14 changes: 10 additions & 4 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,9 @@ class XLATensor {

static XLATensor max(const XLATensor& input, const XLATensor& other);

static XLATensor argmax(const XLATensor& input, xla::int64 dim,
bool keepdim);
static XLATensor argmax(const XLATensor& input, xla::int64 dim, bool keepdim);

static XLATensor argmin(const XLATensor& input, xla::int64 dim,
bool keepdim);
static XLATensor argmin(const XLATensor& input, xla::int64 dim, bool keepdim);

// Like batch_norm, but returns additional save_mean and save_invstd used by
// the backward pass.
Expand Down Expand Up @@ -349,6 +347,14 @@ class XLATensor {
// Insert a dimension of size one at the specified position.
static XLATensor unsqueeze(const XLATensor& input, int dim);

// Returns the upper triangular part of a matrix (2-D tensor) or batch of
// matrices input, the other elements of the result tensor out are set to 0.
static XLATensor triu(const XLATensor& input, xla::int64 diagonal);

// Returns the lower triangular part of a matrix (2-D tensor) or batch of
// matrices input, the other elements of the result tensor out are set to 0.
static XLATensor tril(const XLATensor& input, xla::int64 diagonal);

static XLATensor where(const XLATensor& condition, const XLATensor& input,
const XLATensor& other);

Expand Down