From 8cf376fc022606b445a97b7bb4b3189166d9ca7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20=C5=9Euhan?= Date: Thu, 14 Feb 2019 17:20:03 -0800 Subject: [PATCH 1/2] Allow graph parameters for reshapes target size --- .torch_commit_id | 1 - torch_xla/csrc/data_ops.cpp | 33 +++++++++++++++++++++------------ torch_xla/csrc/data_ops.h | 6 ++++++ torch_xla/csrc/translator.cpp | 13 +++++++++++-- 4 files changed, 38 insertions(+), 15 deletions(-) delete mode 100644 .torch_commit_id diff --git a/.torch_commit_id b/.torch_commit_id deleted file mode 100644 index 2ca54e3374cb..000000000000 --- a/.torch_commit_id +++ /dev/null @@ -1 +0,0 @@ -19addc7eb0bbb0074a3a57b6598382aeaa2222c9 diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index dac8d6c2e261..e63fe63f1e9e 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -72,18 +72,8 @@ xla::XlaOp BuildView(const torch::jit::Node* node, const xla::XlaOp& input) { XLA_CHECK_EQ(node_inputs.size(), 2); const auto node_outputs = node->outputs(); XLA_CHECK_EQ(node_outputs.size(), 1); - // Try to use the second argument of the operator as the target shape. - std::vector output_sizes; - switch (node->kind()) { - case at::aten::view: - output_sizes = node->get>(at::attr::size).value(); - break; - case at::aten::reshape: - output_sizes = node->get>(at::attr::shape).value(); - break; - default: - XLA_ERROR() << "Unexpected node kind, must be view or reshape"; - } + std::vector output_sizes = + node->get>(at::attr::size).value(); return BuildView(input, XlaHelpers::I64List(output_sizes)); } @@ -95,6 +85,25 @@ xla::XlaOp BuildView( return xla::Reshape(input, complete_output_sizes); } +xla::XlaOp BuildReshape( + const torch::jit::Node* node, const xla::XlaOp& input, + const XlaComputationInOut::SizeOpValues& size_op_values_tracking) { + std::vector output_sizes; + if (node->hasAttribute(at::attr::shape)) { + output_sizes = XlaHelpers::I64List( + node->get>(at::attr::shape).value()); + } else { + const auto size_op_value_it = + size_op_values_tracking.find(node->input(1)->unique()); + XLA_CHECK(size_op_value_it != size_op_values_tracking.end()) + << "at::aten::reshape only allowed when second parameter is a " + "constant size: " + << *node; + output_sizes = size_op_value_it->second; + } + return BuildView(input, output_sizes); +} + xla::XlaOp SqueezeTrivialDimension(const xla::XlaOp& input, size_t dim) { auto input_sizes = XlaHelpers::SizesOfXlaOp(input); XLA_CHECK_LT(dim, input_sizes.size()); diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index f46ba43dfb0f..a4f5d523620f 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -5,6 +5,7 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "torch/csrc/jit/ir.h" +#include "torch_xla/csrc/translator.h" // Collection of XLA lowerings for operations which only involve some form of // data movement and no computation. @@ -27,6 +28,11 @@ xla::XlaOp BuildView( const xla::XlaOp& input, tensorflow::gtl::ArraySlice output_sizes); +// Creates a tensor reshape operation. +xla::XlaOp BuildReshape( + const torch::jit::Node* node, const xla::XlaOp& input, + const XlaComputationInOut::SizeOpValues& size_op_values_tracking); + // Squeezes the given dimension if trivial (size 1), returns the unchanged input // otherwise. xla::XlaOp SqueezeTrivialDimension(const xla::XlaOp& input, size_t dim); diff --git a/torch_xla/csrc/translator.cpp b/torch_xla/csrc/translator.cpp index e9f1618a903d..a43a5ae51703 100644 --- a/torch_xla/csrc/translator.cpp +++ b/torch_xla/csrc/translator.cpp @@ -422,11 +422,20 @@ void TranslateLogSoftmaxBackward( cctx->AddNodeOp(node, xla_output); } +void TranslateView(const torch::jit::Node* node, ComputationContext* cctx, + xla::PrecisionConfig::Precision /*conv_precision*/, + xla::XlaBuilder* /*b*/) { + XLA_CHECK_EQ(node->inputs().size(), 2); + xla::XlaOp xla_output = BuildView(node, cctx->OpForInput(node, 0)); + cctx->AddNodeOp(node, xla_output); +} + void TranslateReshape(const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* /*b*/) { XLA_CHECK_EQ(node->inputs().size(), 2); - xla::XlaOp xla_output = BuildView(node, cctx->OpForInput(node, 0)); + xla::XlaOp xla_output = + BuildReshape(node, cctx->OpForInput(node, 0), cctx->GetSizeOpValues()); cctx->AddNodeOp(node, xla_output); } @@ -624,7 +633,7 @@ CreateTranslationHandlers() { (*t)[at::aten::log_softmax] = TranslateLogSoftmax; (*t)[at::aten::_log_softmax_backward_data] = TranslateLogSoftmaxBackward; (*t)[at::aten::reshape] = TranslateReshape; - (*t)[at::aten::view] = TranslateReshape; + (*t)[at::aten::view] = TranslateView; (*t)[at::aten::expand] = TranslateExpand; (*t)[at::aten::stack] = TranslateStack; (*t)[at::aten::cat] = TranslateCat; From e2b4cb902b7d731305760d5f98cbe4c35256bec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20=C5=9Euhan?= Date: Thu, 14 Feb 2019 17:42:33 -0800 Subject: [PATCH 2/2] Don't depend on translator header in data_ops header --- torch_xla/csrc/data_ops.cpp | 3 ++- torch_xla/csrc/data_ops.h | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index e63fe63f1e9e..493dcdf8ac59 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -87,7 +87,8 @@ xla::XlaOp BuildView( xla::XlaOp BuildReshape( const torch::jit::Node* node, const xla::XlaOp& input, - const XlaComputationInOut::SizeOpValues& size_op_values_tracking) { + const std::unordered_map>& + size_op_values_tracking) { std::vector output_sizes; if (node->hasAttribute(at::attr::shape)) { output_sizes = XlaHelpers::I64List( diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index a4f5d523620f..c2892ceb944b 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -5,7 +5,6 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "torch/csrc/jit/ir.h" -#include "torch_xla/csrc/translator.h" // Collection of XLA lowerings for operations which only involve some form of // data movement and no computation. @@ -31,7 +30,8 @@ xla::XlaOp BuildView( // Creates a tensor reshape operation. xla::XlaOp BuildReshape( const torch::jit::Node* node, const xla::XlaOp& input, - const XlaComputationInOut::SizeOpValues& size_op_values_tracking); + const std::unordered_map>& + size_op_values_tracking); // Squeezes the given dimension if trivial (size 1), returns the unchanged input // otherwise.