diff --git a/.torch_commit_id b/.torch_commit_id deleted file mode 100644 index 2ca54e3374cb..000000000000 --- a/.torch_commit_id +++ /dev/null @@ -1 +0,0 @@ -19addc7eb0bbb0074a3a57b6598382aeaa2222c9 diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index dac8d6c2e261..493dcdf8ac59 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -72,18 +72,8 @@ xla::XlaOp BuildView(const torch::jit::Node* node, const xla::XlaOp& input) { XLA_CHECK_EQ(node_inputs.size(), 2); const auto node_outputs = node->outputs(); XLA_CHECK_EQ(node_outputs.size(), 1); - // Try to use the second argument of the operator as the target shape. - std::vector output_sizes; - switch (node->kind()) { - case at::aten::view: - output_sizes = node->get>(at::attr::size).value(); - break; - case at::aten::reshape: - output_sizes = node->get>(at::attr::shape).value(); - break; - default: - XLA_ERROR() << "Unexpected node kind, must be view or reshape"; - } + std::vector output_sizes = + node->get>(at::attr::size).value(); return BuildView(input, XlaHelpers::I64List(output_sizes)); } @@ -95,6 +85,26 @@ xla::XlaOp BuildView( return xla::Reshape(input, complete_output_sizes); } +xla::XlaOp BuildReshape( + const torch::jit::Node* node, const xla::XlaOp& input, + const std::unordered_map>& + size_op_values_tracking) { + std::vector output_sizes; + if (node->hasAttribute(at::attr::shape)) { + output_sizes = XlaHelpers::I64List( + node->get>(at::attr::shape).value()); + } else { + const auto size_op_value_it = + size_op_values_tracking.find(node->input(1)->unique()); + XLA_CHECK(size_op_value_it != size_op_values_tracking.end()) + << "at::aten::reshape only allowed when second parameter is a " + "constant size: " + << *node; + output_sizes = size_op_value_it->second; + } + return BuildView(input, output_sizes); +} + xla::XlaOp SqueezeTrivialDimension(const xla::XlaOp& input, size_t dim) { auto input_sizes = XlaHelpers::SizesOfXlaOp(input); XLA_CHECK_LT(dim, input_sizes.size()); diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index f46ba43dfb0f..c2892ceb944b 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -27,6 +27,12 @@ xla::XlaOp BuildView( const xla::XlaOp& input, tensorflow::gtl::ArraySlice output_sizes); +// Creates a tensor reshape operation. +xla::XlaOp BuildReshape( + const torch::jit::Node* node, const xla::XlaOp& input, + const std::unordered_map>& + size_op_values_tracking); + // Squeezes the given dimension if trivial (size 1), returns the unchanged input // otherwise. xla::XlaOp SqueezeTrivialDimension(const xla::XlaOp& input, size_t dim); diff --git a/torch_xla/csrc/translator.cpp b/torch_xla/csrc/translator.cpp index e9f1618a903d..a43a5ae51703 100644 --- a/torch_xla/csrc/translator.cpp +++ b/torch_xla/csrc/translator.cpp @@ -422,11 +422,20 @@ void TranslateLogSoftmaxBackward( cctx->AddNodeOp(node, xla_output); } +void TranslateView(const torch::jit::Node* node, ComputationContext* cctx, + xla::PrecisionConfig::Precision /*conv_precision*/, + xla::XlaBuilder* /*b*/) { + XLA_CHECK_EQ(node->inputs().size(), 2); + xla::XlaOp xla_output = BuildView(node, cctx->OpForInput(node, 0)); + cctx->AddNodeOp(node, xla_output); +} + void TranslateReshape(const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* /*b*/) { XLA_CHECK_EQ(node->inputs().size(), 2); - xla::XlaOp xla_output = BuildView(node, cctx->OpForInput(node, 0)); + xla::XlaOp xla_output = + BuildReshape(node, cctx->OpForInput(node, 0), cctx->GetSizeOpValues()); cctx->AddNodeOp(node, xla_output); } @@ -624,7 +633,7 @@ CreateTranslationHandlers() { (*t)[at::aten::log_softmax] = TranslateLogSoftmax; (*t)[at::aten::_log_softmax_backward_data] = TranslateLogSoftmaxBackward; (*t)[at::aten::reshape] = TranslateReshape; - (*t)[at::aten::view] = TranslateReshape; + (*t)[at::aten::view] = TranslateView; (*t)[at::aten::expand] = TranslateExpand; (*t)[at::aten::stack] = TranslateStack; (*t)[at::aten::cat] = TranslateCat;