Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::vector<at::Tensor> XlaCreateTensorList(
std::vector<bool> defined_writeable;
std::vector<bool> tensor_is_defined(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i];
const at::Tensor& tensor = tensors[i];
if (!tensor.defined()) {
XLA_CHECK(writeable == nullptr || !(*writeable)[i])
<< "Trying to write to an undefined tensor";
Expand Down
37 changes: 19 additions & 18 deletions torch_xla/csrc/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@ BatchNormOutput BuildBatchNorm(const torch::jit::Node* node,
const xla::XlaOp& input,
const xla::XlaOp& weight,
const xla::XlaOp& bias) {
auto builder = input.builder();
xla::XlaBuilder* builder = input.builder();
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
const float eps_value =
node->get<at::Scalar>(at::attr::eps).value().to<float>();
const auto eps =
xla::XlaOp eps =
XlaHelpers::ScalarValue(eps_value, input_shape.element_type(), builder);
const auto one =
xla::XlaOp one =
XlaHelpers::ScalarValue<float>(1, input_shape.element_type(), builder);
const auto half =
xla::XlaOp half =
XlaHelpers::ScalarValue<float>(0.5f, input_shape.element_type(), builder);

auto outputs = xla::BatchNormTraining(input, weight, bias, eps_value, 1);
auto output = xla::GetTupleElement(outputs, 0);
auto save_mean = xla::GetTupleElement(outputs, 1);
auto save_var = xla::GetTupleElement(outputs, 2);
auto save_invstd_eps = one / xla::Pow(save_var + eps, half);
xla::XlaOp outputs =
xla::BatchNormTraining(input, weight, bias, eps_value, 1);
xla::XlaOp output = xla::GetTupleElement(outputs, 0);
xla::XlaOp save_mean = xla::GetTupleElement(outputs, 1);
xla::XlaOp save_var = xla::GetTupleElement(outputs, 2);
xla::XlaOp save_invstd_eps = one / xla::Pow(save_var + eps, half);
return {output, save_mean, save_invstd_eps};
}

Expand All @@ -32,22 +33,22 @@ BatchNormGrads BuildBatchNormBackward(const torch::jit::Node* node,
const xla::XlaOp& weight,
const xla::XlaOp& save_mean,
const xla::XlaOp& save_invstd_eps) {
auto builder = grad.builder();
xla::XlaBuilder* builder = grad.builder();
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
const float eps_value =
node->get<at::Scalar>(at::attr::eps).value().to<float>();
const auto eps =
xla::XlaOp eps =
XlaHelpers::ScalarValue(eps_value, input_shape.element_type(), builder);
const auto one =
xla::XlaOp one =
XlaHelpers::ScalarValue<float>(1, input_shape.element_type(), builder);
const auto two =
xla::XlaOp two =
XlaHelpers::ScalarValue<float>(2, input_shape.element_type(), builder);
const auto save_var = xla::Pow(one / save_invstd_eps, two) - eps;
const auto grads = xla::BatchNormGrad(input, weight, save_mean, save_var,
xla::XlaOp save_var = xla::Pow(one / save_invstd_eps, two) - eps;
xla::XlaOp grads = xla::BatchNormGrad(input, weight, save_mean, save_var,
grad, eps_value, 1);
const auto grad_input = xla::GetTupleElement(grads, 0);
const auto grad_weight = xla::GetTupleElement(grads, 1);
const auto grad_bias = xla::GetTupleElement(grads, 2);
xla::XlaOp grad_input = xla::GetTupleElement(grads, 0);
xla::XlaOp grad_weight = xla::GetTupleElement(grads, 1);
xla::XlaOp grad_bias = xla::GetTupleElement(grads, 2);
return {grad_input, grad_weight, grad_bias};
}

Expand Down
29 changes: 15 additions & 14 deletions torch_xla/csrc/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ xla::XlaOp BuildThnnConv2dBackwardInput(
input_size[2 + i] += 2 * padding_attr[i];
}
tensorflow::TensorShape input_shape(input_size);
const auto filter = xla::Transpose(weight, {2, 3, 1, 0});
auto builder = grad_output.builder();
xla::XlaOp filter = xla::Transpose(weight, {2, 3, 1, 0});
xla::XlaBuilder* builder = grad_output.builder();
const auto filter_size = XlaHelpers::SizesOfXlaOp(filter);
tensorflow::TensorShape filter_shape(filter_size);
tensorflow::TensorShape out_backprop_shape(
Expand All @@ -35,7 +35,7 @@ xla::XlaOp BuildThnnConv2dBackwardInput(
tensorflow::ConvBackpropDimensions dims;
constexpr int num_spatial_dims = 2;
std::vector<int> dilations{1, 1, 1, 1};
const auto status = ConvBackpropComputeDimensionsV2(
xla::Status status = ConvBackpropComputeDimensionsV2(
"thnn_conv2d_backward", num_spatial_dims, input_shape, filter_shape,
out_backprop_shape, dilations, strides, tensorflow::Padding::VALID,
/*explicit_paddings=*/{}, tensorflow::TensorFormat::FORMAT_NCHW, &dims);
Expand Down Expand Up @@ -87,7 +87,8 @@ xla::XlaOp BuildThnnConv2dBackwardInput(
padding_config.add_dimensions();
}
for (int i = 0; i < 2; ++i) {
auto* dims = padding_config.add_dimensions();
xla::PaddingConfig::PaddingConfigDimension* dims =
padding_config.add_dimensions();
dims->set_edge_padding_low(-padding_attr[i]);
dims->set_edge_padding_high(-padding_attr[i]);
}
Expand Down Expand Up @@ -135,7 +136,7 @@ xla::XlaOp BuildThnnConv2dBackwardWeight(
tensorflow::ConvBackpropDimensions dims;
constexpr int num_spatial_dims = 2;
std::vector<int> dilations{1, 1, 1, 1};
const auto status = ConvBackpropComputeDimensionsV2(
xla::Status status = ConvBackpropComputeDimensionsV2(
"thnn_conv2d_backward", num_spatial_dims, activations_shape, filter_shape,
out_backprop_shape, dilations, strides, tensorflow::Padding::VALID,
/*explicit_paddings=*/{}, tensorflow::TensorFormat::FORMAT_NCHW, &dims);
Expand Down Expand Up @@ -215,12 +216,12 @@ xla::XlaOp BuildThnnConv2dBackwardWeight(
}

// Redo the initial input padding.
const auto padding_config =
xla::PaddingConfig padding_config =
XlaHelpers::MakeXlaPaddingConfig(XlaHelpers::I64List(padding_attr));

auto builder = grad_output.builder();
xla::XlaBuilder* builder = grad_output.builder();
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
const auto padded_input = xla::Pad(
xla::XlaOp padded_input = xla::Pad(
input,
XlaHelpers::ScalarValue<float>(0, input_shape.element_type(), builder),
padding_config);
Expand Down Expand Up @@ -290,14 +291,14 @@ xla::XlaOp BuildConvolutionBias(
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
const xla::PrecisionConfig::Precision conv_precision) {
const auto conv =
xla::XlaOp conv =
BuildConvolution(input, kernel, stride, padding, conv_precision);
auto broadcast_sizes = XlaHelpers::SizesOfXlaOp(conv);
XLA_CHECK_EQ(broadcast_sizes.size(), 4);
// Remove the channels dimension.
broadcast_sizes.erase(broadcast_sizes.begin() + 1);
// Make the bias match the output dimensions.
const auto bias_broadcast =
xla::XlaOp bias_broadcast =
xla::Transpose(xla::Broadcast(bias, broadcast_sizes), {0, 3, 1, 2});
return conv + bias_broadcast;
}
Expand All @@ -320,13 +321,13 @@ Conv2DGrads BuildConv2dBackward(
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
const xla::PrecisionConfig::Precision conv_precision) {
const auto grad_input = BuildThnnConv2dBackwardInput(
xla::XlaOp grad_input = BuildThnnConv2dBackwardInput(
grad_output, input, weight, stride, padding, conv_precision);
const auto grad_weight = BuildThnnConv2dBackwardWeight(
xla::XlaOp grad_weight = BuildThnnConv2dBackwardWeight(
grad_output, input, weight, stride, padding, conv_precision);
auto builder = grad_output.builder();
xla::XlaBuilder* builder = grad_output.builder();
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
const auto grad_bias = xla::Reduce(
xla::XlaOp grad_bias = xla::Reduce(
grad_output,
XlaHelpers::ScalarValue<float>(0, input_shape.element_type(), builder),
XlaHelpers::CreateAddComputation(input_shape.element_type()), {0, 2, 3});
Expand Down
14 changes: 7 additions & 7 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ std::vector<xla::int64> GetCompleteShape(
c10::optional<size_t> incomplete_dim;
int64_t incomplete_element_count = 1;
for (size_t dim = 0; dim < output_sizes.size(); ++dim) {
const auto dim_size = output_sizes[dim];
xla::int64 dim_size = output_sizes[dim];
if (dim_size < 0) {
XLA_CHECK(!incomplete_dim)
<< "More than one incomplete dimension found: " << *incomplete_dim
Expand All @@ -55,7 +55,7 @@ std::vector<xla::int64> GetCompleteShape(
if (!incomplete_dim) {
return std::vector<xla::int64>(output_sizes.begin(), output_sizes.end());
}
const auto total_element_count =
int64_t total_element_count =
std::accumulate(input_sizes.begin(), input_sizes.end(), int64_t(1),
std::multiplies<int64_t>());
XLA_CHECK_EQ(total_element_count % incomplete_element_count, 0)
Expand Down Expand Up @@ -114,13 +114,13 @@ xla::XlaOp BuildExpand(
for (size_t i = 0; i < output_sizes.size() - input_sizes.size(); ++i) {
input_sizes.insert(input_sizes.begin(), 1);
}
const auto implicit_reshape = xla::Reshape(input, input_sizes);
xla::XlaOp implicit_reshape = xla::Reshape(input, input_sizes);
// Squeeze the trivial (of size 1) dimensions.
std::vector<xla::int64> non_singleton_dimensions;
std::copy_if(input_sizes.begin(), input_sizes.end(),
std::back_inserter(non_singleton_dimensions),
[](const size_t dim_size) { return dim_size != 1; });
const auto squeezed_input =
xla::XlaOp squeezed_input =
xla::Reshape(implicit_reshape, non_singleton_dimensions);
// Broadcast the squeezed tensor, the additional dimensions are to the left.
std::vector<xla::int64> broadcast_sizes;
Expand All @@ -129,7 +129,7 @@ xla::XlaOp BuildExpand(
broadcast_sizes.push_back(output_sizes[i]);
}
}
const auto broadcast = xla::Broadcast(squeezed_input, broadcast_sizes);
xla::XlaOp broadcast = xla::Broadcast(squeezed_input, broadcast_sizes);
// Bring the dimensions added by broadcast where the trivial dimensions were.
std::vector<xla::int64> reshape_permutation;
for (size_t i = 0; i < input_sizes.size(); ++i) {
Expand Down Expand Up @@ -157,7 +157,7 @@ xla::XlaOp BuildStack(
// Reshape inputs along the dim axis.
for (size_t i = 0; i < stack_inputs.size(); ++i) {
const auto stack_input = stack_inputs[i];
const auto stack_input_op = node_op(stack_input);
xla::XlaOp stack_input_op = node_op(stack_input);
auto reshaped_input_size = XlaHelpers::SizesOfXlaOp(stack_input_op);
reshaped_input_size.insert(reshaped_input_size.begin() + dim, 1);
reshaped_inputs.push_back(
Expand Down Expand Up @@ -198,7 +198,7 @@ std::vector<xla::XlaOp> BuildChunk(const torch::jit::Node* node,
std::vector<xla::XlaOp> splits(chunks);
int64_t start_idx = 0;
for (int64_t i = 0; i < chunks; ++i) {
const auto length = split_sizes[i];
int64_t length = split_sizes[i];
splits[i] = SliceInDim(input, start_idx, start_idx + length, 1, dim);
start_idx += length;
}
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ xla::XlaOp BuildArithmeticOp(const torch::jit::Node* node,

xla::XlaOp BuildComparisonOp(const torch::jit::Node* node,
const xla::XlaOp& operand) {
auto builder = operand.builder();
xla::XlaBuilder* builder = operand.builder();
xla::Shape operand_shape = XlaHelpers::ShapeOfXlaOp(operand);
const auto xla_other = XlaHelpers::ScalarValue(
xla::XlaOp xla_other = XlaHelpers::ScalarValue(
node->get<at::Scalar>(at::attr::other).value().to<float>(),
operand_shape.element_type(), builder);
xla::XlaOp pred;
Expand All @@ -46,15 +46,15 @@ xla::XlaOp BuildComparisonOp(const torch::jit::Node* node,

xla::XlaOp BuildThreshold(const xla::XlaOp& input, const xla::XlaOp& output,
const float threshold, const float value) {
auto builder = input.builder();
xla::XlaBuilder* builder = input.builder();
xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input);
const auto input_sizes = XlaHelpers::ShapeSizes(input_shape);
std::vector<xla::int64> broadcast_sizes(input_sizes.begin(),
input_sizes.end());
xla::Shape output_shape = XlaHelpers::ShapeOfXlaOp(output);
const auto xla_threshold = XlaHelpers::ScalarValue<float>(
xla::XlaOp xla_threshold = XlaHelpers::ScalarValue<float>(
threshold, input_shape.element_type(), builder);
const auto xla_value = XlaHelpers::ScalarValue<float>(
xla::XlaOp xla_value = XlaHelpers::ScalarValue<float>(
value, output_shape.element_type(), builder);
return xla::Select(xla::Gt(input, xla_threshold), output,
xla::Broadcast(xla_value, broadcast_sizes));
Expand All @@ -73,7 +73,7 @@ xla::XlaOp BuildTypeAs(const torch::jit::Node* node,
const auto output_tensor_type =
node_outputs[0]->type()->cast<at::DimensionedTensorType>();
XLA_CHECK(output_tensor_type);
const auto target_type = XlaHelpers::MakeXlaPrimitiveType(
xla::PrimitiveType target_type = XlaHelpers::MakeXlaPrimitiveType(
output_tensor_type->scalarType(), /*device=*/nullptr);
return xla::ConvertElementType(operand, target_type);
}
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ xla::PaddingConfig XlaHelpers::MakeXlaPaddingConfig(
padding_config.add_dimensions();
}
for (int i = 0; i < 2; ++i) {
auto* dims = padding_config.add_dimensions();
xla::PaddingConfig::PaddingConfigDimension* dims =
padding_config.add_dimensions();
dims->set_edge_padding_low(padding[i]);
dims->set_edge_padding_high(padding[i]);
}
Expand All @@ -56,9 +57,9 @@ xla::PaddingConfig XlaHelpers::MakeXlaPaddingConfig(

xla::XlaComputation XlaHelpers::CreateAddComputation(xla::PrimitiveType type) {
xla::XlaBuilder reduction_builder("xla_add_computation");
const auto x = xla::Parameter(&reduction_builder, 0,
xla::XlaOp x = xla::Parameter(&reduction_builder, 0,
xla::ShapeUtil::MakeShape(type, {}), "x");
const auto y = xla::Parameter(&reduction_builder, 1,
xla::XlaOp y = xla::Parameter(&reduction_builder, 1,
xla::ShapeUtil::MakeShape(type, {}), "y");
Add(x, y);
return reduction_builder.Build().ConsumeValueOrDie();
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class XlaHelpers {
// Creates a XLA constant for the given scalar_value.
template <class T>
static xla::XlaOp ScalarValue(T scalar_value, xla::XlaBuilder* builder) {
const auto scalar_literal = xla::LiteralUtil::CreateR0<T>(scalar_value);
xla::Literal scalar_literal = xla::LiteralUtil::CreateR0<T>(scalar_value);
return xla::ConstantLiteral(builder, scalar_literal);
}

Expand Down Expand Up @@ -65,7 +65,7 @@ class XlaHelpers {
template <class T>
static xla::XlaOp ScalarBroadcast(T scalar_value, const xla::Shape& shape,
xla::XlaBuilder* builder) {
auto scalar_op =
xla::XlaOp scalar_op =
ScalarValue<T>(scalar_value, shape.element_type(), builder);
return xla::Broadcast(scalar_op, ShapeSizes(shape));
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ xla::int64 LoweringContext::AddResult(xla::XlaOp op) {

xla::StatusOr<xla::XlaComputation> LoweringContext::Build() {
if (!root_tuple_.empty()) {
auto root = xla::Tuple(builder(), root_tuple_);
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
return builder()->Build(root);
}
return builder()->Build();
Expand Down
22 changes: 12 additions & 10 deletions torch_xla/csrc/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void XlaModule::Initialize(const TensorBatchVector& inputs) {
TensorBatchVector::value_type replica_params;
TensorBatchVector::value_type optimizable_replica_params;
for (size_t j = 0; j < params_buffers_regather.size(); ++j) {
const auto& var_ref =
const torch::autograd::Variable& var_ref =
torch::autograd::as_variable_ref(*params_buffers_regather[j]);
replica_params.push_back(
XLATensor::Create(var_ref, device, var_ref.requires_grad()));
Expand Down Expand Up @@ -244,7 +244,7 @@ void XlaModule::backward(const TensorBatchVector& grad_outputs) {
XLA_CHECK_GE(raw_output_index, input_vjps_real_outputs);
XLA_CHECK_LT(raw_output_index - input_vjps_real_outputs,
replica_captured_outputs.size());
auto p =
XLATensor p =
replica_captured_outputs[raw_output_index - input_vjps_real_outputs];
replica_raw_grad_outputs.push_back(p);
if (i == 0) {
Expand Down Expand Up @@ -397,8 +397,9 @@ xla::XlaComputation XlaModule::BuildFusedTrainComputation(
xla::XlaBuilder b("XlaFusedComputation");
// Build the forward pass program without compiling it, the backward pass
// needs to be called before finalizing it.
auto computation_in_outs = xla_fwd_impl.BuildComputationProgram(
forward_shapes, backward_size_op_values_, &b);
XlaComputationInOut computation_in_outs =
xla_fwd_impl.BuildComputationProgram(forward_shapes,
backward_size_op_values_, &b);
// Take the XLA outputs from the forward pass and set them for the backward
// call in the same order the standalone, unfused version takes its arguments.
XLA_CHECK(!computation_in_outs.outputs.empty());
Expand Down Expand Up @@ -447,7 +448,7 @@ xla::XlaComputation XlaModule::BuildFusedTrainComputation(
}
// The arguments are set up correctly, call into the backward computation.
XlaTranslator xla_bwd_impl(gradient_.df, GetPrecisionConfig());
auto backward_computation =
xla::XlaComputation backward_computation =
xla_bwd_impl
.BuildComputation("XlaBackward", backward_shapes,
backward_size_op_values_,
Expand Down Expand Up @@ -497,8 +498,9 @@ XlaModule::TensorBatchVector XlaModule::RunUnfusedForward(
}

XlaTranslator xla_fwd_impl(gradient_.f, GetPrecisionConfig());
auto forward_translation_result = xla_fwd_impl.BuildComputation(
"XlaForward", forward_shapes, backward_size_op_values_);
XlaTranslationResult forward_translation_result =
xla_fwd_impl.BuildComputation("XlaForward", forward_shapes,
backward_size_op_values_);
backward_size_op_values_ = SetBackwardSizeOpValues(
forward_translation_result.ret_size_op_values, gradient_);

Expand Down Expand Up @@ -691,9 +693,9 @@ std::vector<Device> XlaModule::CommonDevicesForReplicas(
xla::Shape XlaModule::GetResultShape(const xla::XlaComputation& computation,
const TensorBatchVector& input_tensors) {
auto devices = CommonDevicesForReplicas(input_tensors);
const auto program_shape = computation.GetProgramShape().ValueOrDie();
const auto result_shape = program_shape.result();
return MakeShapeWithDeviceLayout(result_shape, devices.front().hw_type);
xla::ProgramShape program_shape = computation.GetProgramShape().ValueOrDie();
return MakeShapeWithDeviceLayout(program_shape.result(),
devices.front().hw_type);
}

} // namespace torch_xla
Loading