Skip to content

Commit

Permalink
Replace Shape with a C++ class in XLA.
Browse files Browse the repository at this point in the history
No functional change. Rename the proto message Shape to ShapeProto and define an in-place replacement C++ class named Shape with an interface which mirrors the protobuf generated code interface. Having Shape as a C++ class enables greater flexibility in the interface, enables enforcement of invariants, and potential performance improvements.

PiperOrigin-RevId: 223252977
  • Loading branch information
meheffernan authored and tensorflower-gardener committed Nov 29, 2018
1 parent 0f98c06 commit bd737c8
Show file tree
Hide file tree
Showing 46 changed files with 670 additions and 442 deletions.
9 changes: 5 additions & 4 deletions tensorflow/compiler/aot/codegen.cc
Expand Up @@ -175,7 +175,8 @@ Status GenArgMethods(const tf2xla::Config& config,
}
for (int i = 0; i < num_args; ++i) {
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
TF_RETURN_IF_ERROR(
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
const string code = R"(
void set_arg{{NAME}}_data(void* data) {
set_arg_data({{I}}, data);
Expand Down Expand Up @@ -218,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config,
}
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(
AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
TF_RETURN_IF_ERROR(AddRewritesForShape(
i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
string code = R"(
{{TYPE}}* result{{NAME}}_data() {
return static_cast<{{TYPE}}*>(result_data({{I}}));
Expand Down Expand Up @@ -588,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim},
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
Expand Down
16 changes: 11 additions & 5 deletions tensorflow/compiler/aot/compile.cc
Expand Up @@ -58,15 +58,21 @@ Status CompileXla(xla::CompileOnlyClient* client,
}
compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
xla::ProgramShapeProto* pshape = &compile_result->program_shape;
std::vector<const xla::Shape*> arg_layouts;
arg_layouts.reserve(pshape->parameters_size());

// AotXlaComputationInstance::argument_layouts is a vector of Shape
// pointers. Accumulate the Shape objects themselves in a separate vector
// while building the vector of pointers.
std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
arg_layout_ptrs[i] = &arg_layouts[i];
}
xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result();
instance.argument_layouts = std::move(arg_layout_ptrs);
xla::Shape result_shape(pshape->result());
instance.result_layout = &result_shape;
xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
if (!aot_or.ok()) {
Expand Down
8 changes: 5 additions & 3 deletions tensorflow/compiler/aot/tests/tfcompile_test.cc
Expand Up @@ -529,10 +529,12 @@ TEST(TFCompileTest, ProgramShape) {
const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
ASSERT_TRUE(muladd_shape != nullptr);
ASSERT_EQ(muladd_shape->parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2));
EXPECT_TRUE(
ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2));
EXPECT_TRUE(
ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2));

const xla::Shape& muladd_result = muladd_shape->result();
const xla::Shape muladd_result(muladd_shape->result());
ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
const xla::Shape& muladd_result0 =
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/tf2xla/shape_util.h
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_

#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
Expand Down
12 changes: 6 additions & 6 deletions tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
Expand Up @@ -116,13 +116,13 @@ TEST(XlaJitCompiledCpuFunction, Sum) {
// Check program shape.
using xla::ShapeUtil;
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
const xla::ProgramShapeProto* program_shape = function.ProgramShape();
ASSERT_TRUE(program_shape != nullptr);
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32));
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32));
ASSERT_TRUE(function.ProgramShape() != nullptr);
const xla::ProgramShape program_shape(*function.ProgramShape());
ASSERT_EQ(program_shape.parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(0), s32));
EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(1), s32));

const xla::Shape& result = program_shape->result();
const xla::Shape& result = program_shape.result();
ASSERT_EQ(result.element_type(), xla::TUPLE);
ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1);
const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0);
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/client/BUILD
Expand Up @@ -81,6 +81,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
Expand Down
18 changes: 11 additions & 7 deletions tensorflow/compiler/xla/client/client.cc
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
Expand All @@ -42,7 +43,7 @@ StatusOr<Literal> Client::Transfer(const GlobalData& data,
TransferToClientRequest request;
*request.mutable_data() = data.handle();
if (shape_with_layout != nullptr) {
*request.mutable_shape_with_layout() = *shape_with_layout;
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
}
TransferToClientResponse response;

Expand Down Expand Up @@ -123,7 +124,7 @@ StatusOr<Literal> Client::TransferFromOutfeed(
}
request.set_replica_id(replica_id);
if (shape_with_layout != nullptr) {
*request.mutable_shape_with_layout() = *shape_with_layout;
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
}
TransferFromOutfeedResponse response;

Expand Down Expand Up @@ -170,11 +171,14 @@ StatusOr<Literal> Client::ExecuteAndTransfer(
std::unique_ptr<GlobalData> data,
Execute(computation, arguments, execution_options, execution_profile));

const Shape* shape_with_output_layout = nullptr;
absl::optional<Shape> shape_with_output_layout;
if (execution_options && execution_options->has_shape_with_output_layout()) {
shape_with_output_layout = &execution_options->shape_with_output_layout();
shape_with_output_layout =
Shape(execution_options->shape_with_output_layout());
}
return Transfer(*data, shape_with_output_layout);
return Transfer(*data, shape_with_output_layout.has_value()
? &(*shape_with_output_layout)
: nullptr);
}

StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
Expand Down Expand Up @@ -229,7 +233,7 @@ StatusOr<ExecutionHandle> Client::Compile(

// The argument shapes affect how the computation is compiled.
for (const auto& arg_shape : argument_shapes) {
*request.add_input_shape_with_layout() = arg_shape;
*request.add_input_shape_with_layout() = arg_shape.ToProto();
}

CompileResponse response;
Expand Down Expand Up @@ -458,7 +462,7 @@ StatusOr<Shape> Client::GetShape(const GlobalData& data) {
return s;
}

return response.shape();
return Shape(response.shape());
}

StatusOr<string> Client::ExecutionStatsAsString(
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/client/executable_build_options.h
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/compiler/xla/client/lib/testing.cc
Expand Up @@ -66,7 +66,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
XlaComputation computation = b.Build().ConsumeValueOrDie();

auto execution_options = CreateDefaultExecutionOptions();
*execution_options.mutable_shape_with_output_layout() = shape;
*execution_options.mutable_shape_with_output_layout() = shape.ToProto();
return client->Execute(computation, /*arguments=*/{}, &execution_options)
.ConsumeValueOrDie();
}
Expand Down Expand Up @@ -98,8 +98,8 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
auto program_shape = computation.proto().host_program_shape();

std::vector<std::unique_ptr<GlobalData>> results;
for (const Shape& shape : program_shape.parameters()) {
results.push_back(MakeFakeDataOrDie(shape, client));
for (const ShapeProto& shape : program_shape.parameters()) {
results.push_back(MakeFakeDataOrDie(Shape(shape), client));
}
return results;
}
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/client/sharding_builder.cc
Expand Up @@ -36,7 +36,7 @@ OpSharding Tile(const Shape& tile_shape,
const TileAssignment& tile_assignment) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
*result.mutable_tile_shape() = tile_shape;
*result.mutable_tile_shape() = tile_shape.ToProto();
for (int64 dim : tile_assignment.dimensions()) {
result.add_tile_assignment_dimensions(dim);
}
Expand All @@ -52,7 +52,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) {

CHECK_EQ(ShapeUtil::Rank(tile_shape), 1);
std::vector<int64> dimensions(1, num_tiles);
*result.mutable_tile_shape() = tile_shape;
*result.mutable_tile_shape() = tile_shape.ToProto();
auto& tile_dimension =
(*result.mutable_tile_shape()->mutable_dimensions())[0];
tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
Expand Down

0 comments on commit bd737c8

Please sign in to comment.