Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .torch_commit_id

This file was deleted.

34 changes: 22 additions & 12 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,8 @@ xla::XlaOp BuildView(const torch::jit::Node* node, const xla::XlaOp& input) {
XLA_CHECK_EQ(node_inputs.size(), 2);
const auto node_outputs = node->outputs();
XLA_CHECK_EQ(node_outputs.size(), 1);
// Try to use the second argument of the operator as the target shape.
std::vector<int64_t> output_sizes;
switch (node->kind()) {
case at::aten::view:
output_sizes = node->get<std::vector<int64_t>>(at::attr::size).value();
break;
case at::aten::reshape:
output_sizes = node->get<std::vector<int64_t>>(at::attr::shape).value();
break;
default:
XLA_ERROR() << "Unexpected node kind, must be view or reshape";
}
std::vector<int64_t> output_sizes =
node->get<std::vector<int64_t>>(at::attr::size).value();
return BuildView(input, XlaHelpers::I64List(output_sizes));
}

Expand All @@ -95,6 +85,26 @@ xla::XlaOp BuildView(
return xla::Reshape(input, complete_output_sizes);
}

xla::XlaOp BuildReshape(
const torch::jit::Node* node, const xla::XlaOp& input,
const std::unordered_map<size_t, std::vector<xla::int64>>&
size_op_values_tracking) {
std::vector<xla::int64> output_sizes;
if (node->hasAttribute(at::attr::shape)) {
output_sizes = XlaHelpers::I64List(
node->get<std::vector<int64_t>>(at::attr::shape).value());
} else {
const auto size_op_value_it =
size_op_values_tracking.find(node->input(1)->unique());
XLA_CHECK(size_op_value_it != size_op_values_tracking.end())
<< "at::aten::reshape only allowed when second parameter is a "
"constant size: "
<< *node;
output_sizes = size_op_value_it->second;
}
return BuildView(input, output_sizes);
}

xla::XlaOp SqueezeTrivialDimension(const xla::XlaOp& input, size_t dim) {
auto input_sizes = XlaHelpers::SizesOfXlaOp(input);
XLA_CHECK_LT(dim, input_sizes.size());
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/data_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ xla::XlaOp BuildView(
const xla::XlaOp& input,
tensorflow::gtl::ArraySlice<const xla::int64> output_sizes);

// Creates a tensor reshape operation.
xla::XlaOp BuildReshape(
const torch::jit::Node* node, const xla::XlaOp& input,
const std::unordered_map<size_t, std::vector<xla::int64>>&
size_op_values_tracking);

// Squeezes the given dimension if trivial (size 1), returns the unchanged input
// otherwise.
xla::XlaOp SqueezeTrivialDimension(const xla::XlaOp& input, size_t dim);
Expand Down
13 changes: 11 additions & 2 deletions torch_xla/csrc/translator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,20 @@ void TranslateLogSoftmaxBackward(
cctx->AddNodeOp(node, xla_output);
}

void TranslateView(const torch::jit::Node* node, ComputationContext* cctx,
xla::PrecisionConfig::Precision /*conv_precision*/,
xla::XlaBuilder* /*b*/) {
XLA_CHECK_EQ(node->inputs().size(), 2);
xla::XlaOp xla_output = BuildView(node, cctx->OpForInput(node, 0));
cctx->AddNodeOp(node, xla_output);
}

void TranslateReshape(const torch::jit::Node* node, ComputationContext* cctx,
xla::PrecisionConfig::Precision /*conv_precision*/,
xla::XlaBuilder* /*b*/) {
XLA_CHECK_EQ(node->inputs().size(), 2);
xla::XlaOp xla_output = BuildView(node, cctx->OpForInput(node, 0));
xla::XlaOp xla_output =
BuildReshape(node, cctx->OpForInput(node, 0), cctx->GetSizeOpValues());
cctx->AddNodeOp(node, xla_output);
}

Expand Down Expand Up @@ -624,7 +633,7 @@ CreateTranslationHandlers() {
(*t)[at::aten::log_softmax] = TranslateLogSoftmax;
(*t)[at::aten::_log_softmax_backward_data] = TranslateLogSoftmaxBackward;
(*t)[at::aten::reshape] = TranslateReshape;
(*t)[at::aten::view] = TranslateReshape;
(*t)[at::aten::view] = TranslateView;
(*t)[at::aten::expand] = TranslateExpand;
(*t)[at::aten::stack] = TranslateStack;
(*t)[at::aten::cat] = TranslateCat;
Expand Down