Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,28 @@ def test_max_pool2d(self):
out = torch_xla._XLAC.max_pool2d(xt_x, 3, stride=stride, padding=padding).to_tensor()
self.assertEqualRel(out.data, expected.data)

def test_transpose(self):
x = _gen_tensor(2, 3)
xt_x = torch_xla._XLAC.XLATensor(x)
expected = x.t()
out = xt_x.t().to_tensor()
self.assertEqualDbg(out.data, expected.data)

def test_view(self):
x = _gen_tensor(32, 20, 4, 4)
xt_x = torch_xla._XLAC.XLATensor(x)
expected = x.view(-1, 320)
out = xt_x.view(-1, 320).to_tensor()
self.assertEqualDbg(out.data, expected.data)

def log_softmax(self):
x = _gen_tensor(5, 3, 4, 2)
xt_x = torch_xla._XLAC.XLATensor(x)
for dim in range(0, x.dim()):
expected = x.log_softmax(dim)
out = xt_x.log_softmax(dim).to_tensor()
self.assertEqualDbg(out.data, expected.data)


if __name__ == '__main__':
torch.set_default_tensor_type('torch.FloatTensor')
Expand Down
23 changes: 15 additions & 8 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "data_ops.h"
#include "helpers.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/core/lib/gtl/array_slice.h"

#include <functional>
#include <numeric>
Expand All @@ -12,9 +13,9 @@ namespace {
// output shape. The complete output shape has same total number of elements as
// input_sizes and matches output_sizes in all dimensions except for at most
// one, which can be inferred and stored as -1 in output_sizes.
std::vector<int64_t> GetCompleteShape(
const std::vector<int64_t>& output_sizes,
const std::vector<xla::int64>& input_sizes) {
std::vector<xla::int64> GetCompleteShape(
tensorflow::gtl::ArraySlice<const xla::int64> output_sizes,
tensorflow::gtl::ArraySlice<const xla::int64> input_sizes) {
c10::optional<size_t> incomplete_dim;
int64_t incomplete_element_count = 1;
for (size_t dim = 0; dim < output_sizes.size(); ++dim) {
Expand All @@ -29,15 +30,15 @@ std::vector<int64_t> GetCompleteShape(
}
}
if (!incomplete_dim) {
return output_sizes;
return std::vector<xla::int64>(output_sizes.begin(), output_sizes.end());
}
const auto total_element_count =
std::accumulate(input_sizes.begin(), input_sizes.end(), int64_t(1),
std::multiplies<int64_t>());
XLA_CHECK_EQ(total_element_count % incomplete_element_count, 0)
<< "Cannot infer remaining dimension";
std::vector<int64_t> complete_output_sizes(output_sizes.begin(),
output_sizes.end());
std::vector<xla::int64> complete_output_sizes(output_sizes.begin(),
output_sizes.end());
complete_output_sizes[*incomplete_dim] =
total_element_count / incomplete_element_count;
return complete_output_sizes;
Expand Down Expand Up @@ -86,9 +87,15 @@ xla::XlaOp BuildView(const torch::jit::Node* node, const xla::XlaOp& input) {
default:
XLA_ERROR() << "Unexpected node kind, must be view or reshape";
}
output_sizes =
return BuildView(input, XlaHelpers::I64List(output_sizes));
}

xla::XlaOp BuildView(
const xla::XlaOp& input,
tensorflow::gtl::ArraySlice<const xla::int64> output_sizes) {
const auto complete_output_sizes =
GetCompleteShape(output_sizes, XlaHelpers::SizesOfXlaOp(input));
return xla::Reshape(input, XlaHelpers::I64List(output_sizes));
return xla::Reshape(input, complete_output_sizes);
}

xla::XlaOp BuildExpand(const torch::jit::Node* node, const xla::XlaOp& input) {
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/data_ops.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "torch/csrc/jit/ir.h"

// Collection of XLA lowerings for operations which only involve some form of
Expand All @@ -12,6 +13,11 @@ namespace torch_xla {
// specified by the "size" attribute of the given node.
xla::XlaOp BuildView(const torch::jit::Node* node, const xla::XlaOp& input);

// Same as above, with output size provided as parameter.
xla::XlaOp BuildView(
const xla::XlaOp& input,
tensorflow::gtl::ArraySlice<const xla::int64> output_sizes);

// Creates a new tensor with the singleton dimensions expanded to the sizes
// specified by the "size" attribute of the given node.
xla::XlaOp BuildExpand(const torch::jit::Node* node, const xla::XlaOp& input);
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,17 @@ void InitXlaTensorBindings(py::module m) {
self->addcmul_(alpha, tensor1, tensor2);
return self;
})
.def("t", [](std::shared_ptr<XLATensor> self) { return self->t(); })
.def("view",
[](std::shared_ptr<XLATensor> self, py::args args) {
std::vector<xla::int64> output_sizes;
for (const auto& output_dim_size : args) {
output_sizes.push_back(output_dim_size.cast<xla::int64>());
}
return self->view(output_sizes);
})
.def("log_softmax", [](std::shared_ptr<XLATensor> self,
int dim) { return self->log_softmax(dim); })
.def("cross_replica_sum",
[](std::shared_ptr<XLATensor> self, const py::list& groups) {
std::vector<std::vector<xla::int64>> crs_groups;
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/log_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ xla::XlaOp BuildLogSoftmax(const torch::jit::Node* node,
const auto node_inputs = node->inputs();
XLA_CHECK_EQ(node_inputs.size(), size_t(2));
xla::int64 dim = node->get<int64_t>(at::attr::dim).value();
return BuildLogSoftmax(logits, dim);
}

xla::XlaOp BuildLogSoftmax(const xla::XlaOp& logits, xla::int64 dim) {
xla::Shape logits_shape = XlaHelpers::ShapeOfXlaOp(logits);
auto input_size = XlaHelpers::ShapeSizes(logits_shape);

Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/log_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ namespace torch_xla {
xla::XlaOp BuildLogSoftmax(const torch::jit::Node* node,
const xla::XlaOp& logits);

// Same as above, with the dimension provided as parameter.
xla::XlaOp BuildLogSoftmax(const xla::XlaOp& logits, xla::int64 dim);

// Computes the gradient of the input of the LogSoftmax function.
xla::XlaOp BuildLogSoftmaxGrad(const torch::jit::Node* node,
const xla::XlaOp& grad_output,
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/ops/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
namespace torch_xla {
namespace ir {
namespace ops {

namespace {

// The bias doesn't matter for shape inference.
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/ops/max_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
namespace torch_xla {
namespace ir {
namespace ops {

namespace {

// Infers the output shape of the max pooling operation.
Expand Down
45 changes: 45 additions & 0 deletions torch_xla/csrc/ops/softmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "ops/softmax.h"
#include "log_softmax.h"
#include "lowering_context.h"
#include "ops/infer_output_shape.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

// Infers the output shape of the log softmax operation.
xla::Shape NodeOutputShape(const NodeOperand& input, xla::int64 dim) {
auto lower_for_shape_fn =
[dim](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 1)
<< "Unexpected number of operands: " << operands.size();
return BuildLogSoftmax(operands[0], dim);
};
return InferOutputShape({input.node->shape()}, lower_for_shape_fn);
}

} // namespace

LogSoftmax::LogSoftmax(const NodeOperand& input, xla::int64 dim)
: Node(ir::OpKind(at::aten::log_softmax), {input},
NodeOutputShape(input, dim)),
dim_(dim) {}

XlaOpVector LogSoftmax::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = BuildLogSoftmax(input, dim_);
return ReturnOp(output, loctx);
}

std::string LogSoftmax::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", dim=" << dim_;
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
25 changes: 25 additions & 0 deletions torch_xla/csrc/ops/softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

// IR node for log(softmax) operation.
class LogSoftmax : public Node {
public:
LogSoftmax(const NodeOperand& input, xla::int64 dim);

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

private:
// The dimension along which the result is computed.
xla::int64 dim_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
48 changes: 48 additions & 0 deletions torch_xla/csrc/ops/view.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "ops/view.h"
#include "data_ops.h"
#include "lowering_context.h"
#include "ops/infer_output_shape.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"

namespace torch_xla {
namespace ir {
namespace ops {
namespace {

xla::Shape NodeOutputShape(
const NodeOperand& input,
tensorflow::gtl::ArraySlice<const xla::int64> output_sizes) {
auto lower_for_shape_fn =
[&output_sizes](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands)
-> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 1)
<< "Unexpected number of operands: " << operands.size();
return BuildView(operands[0], output_sizes);
};
return InferOutputShape({input.node->shape()}, lower_for_shape_fn);
}

} // namespace

View::View(const NodeOperand& input,
tensorflow::gtl::ArraySlice<const xla::int64> output_size)
: Node(ir::OpKind(at::aten::view), {input},
NodeOutputShape(input, output_size)),
output_size_(output_size.begin(), output_size.end()) {}

XlaOpVector View::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = BuildView(input, output_size_);
return ReturnOp(output, loctx);
}

std::string View::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", output_size=("
<< absl::StrJoin(output_size_, ", ") << ")";
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
27 changes: 27 additions & 0 deletions torch_xla/csrc/ops/view.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "ir.h"
#include "tensorflow/core/lib/gtl/array_slice.h"

namespace torch_xla {
namespace ir {
namespace ops {

// IR node for a tensor view.
class View : public Node {
public:
View(const NodeOperand& input,
tensorflow::gtl::ArraySlice<const xla::int64> output_size);

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

private:
// The possibly incomplete output size.
std::vector<xla::int64> output_size_;
};

} // namespace ops
} // namespace ir
} // namespace torch_xla
35 changes: 35 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "ops/max_pool2d.h"
#include "ops/ops.h"
#include "ops/scalar.h"
#include "ops/softmax.h"
#include "ops/view.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/metrics.h"
Expand Down Expand Up @@ -676,6 +678,39 @@ std::shared_ptr<XLATensor> XLATensor::max_pool2d(int kernel_size, int stride,
GetDevice());
}

std::shared_ptr<XLATensor> XLATensor::t() {
auto lower_fn = [](const ir::Node& node,
ir::LoweringContext* loctx) -> ir::XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_output = xla::Transpose(xla_input, {1, 0});
return node.ReturnOp(xla_output, loctx);
};
auto lower_for_shape_fn =
[](tensorflow::gtl::ArraySlice<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 1) << "Unexpected number of operands";
return xla::Transpose(operands[0], {1, 0});
};
xla::Shape output_shape =
ir::ops::InferOutputShape({shape()}, lower_for_shape_fn);
return Create(ir::ops::GenericOp(ir::OpKind(at::aten::t),
ir::OpList{ir::NodeOperand(GetIrNode())},
output_shape, std::move(lower_fn)),
GetDevice());
}

std::shared_ptr<XLATensor> XLATensor::view(
tensorflow::gtl::ArraySlice<const xla::int64> output_size) {
return Create(std::make_shared<ir::ops::View>(ir::NodeOperand(GetIrNode()),
output_size),
GetDevice());
}

std::shared_ptr<XLATensor> XLATensor::log_softmax(xla::int64 dim) {
return Create(
std::make_shared<ir::ops::LogSoftmax>(ir::NodeOperand(GetIrNode()), dim),
GetDevice());
}

std::shared_ptr<XLATensor> XLATensor::cross_replica_sum(
const std::vector<std::vector<xla::int64>>& groups) {
ir::NodePtr crs =
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ class XLATensor {
std::shared_ptr<XLATensor> max_pool2d(int kernel_size, int stride,
int padding);

std::shared_ptr<XLATensor> t();

std::shared_ptr<XLATensor> view(
tensorflow::gtl::ArraySlice<const xla::int64> output_size);

std::shared_ptr<XLATensor> log_softmax(xla::int64 dim);

std::shared_ptr<XLATensor> cross_replica_sum(
const std::vector<std::vector<xla::int64>>& groups);

Expand Down