diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 1b046593991f..55468051768a 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -27,7 +27,7 @@ std::vector XlaCreateTensorList( std::vector defined_writeable; std::vector tensor_is_defined(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { - auto& tensor = tensors[i]; + const at::Tensor& tensor = tensors[i]; if (!tensor.defined()) { XLA_CHECK(writeable == nullptr || !(*writeable)[i]) << "Trying to write to an undefined tensor"; diff --git a/torch_xla/csrc/batch_norm.cpp b/torch_xla/csrc/batch_norm.cpp index 3e589f1fc310..bfb1d02ac9de 100644 --- a/torch_xla/csrc/batch_norm.cpp +++ b/torch_xla/csrc/batch_norm.cpp @@ -7,22 +7,23 @@ BatchNormOutput BuildBatchNorm(const torch::jit::Node* node, const xla::XlaOp& input, const xla::XlaOp& weight, const xla::XlaOp& bias) { - auto builder = input.builder(); + xla::XlaBuilder* builder = input.builder(); xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); const float eps_value = node->get(at::attr::eps).value().to(); - const auto eps = + xla::XlaOp eps = XlaHelpers::ScalarValue(eps_value, input_shape.element_type(), builder); - const auto one = + xla::XlaOp one = XlaHelpers::ScalarValue(1, input_shape.element_type(), builder); - const auto half = + xla::XlaOp half = XlaHelpers::ScalarValue(0.5f, input_shape.element_type(), builder); - auto outputs = xla::BatchNormTraining(input, weight, bias, eps_value, 1); - auto output = xla::GetTupleElement(outputs, 0); - auto save_mean = xla::GetTupleElement(outputs, 1); - auto save_var = xla::GetTupleElement(outputs, 2); - auto save_invstd_eps = one / xla::Pow(save_var + eps, half); + xla::XlaOp outputs = + xla::BatchNormTraining(input, weight, bias, eps_value, 1); + xla::XlaOp output = xla::GetTupleElement(outputs, 0); + xla::XlaOp save_mean = xla::GetTupleElement(outputs, 1); + xla::XlaOp save_var = xla::GetTupleElement(outputs, 2); + xla::XlaOp save_invstd_eps = one / xla::Pow(save_var + eps, half); return {output, save_mean, save_invstd_eps}; } @@ -32,22 +33,22 @@ BatchNormGrads BuildBatchNormBackward(const torch::jit::Node* node, const xla::XlaOp& weight, const xla::XlaOp& save_mean, const xla::XlaOp& save_invstd_eps) { - auto builder = grad.builder(); + xla::XlaBuilder* builder = grad.builder(); xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); const float eps_value = node->get(at::attr::eps).value().to(); - const auto eps = + xla::XlaOp eps = XlaHelpers::ScalarValue(eps_value, input_shape.element_type(), builder); - const auto one = + xla::XlaOp one = XlaHelpers::ScalarValue(1, input_shape.element_type(), builder); - const auto two = + xla::XlaOp two = XlaHelpers::ScalarValue(2, input_shape.element_type(), builder); - const auto save_var = xla::Pow(one / save_invstd_eps, two) - eps; - const auto grads = xla::BatchNormGrad(input, weight, save_mean, save_var, + xla::XlaOp save_var = xla::Pow(one / save_invstd_eps, two) - eps; + xla::XlaOp grads = xla::BatchNormGrad(input, weight, save_mean, save_var, grad, eps_value, 1); - const auto grad_input = xla::GetTupleElement(grads, 0); - const auto grad_weight = xla::GetTupleElement(grads, 1); - const auto grad_bias = xla::GetTupleElement(grads, 2); + xla::XlaOp grad_input = xla::GetTupleElement(grads, 0); + xla::XlaOp grad_weight = xla::GetTupleElement(grads, 1); + xla::XlaOp grad_bias = xla::GetTupleElement(grads, 2); return {grad_input, grad_weight, grad_bias}; } diff --git a/torch_xla/csrc/convolution.cpp b/torch_xla/csrc/convolution.cpp index a19c15128627..7550bd41e59d 100644 --- a/torch_xla/csrc/convolution.cpp +++ b/torch_xla/csrc/convolution.cpp @@ -23,8 +23,8 @@ xla::XlaOp BuildThnnConv2dBackwardInput( input_size[2 + i] += 2 * padding_attr[i]; } tensorflow::TensorShape input_shape(input_size); - const auto filter = xla::Transpose(weight, {2, 3, 1, 0}); - auto builder = grad_output.builder(); + xla::XlaOp filter = xla::Transpose(weight, {2, 3, 1, 0}); + xla::XlaBuilder* builder = grad_output.builder(); const auto filter_size = XlaHelpers::SizesOfXlaOp(filter); tensorflow::TensorShape filter_shape(filter_size); tensorflow::TensorShape out_backprop_shape( @@ -35,7 +35,7 @@ xla::XlaOp BuildThnnConv2dBackwardInput( tensorflow::ConvBackpropDimensions dims; constexpr int num_spatial_dims = 2; std::vector dilations{1, 1, 1, 1}; - const auto status = ConvBackpropComputeDimensionsV2( + xla::Status status = ConvBackpropComputeDimensionsV2( "thnn_conv2d_backward", num_spatial_dims, input_shape, filter_shape, out_backprop_shape, dilations, strides, tensorflow::Padding::VALID, /*explicit_paddings=*/{}, tensorflow::TensorFormat::FORMAT_NCHW, &dims); @@ -87,7 +87,8 @@ xla::XlaOp BuildThnnConv2dBackwardInput( padding_config.add_dimensions(); } for (int i = 0; i < 2; ++i) { - auto* dims = padding_config.add_dimensions(); + xla::PaddingConfig::PaddingConfigDimension* dims = + padding_config.add_dimensions(); dims->set_edge_padding_low(-padding_attr[i]); dims->set_edge_padding_high(-padding_attr[i]); } @@ -135,7 +136,7 @@ xla::XlaOp BuildThnnConv2dBackwardWeight( tensorflow::ConvBackpropDimensions dims; constexpr int num_spatial_dims = 2; std::vector dilations{1, 1, 1, 1}; - const auto status = ConvBackpropComputeDimensionsV2( + xla::Status status = ConvBackpropComputeDimensionsV2( "thnn_conv2d_backward", num_spatial_dims, activations_shape, filter_shape, out_backprop_shape, dilations, strides, tensorflow::Padding::VALID, /*explicit_paddings=*/{}, tensorflow::TensorFormat::FORMAT_NCHW, &dims); @@ -215,12 +216,12 @@ xla::XlaOp BuildThnnConv2dBackwardWeight( } // Redo the initial input padding. - const auto padding_config = + xla::PaddingConfig padding_config = XlaHelpers::MakeXlaPaddingConfig(XlaHelpers::I64List(padding_attr)); - auto builder = grad_output.builder(); + xla::XlaBuilder* builder = grad_output.builder(); xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); - const auto padded_input = xla::Pad( + xla::XlaOp padded_input = xla::Pad( input, XlaHelpers::ScalarValue(0, input_shape.element_type(), builder), padding_config); @@ -290,14 +291,14 @@ xla::XlaOp BuildConvolutionBias( tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, const xla::PrecisionConfig::Precision conv_precision) { - const auto conv = + xla::XlaOp conv = BuildConvolution(input, kernel, stride, padding, conv_precision); auto broadcast_sizes = XlaHelpers::SizesOfXlaOp(conv); XLA_CHECK_EQ(broadcast_sizes.size(), 4); // Remove the channels dimension. broadcast_sizes.erase(broadcast_sizes.begin() + 1); // Make the bias match the output dimensions. - const auto bias_broadcast = + xla::XlaOp bias_broadcast = xla::Transpose(xla::Broadcast(bias, broadcast_sizes), {0, 3, 1, 2}); return conv + bias_broadcast; } @@ -320,13 +321,13 @@ Conv2DGrads BuildConv2dBackward( tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, const xla::PrecisionConfig::Precision conv_precision) { - const auto grad_input = BuildThnnConv2dBackwardInput( + xla::XlaOp grad_input = BuildThnnConv2dBackwardInput( grad_output, input, weight, stride, padding, conv_precision); - const auto grad_weight = BuildThnnConv2dBackwardWeight( + xla::XlaOp grad_weight = BuildThnnConv2dBackwardWeight( grad_output, input, weight, stride, padding, conv_precision); - auto builder = grad_output.builder(); + xla::XlaBuilder* builder = grad_output.builder(); xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); - const auto grad_bias = xla::Reduce( + xla::XlaOp grad_bias = xla::Reduce( grad_output, XlaHelpers::ScalarValue(0, input_shape.element_type(), builder), XlaHelpers::CreateAddComputation(input_shape.element_type()), {0, 2, 3}); diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index 41cbbb30e477..b3210bf8f528 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -42,7 +42,7 @@ std::vector GetCompleteShape( c10::optional incomplete_dim; int64_t incomplete_element_count = 1; for (size_t dim = 0; dim < output_sizes.size(); ++dim) { - const auto dim_size = output_sizes[dim]; + xla::int64 dim_size = output_sizes[dim]; if (dim_size < 0) { XLA_CHECK(!incomplete_dim) << "More than one incomplete dimension found: " << *incomplete_dim @@ -55,7 +55,7 @@ std::vector GetCompleteShape( if (!incomplete_dim) { return std::vector(output_sizes.begin(), output_sizes.end()); } - const auto total_element_count = + int64_t total_element_count = std::accumulate(input_sizes.begin(), input_sizes.end(), int64_t(1), std::multiplies()); XLA_CHECK_EQ(total_element_count % incomplete_element_count, 0) @@ -114,13 +114,13 @@ xla::XlaOp BuildExpand( for (size_t i = 0; i < output_sizes.size() - input_sizes.size(); ++i) { input_sizes.insert(input_sizes.begin(), 1); } - const auto implicit_reshape = xla::Reshape(input, input_sizes); + xla::XlaOp implicit_reshape = xla::Reshape(input, input_sizes); // Squeeze the trivial (of size 1) dimensions. std::vector non_singleton_dimensions; std::copy_if(input_sizes.begin(), input_sizes.end(), std::back_inserter(non_singleton_dimensions), [](const size_t dim_size) { return dim_size != 1; }); - const auto squeezed_input = + xla::XlaOp squeezed_input = xla::Reshape(implicit_reshape, non_singleton_dimensions); // Broadcast the squeezed tensor, the additional dimensions are to the left. std::vector broadcast_sizes; @@ -129,7 +129,7 @@ xla::XlaOp BuildExpand( broadcast_sizes.push_back(output_sizes[i]); } } - const auto broadcast = xla::Broadcast(squeezed_input, broadcast_sizes); + xla::XlaOp broadcast = xla::Broadcast(squeezed_input, broadcast_sizes); // Bring the dimensions added by broadcast where the trivial dimensions were. std::vector reshape_permutation; for (size_t i = 0; i < input_sizes.size(); ++i) { @@ -157,7 +157,7 @@ xla::XlaOp BuildStack( // Reshape inputs along the dim axis. for (size_t i = 0; i < stack_inputs.size(); ++i) { const auto stack_input = stack_inputs[i]; - const auto stack_input_op = node_op(stack_input); + xla::XlaOp stack_input_op = node_op(stack_input); auto reshaped_input_size = XlaHelpers::SizesOfXlaOp(stack_input_op); reshaped_input_size.insert(reshaped_input_size.begin() + dim, 1); reshaped_inputs.push_back( @@ -198,7 +198,7 @@ std::vector BuildChunk(const torch::jit::Node* node, std::vector splits(chunks); int64_t start_idx = 0; for (int64_t i = 0; i < chunks; ++i) { - const auto length = split_sizes[i]; + int64_t length = split_sizes[i]; splits[i] = SliceInDim(input, start_idx, start_idx + length, 1, dim); start_idx += length; } diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 768acdb901ce..709abd861db6 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -27,9 +27,9 @@ xla::XlaOp BuildArithmeticOp(const torch::jit::Node* node, xla::XlaOp BuildComparisonOp(const torch::jit::Node* node, const xla::XlaOp& operand) { - auto builder = operand.builder(); + xla::XlaBuilder* builder = operand.builder(); xla::Shape operand_shape = XlaHelpers::ShapeOfXlaOp(operand); - const auto xla_other = XlaHelpers::ScalarValue( + xla::XlaOp xla_other = XlaHelpers::ScalarValue( node->get(at::attr::other).value().to(), operand_shape.element_type(), builder); xla::XlaOp pred; @@ -46,15 +46,15 @@ xla::XlaOp BuildComparisonOp(const torch::jit::Node* node, xla::XlaOp BuildThreshold(const xla::XlaOp& input, const xla::XlaOp& output, const float threshold, const float value) { - auto builder = input.builder(); + xla::XlaBuilder* builder = input.builder(); xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); const auto input_sizes = XlaHelpers::ShapeSizes(input_shape); std::vector broadcast_sizes(input_sizes.begin(), input_sizes.end()); xla::Shape output_shape = XlaHelpers::ShapeOfXlaOp(output); - const auto xla_threshold = XlaHelpers::ScalarValue( + xla::XlaOp xla_threshold = XlaHelpers::ScalarValue( threshold, input_shape.element_type(), builder); - const auto xla_value = XlaHelpers::ScalarValue( + xla::XlaOp xla_value = XlaHelpers::ScalarValue( value, output_shape.element_type(), builder); return xla::Select(xla::Gt(input, xla_threshold), output, xla::Broadcast(xla_value, broadcast_sizes)); @@ -73,7 +73,7 @@ xla::XlaOp BuildTypeAs(const torch::jit::Node* node, const auto output_tensor_type = node_outputs[0]->type()->cast(); XLA_CHECK(output_tensor_type); - const auto target_type = XlaHelpers::MakeXlaPrimitiveType( + xla::PrimitiveType target_type = XlaHelpers::MakeXlaPrimitiveType( output_tensor_type->scalarType(), /*device=*/nullptr); return xla::ConvertElementType(operand, target_type); } diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index f8b571648674..ad62b37f0460 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -47,7 +47,8 @@ xla::PaddingConfig XlaHelpers::MakeXlaPaddingConfig( padding_config.add_dimensions(); } for (int i = 0; i < 2; ++i) { - auto* dims = padding_config.add_dimensions(); + xla::PaddingConfig::PaddingConfigDimension* dims = + padding_config.add_dimensions(); dims->set_edge_padding_low(padding[i]); dims->set_edge_padding_high(padding[i]); } @@ -56,9 +57,9 @@ xla::PaddingConfig XlaHelpers::MakeXlaPaddingConfig( xla::XlaComputation XlaHelpers::CreateAddComputation(xla::PrimitiveType type) { xla::XlaBuilder reduction_builder("xla_add_computation"); - const auto x = xla::Parameter(&reduction_builder, 0, + xla::XlaOp x = xla::Parameter(&reduction_builder, 0, xla::ShapeUtil::MakeShape(type, {}), "x"); - const auto y = xla::Parameter(&reduction_builder, 1, + xla::XlaOp y = xla::Parameter(&reduction_builder, 1, xla::ShapeUtil::MakeShape(type, {}), "y"); Add(x, y); return reduction_builder.Build().ConsumeValueOrDie(); diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 38f2130c4eca..c26a8651b683 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -34,7 +34,7 @@ class XlaHelpers { // Creates a XLA constant for the given scalar_value. template static xla::XlaOp ScalarValue(T scalar_value, xla::XlaBuilder* builder) { - const auto scalar_literal = xla::LiteralUtil::CreateR0(scalar_value); + xla::Literal scalar_literal = xla::LiteralUtil::CreateR0(scalar_value); return xla::ConstantLiteral(builder, scalar_literal); } @@ -65,7 +65,7 @@ class XlaHelpers { template static xla::XlaOp ScalarBroadcast(T scalar_value, const xla::Shape& shape, xla::XlaBuilder* builder) { - auto scalar_op = + xla::XlaOp scalar_op = ScalarValue(scalar_value, shape.element_type(), builder); return xla::Broadcast(scalar_op, ShapeSizes(shape)); } diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 1f971c8ccb37..d2ead25123c9 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -34,7 +34,7 @@ xla::int64 LoweringContext::AddResult(xla::XlaOp op) { xla::StatusOr LoweringContext::Build() { if (!root_tuple_.empty()) { - auto root = xla::Tuple(builder(), root_tuple_); + xla::XlaOp root = xla::Tuple(builder(), root_tuple_); return builder()->Build(root); } return builder()->Build(); diff --git a/torch_xla/csrc/module.cpp b/torch_xla/csrc/module.cpp index 42213eae3b38..6874a18d2fca 100644 --- a/torch_xla/csrc/module.cpp +++ b/torch_xla/csrc/module.cpp @@ -97,7 +97,7 @@ void XlaModule::Initialize(const TensorBatchVector& inputs) { TensorBatchVector::value_type replica_params; TensorBatchVector::value_type optimizable_replica_params; for (size_t j = 0; j < params_buffers_regather.size(); ++j) { - const auto& var_ref = + const torch::autograd::Variable& var_ref = torch::autograd::as_variable_ref(*params_buffers_regather[j]); replica_params.push_back( XLATensor::Create(var_ref, device, var_ref.requires_grad())); @@ -244,7 +244,7 @@ void XlaModule::backward(const TensorBatchVector& grad_outputs) { XLA_CHECK_GE(raw_output_index, input_vjps_real_outputs); XLA_CHECK_LT(raw_output_index - input_vjps_real_outputs, replica_captured_outputs.size()); - auto p = + XLATensor p = replica_captured_outputs[raw_output_index - input_vjps_real_outputs]; replica_raw_grad_outputs.push_back(p); if (i == 0) { @@ -397,8 +397,9 @@ xla::XlaComputation XlaModule::BuildFusedTrainComputation( xla::XlaBuilder b("XlaFusedComputation"); // Build the forward pass program without compiling it, the backward pass // needs to be called before finalizing it. - auto computation_in_outs = xla_fwd_impl.BuildComputationProgram( - forward_shapes, backward_size_op_values_, &b); + XlaComputationInOut computation_in_outs = + xla_fwd_impl.BuildComputationProgram(forward_shapes, + backward_size_op_values_, &b); // Take the XLA outputs from the forward pass and set them for the backward // call in the same order the standalone, unfused version takes its arguments. XLA_CHECK(!computation_in_outs.outputs.empty()); @@ -447,7 +448,7 @@ xla::XlaComputation XlaModule::BuildFusedTrainComputation( } // The arguments are set up correctly, call into the backward computation. XlaTranslator xla_bwd_impl(gradient_.df, GetPrecisionConfig()); - auto backward_computation = + xla::XlaComputation backward_computation = xla_bwd_impl .BuildComputation("XlaBackward", backward_shapes, backward_size_op_values_, @@ -497,8 +498,9 @@ XlaModule::TensorBatchVector XlaModule::RunUnfusedForward( } XlaTranslator xla_fwd_impl(gradient_.f, GetPrecisionConfig()); - auto forward_translation_result = xla_fwd_impl.BuildComputation( - "XlaForward", forward_shapes, backward_size_op_values_); + XlaTranslationResult forward_translation_result = + xla_fwd_impl.BuildComputation("XlaForward", forward_shapes, + backward_size_op_values_); backward_size_op_values_ = SetBackwardSizeOpValues( forward_translation_result.ret_size_op_values, gradient_); @@ -691,9 +693,9 @@ std::vector XlaModule::CommonDevicesForReplicas( xla::Shape XlaModule::GetResultShape(const xla::XlaComputation& computation, const TensorBatchVector& input_tensors) { auto devices = CommonDevicesForReplicas(input_tensors); - const auto program_shape = computation.GetProgramShape().ValueOrDie(); - const auto result_shape = program_shape.result(); - return MakeShapeWithDeviceLayout(result_shape, devices.front().hw_type); + xla::ProgramShape program_shape = computation.GetProgramShape().ValueOrDie(); + return MakeShapeWithDeviceLayout(program_shape.result(), + devices.front().hw_type); } } // namespace torch_xla diff --git a/torch_xla/csrc/nll_loss.cpp b/torch_xla/csrc/nll_loss.cpp index 31509290c1cb..a71355f6f78d 100644 --- a/torch_xla/csrc/nll_loss.cpp +++ b/torch_xla/csrc/nll_loss.cpp @@ -25,7 +25,7 @@ xla::XlaOp LabelsToOneHot(xla::XlaBuilder* builder, xla::int64 depth, int axis, std::iota(linspace_data.begin(), linspace_data.end(), 0); std::vector linspace_dims(output_dims, 1); linspace_dims[axis] = depth; - const auto linspace_xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla::Shape linspace_xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( xla::PrimitiveType::S64, linspace_dims); xla::BorrowingLiteral linspace_literal( reinterpret_cast(linspace_data.data()), linspace_xla_shape); @@ -74,8 +74,8 @@ xla::XlaOp BuildNllLoss(const xla::XlaOp& logits, const xla::XlaOp& labels) { xla::XlaOp BuildNllLossBackward(const xla::XlaOp& logits, const xla::XlaOp& labels) { const int kBatchDim = 0; - auto builder = logits.builder(); - const auto logits_shape = XlaHelpers::ShapeOfXlaOp(logits); + xla::XlaBuilder* builder = logits.builder(); + xla::Shape logits_shape = XlaHelpers::ShapeOfXlaOp(logits); xla::XlaOp one_hot_labels = LabelsToOneHot( /*builder=*/builder, /*depth=*/logits_shape.dimensions(1), @@ -85,7 +85,7 @@ xla::XlaOp BuildNllLossBackward(const xla::XlaOp& logits, XlaHelpers::ScalarValue(1, logits_shape.element_type(), builder), /*off_value=*/ XlaHelpers::ScalarValue(0, logits_shape.element_type(), builder)); - const auto batch = XlaHelpers::ScalarValue( + xla::XlaOp batch = XlaHelpers::ScalarValue( logits_shape.dimensions(kBatchDim), logits_shape.element_type(), builder); // Compute -one_hot_labels / batch. return xla::Neg(one_hot_labels) / batch; diff --git a/torch_xla/csrc/pooling.cpp b/torch_xla/csrc/pooling.cpp index be66e9486919..cb455ecf20f1 100644 --- a/torch_xla/csrc/pooling.cpp +++ b/torch_xla/csrc/pooling.cpp @@ -22,9 +22,9 @@ struct PoolingOpAttributes { xla::XlaComputation CreateGeComputation(xla::PrimitiveType type) { xla::XlaBuilder reduction_builder("xla_ge_computation"); - const auto x = xla::Parameter(&reduction_builder, 0, + xla::XlaOp x = xla::Parameter(&reduction_builder, 0, xla::ShapeUtil::MakeShape(type, {}), "x"); - const auto y = xla::Parameter(&reduction_builder, 1, + xla::XlaOp y = xla::Parameter(&reduction_builder, 1, xla::ShapeUtil::MakeShape(type, {}), "y"); xla::Ge(x, y); return reduction_builder.Build().ConsumeValueOrDie(); @@ -115,14 +115,14 @@ xla::XlaOp BuildMaxPool2d( tensorflow::gtl::ArraySlice kernel_size, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding) { - auto builder = input.builder(); + xla::XlaBuilder* builder = input.builder(); xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); - const auto init_value = + xla::Literal init_value = xla::LiteralUtil::MinValue(input_shape.element_type()); - const auto xla_init_value = xla::ConstantLiteral(builder, init_value); - const auto padding_config = XlaHelpers::MakeXlaPaddingConfig(padding); - const auto padded_input = xla::Pad(input, xla_init_value, padding_config); - const auto pooling_op_attributes = + xla::XlaOp xla_init_value = xla::ConstantLiteral(builder, init_value); + xla::PaddingConfig padding_config = XlaHelpers::MakeXlaPaddingConfig(padding); + xla::XlaOp padded_input = xla::Pad(input, xla_init_value, padding_config); + PoolingOpAttributes pooling_op_attributes = Pooling2DOpAttributes(/*kernel_size_attr=*/kernel_size, /*stride_attr=*/stride, /*padding_attr=*/padding); return xla::MaxPool( @@ -136,14 +136,14 @@ xla::XlaOp BuildMaxPool2d( xla::XlaOp BuildMaxPool2dBackward(const torch::jit::Node* node, const xla::XlaOp& out_backprop, const xla::XlaOp& input) { - auto builder = out_backprop.builder(); + xla::XlaBuilder* builder = out_backprop.builder(); xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); - const auto init_value = + xla::XlaOp init_value = XlaHelpers::ScalarValue(0, input_shape.element_type(), builder); - const auto select = CreateGeComputation(input_shape.element_type()); - const auto scatter = + xla::XlaComputation select = CreateGeComputation(input_shape.element_type()); + xla::XlaComputation scatter = XlaHelpers::CreateAddComputation(input_shape.element_type()); - const auto pooling_op_attributes = Pooling2DOpAttributes(node); + PoolingOpAttributes pooling_op_attributes = Pooling2DOpAttributes(node); std::vector> window_padding; window_padding.resize(2); window_padding.insert(window_padding.end(), @@ -185,7 +185,7 @@ xla::XlaOp BuildAvgPool2d( tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool count_include_pad) { - const auto pooling_op_attributes = + PoolingOpAttributes pooling_op_attributes = Pooling2DOpAttributes(/*kernel_size_attr=*/kernel_size, /*stride_attr=*/stride, /*padding_attr=*/padding); return xla::AvgPool( @@ -222,7 +222,7 @@ xla::XlaOp BuildAvgPool2dBackward( tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool count_include_pad) { - const auto pooling_op_attributes = + PoolingOpAttributes pooling_op_attributes = Pooling2DOpAttributes(/*kernel_size_attr=*/kernel_size, /*stride_attr=*/stride, /*padding_attr=*/padding); auto gradients_size = XlaHelpers::SizesOfXlaOp(input); diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 68dd8f2bdf5e..1bf9fc61bf41 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -8,10 +8,9 @@ xla::XlaOp BuildSum(const torch::jit::Node* node, const xla::XlaOp& operand) { if (node->get(at::attr::keepdim).value()) { XLA_ERROR() << "Sum with keepdim set not supported yet"; } - auto builder = operand.builder(); xla::Shape operand_shape = XlaHelpers::ShapeOfXlaOp(operand); - const auto init_value = - XlaHelpers::ScalarValue(0, operand_shape.element_type(), builder); + xla::XlaOp init_value = XlaHelpers::ScalarValue( + 0, operand_shape.element_type(), operand.builder()); const auto dimensions_to_reduce = node->get>(at::attr::dim).value(); return xla::Reduce( diff --git a/torch_xla/csrc/size_ops.cpp b/torch_xla/csrc/size_ops.cpp index 4e0b29f246cc..769237652492 100644 --- a/torch_xla/csrc/size_ops.cpp +++ b/torch_xla/csrc/size_ops.cpp @@ -9,7 +9,7 @@ xla::XlaOp BuildSize(const torch::jit::Node* node, const xla::XlaOp& input, std::vector* size_op_result) { const auto shape_sizes = XlaHelpers::SizesOfXlaOp(input); *size_op_result = shape_sizes; - auto builder = input.builder(); + xla::XlaBuilder* builder = input.builder(); return xla::ConstantR1(builder, shape_sizes); } diff --git a/torch_xla/csrc/translator.cpp b/torch_xla/csrc/translator.cpp index a103e7f3b457..eb28e48ef483 100644 --- a/torch_xla/csrc/translator.cpp +++ b/torch_xla/csrc/translator.cpp @@ -225,7 +225,7 @@ void TranslateConvolutionBackward( const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision conv_precision, xla::XlaBuilder* /*b*/) { XLA_CHECK_EQ(node->inputs().size(), 9); - const auto conv2d_grads = BuildConv2dBackward( + Conv2DGrads conv2d_grads = BuildConv2dBackward( node, cctx->OpForInput(node, 0), cctx->OpForInput(node, 1), cctx->OpForInput(node, 2), conv_precision); const auto node_outputs = node->outputs(); @@ -329,7 +329,7 @@ void TranslateSqrt(const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* /*b*/) { XLA_CHECK_EQ(node->inputs().size(), 1); - const auto xla_input = cctx->OpForInput(node, 0); + xla::XlaOp xla_input = cctx->OpForInput(node, 0); xla::XlaOp xla_output = xla::Sqrt(xla_input); cctx->AddNodeOp(node, xla_output); } @@ -338,7 +338,7 @@ void TranslateRsqrt(const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* /*b*/) { XLA_CHECK_EQ(node->inputs().size(), 1); - const auto xla_input = cctx->OpForInput(node, 0); + xla::XlaOp xla_input = cctx->OpForInput(node, 0); xla::XlaOp xla_output = xla::Rsqrt(xla_input); cctx->AddNodeOp(node, xla_output); } @@ -347,7 +347,7 @@ void TranslateNeg(const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* /*b*/) { XLA_CHECK_EQ(node->inputs().size(), 1); - const auto xla_input = cctx->OpForInput(node, 0); + xla::XlaOp xla_input = cctx->OpForInput(node, 0); xla::XlaOp xla_output = xla::Neg(xla_input); cctx->AddNodeOp(node, xla_output); } @@ -356,7 +356,7 @@ void TranslateTanh(const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* /*b*/) { XLA_CHECK_EQ(node->inputs().size(), 1); - const auto xla_input = cctx->OpForInput(node, 0); + xla::XlaOp xla_input = cctx->OpForInput(node, 0); xla::XlaOp xla_output = xla::Tanh(xla_input); cctx->AddNodeOp(node, xla_output); } @@ -365,9 +365,9 @@ void TranslateSigmoid(const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* b) { XLA_CHECK_EQ(node->inputs().size(), 1); - const auto xla_input = cctx->OpForInput(node, 0); + xla::XlaOp xla_input = cctx->OpForInput(node, 0); xla::Shape xla_input_shape = XlaHelpers::ShapeOfXlaOp(xla_input); - const auto half = + xla::XlaOp half = XlaHelpers::ScalarValue(0.5, xla_input_shape.element_type(), b); xla::XlaOp xla_output = half + half * xla::Tanh(half * xla_input); cctx->AddNodeOp(node, xla_output); @@ -377,7 +377,7 @@ void TranslateRelu(const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* b) { XLA_CHECK_EQ(node->inputs().size(), 1); - const auto xla_input = cctx->OpForInput(node, 0); + xla::XlaOp xla_input = cctx->OpForInput(node, 0); xla::XlaOp xla_output = BuildRelu(xla_input); cctx->AddNodeOp(node, xla_output); } @@ -479,7 +479,7 @@ void TranslateBatchNorm(const torch::jit::Node* node, ComputationContext* cctx, xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* /*b*/) { XLA_CHECK_EQ(node->inputs().size(), 8); - const auto outputs = + BatchNormOutput outputs = BuildBatchNorm(node, cctx->OpForInput(node, 0), cctx->OpForInput(node, 1), cctx->OpForInput(node, 2)); const auto node_outputs = node->outputs(); @@ -500,7 +500,7 @@ void TranslateBatchNormBackward( xla::PrecisionConfig::Precision /*conv_precision*/, xla::XlaBuilder* /*b*/) { XLA_CHECK_EQ(node->inputs().size(), 10); - auto grads = + BatchNormGrads grads = BuildBatchNormBackward(node, cctx->OpForInput(node, 0), // grad_output cctx->OpForInput(node, 1), // input cctx->OpForInput(node, 2), // weight @@ -652,7 +652,7 @@ XlaTranslationResult XlaTranslator::BuildComputation( const XlaComputationInOut::SizeOpValues& param_size_op_values, const BuildOptions& options) const { xla::XlaBuilder b(name); - auto computation_program = + XlaComputationInOut computation_program = BuildComputationProgram(parameter_shapes, param_size_op_values, &b); if (options.output_transform) { for (size_t i = 0; i < computation_program.outputs.size(); ++i) { @@ -678,7 +678,7 @@ XlaComputationInOut XlaTranslator::BuildComputationProgram( ++parameter_number) { torch::jit::Value* graph_input = graph_inputs[parameter_number]; if (parameter_shapes[parameter_number].kind == ParameterKind::kGraphInput) { - auto param_no = cctx.GetInputsSize(); + size_t param_no = cctx.GetInputsSize(); const auto parameter_op = xla::Parameter(b, param_no, parameter_shapes[parameter_number].shape, "param_" + std::to_string(param_no));