Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 178689056 #15286

Merged
merged 114 commits into from
Dec 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
e81b739
Enable more import_graph_def tests with C API.
skye Dec 7, 2017
b5c8cd6
Fix control flow test to not use session after it's gone out of scope.
skye Dec 7, 2017
1e54177
Increase test size of tensorflow/python/kernel_tests:dynamic_partitio…
tensorflower-gardener Dec 7, 2017
1fe793d
Adds support for loading model directly from a Flatbuffer object.
tensorflower-gardener Dec 7, 2017
b1c3222
Use sparse loss to avoid the warning being thrown by tf.nn.softmax_cr…
MarkDaoust Dec 7, 2017
cdc36a3
tfdbg: cosmetic fix to MonitoredSession.__del__ AttributeError
caisq Dec 7, 2017
2a41c99
Optimize control dependencies driven by constants
benoitsteiner Dec 7, 2017
e35da0b
Add logging to explicitly report the op types of nodes that fail shap…
tensorflower-gardener Dec 7, 2017
9620b2d
Fix the issue with shared saver on GPU.
isaprykin Dec 7, 2017
a6d01af
Graph._create_op_from_tf_operation should update _names_in_use.
skye Dec 7, 2017
507dc8c
Use size=large, shard_count=2 for tensorflow/python/kernel_tests:tran…
tensorflower-gardener Dec 7, 2017
ee1bd60
Add `discriminator_and_aux_fn` to `InfoGANModel`.
tensorflower-gardener Dec 7, 2017
f282544
WhileContext.AddBackpropIndexedSlicesAccumulator shouldn't create inv…
skye Dec 7, 2017
25adb9c
Add "apple CROSSTOOL" to common problems
MarkDaoust Dec 7, 2017
7772360
Raise exception on bad while loop shapes sooner.
skye Dec 7, 2017
488f091
Remove unused BUILD dependencies
tensorflower-gardener Dec 7, 2017
89804a9
Eliminate redundant control dependencies by computing the transitive …
tensorflower-gardener Dec 7, 2017
b02eae0
[tf.data] Add usage notes to the `Dataset.from_generator()` documenta…
mrry Dec 7, 2017
0509f07
Make BaseSession.make_callable work with the C API enabled.
skye Dec 7, 2017
2ea1141
Simplify and fix some bugs in constant folding of neutral/absorbing e…
tensorflower-gardener Dec 7, 2017
daab44f
Allow partial handling of some cyclic graphs, at least to
tensorflower-gardener Dec 7, 2017
2d4c29c
Added __str__ to ClusterSpec. This will improve the logs users are at…
ispirmustafa Dec 7, 2017
f37380b
Add tfe.py_func, a tf.py_func-like construct that wraps a Python func…
akshayka Dec 7, 2017
51fa3f7
Added unknown rank support to dense to sparse conversion withing cate…
ispirmustafa Dec 7, 2017
6bb91a0
Make import_graph_def import functions with C API enabled.
skye Dec 7, 2017
7b0458a
BUGFIX: Ensure that rejected states don't propagate NaN. Make float…
langmore Dec 7, 2017
029109b
[XLA] Make hlo_test_base support running and comparing an hlo module.
tensorflower-gardener Dec 7, 2017
6e04085
Change TraceListener::BlockHostUntilDoneComplete to pass Status* rath…
tensorflower-gardener Dec 8, 2017
6c4af62
Make SQLite random IDs unique across tables
jart Dec 8, 2017
7ebe79f
Fix nullptr dereferencing bug in dependency optimizer in VLOG statement.
tensorflower-gardener Dec 8, 2017
fa9ebbc
Don't enforce shape invariance when creating backprop while loop.
skye Dec 8, 2017
affd993
[XLA:CPU] Rename cpu/layout_assignment to cpu/cpu_layout_assignment
Dec 8, 2017
1667d4d
Parallelizing the creation of asynchronous function calls in `tf.data…
jsimsa Dec 8, 2017
0e9cc7f
[XLA] Implement Conditional in XLA service, client ComputationBuilder…
tensorflower-gardener Dec 8, 2017
327fe05
boosted_trees: to name variables so that they can be distinguished in…
tensorflower-gardener Dec 8, 2017
3edc4ee
Add ":grpc_debug_test_server" to the deps of ":debug_py".
chihuahua Dec 8, 2017
0903098
Dont optimize away control triggers
benoitsteiner Dec 8, 2017
233aff8
Create a new Var-like object, LegacyVar, which allows access to its m…
ebrevdo Dec 8, 2017
1c4f65b
Fix a bug in string to number conversion code.
gunan Dec 8, 2017
516f97c
Add a maximum_iterations argument to tf.while_loop.
ebrevdo Dec 8, 2017
43d2411
[XLA] Fix a issue in ParseShapeString which made the hlo parser fail …
tensorflower-gardener Dec 8, 2017
6605938
[XLA:GPU] Support atomic operations on small data types.
tensorflower-gardener Dec 8, 2017
a1c2e20
Introduce an experimental API to pass sharding information from tenso…
tensorflower-gardener Dec 8, 2017
2139d20
Finishing the migration of the RNN / shared fully-connected layer blo…
tensorflower-gardener Dec 8, 2017
ad1834d
Add a "file_pattern" argument to the @{} reference replacer.
MarkDaoust Dec 8, 2017
8b86edc
Return an error when shape inference fails to converge
benoitsteiner Dec 8, 2017
f186c48
Relax the requirements on --input_arrays and --input_shapes.
tensorflower-gardener Dec 8, 2017
29e7222
Misc improvements required for some internal model:
tensorflower-gardener Dec 8, 2017
3a0dd45
Removes invariant that bitcasts must have the same shape byte size as…
tensorflower-gardener Dec 8, 2017
ec0f204
Fix tf.while_loop with maximum_iterations != None and single loop_var.
ebrevdo Dec 8, 2017
bb4ada7
[XLA:GPU] Remove the comment that says b/34969189 blocking TruncateNo…
tensorflower-gardener Dec 8, 2017
d8425f5
Add n=1 special case to the DeserializeSparse op.
mrry Dec 8, 2017
f6bf99e
Add unittest for `tf.contrib.distributions.MixtureSameFamily` `batch_…
jvdillon Dec 8, 2017
213d4be
[XLA:GPU] Support conditional in GPU backend.
tensorflower-gardener Dec 8, 2017
1afc614
BUGFIX: Call convert_to_tensor on input in fill_triangular. Also ch…
langmore Dec 8, 2017
dc04e89
Adding support for new TensorFlow operators. Also adding a transforma…
tensorflower-gardener Dec 8, 2017
2f16f3a
Add bfloat16 support to the CPU backend.
broune Dec 8, 2017
d94d6c4
Add DataFormatVecPermute op.
Dec 8, 2017
e8fcbf6
Always instantiate default attribute values when building a grappler …
benoitsteiner Dec 8, 2017
b8368b7
Extend neutral element optimization to handle
tensorflower-gardener Dec 8, 2017
76b22ba
Share constant creation code.
Dec 8, 2017
b1c7d17
Add a test to simulate the environment where GPU binary running on no…
aaroey Dec 8, 2017
28807c5
Add Operation._remove_all_control_inputs and use in ControlFlowContext.
skye Dec 8, 2017
54c2584
Add a "Snapshot" kernel that always makes a copy of its input. This s…
tensorflower-gardener Dec 8, 2017
e70078e
Add default values of attributes that might have been stripped by
benoitsteiner Dec 9, 2017
ddfd625
Adjust verbosity of per-node logging in graph_properties.
yacoder Dec 9, 2017
a0c2121
Make profiler memory profiling work with tf.while_loop
tensorflower-gardener Dec 9, 2017
b1c64c6
[tf.data] Move tests from python/kernel_tests to python/data/kernel_t…
mrry Dec 9, 2017
7478053
In MovingAverageOptimizer, delegate compute_gradients() to the wrappe…
tensorflower-gardener Dec 9, 2017
c697d96
Preserve symbolic shape information as much as possible during shape …
benoitsteiner Dec 9, 2017
922fbb5
Check that Rendezvous is not null.
tensorflower-gardener Dec 9, 2017
d85eec0
Always create a Rendezvous in RemoteCallOp.
mrry Dec 9, 2017
3457d12
Clear existing layouts before running the layout assignment.
hyouklee Dec 9, 2017
4a19c34
Support non-const input sizes for Conv2DBackpropInput.
Dec 9, 2017
dda6d1b
[XLA] Hlo parser: support reporting error messages with locations poi…
tensorflower-gardener Dec 9, 2017
fefa1c2
Set Operation._id_value before adding to control flow context.
skye Dec 9, 2017
3764127
Replace StreamExecutorInterface::BlockHostUntilDone with BlockHostUnt…
tensorflower-gardener Dec 9, 2017
bbfe1d3
Add IsDataTypeComplex helper function. In numerous places, we needles…
tensorflower-gardener Dec 10, 2017
373b425
Add a library for the cwise ops headers and common source.
tensorflower-gardener Dec 10, 2017
22e0870
Make control_flow_op_py_test "medium" to avoid ASAN timeouts.
skye Dec 11, 2017
309f7e2
Fix mismatched argument comments to match parameter names
tensorflower-gardener Dec 11, 2017
218caf9
Cleanup: Remove unused declarations and unnecessary conversions
tensorflower-gardener Dec 11, 2017
6c33765
Add `grpc_enabled` optional argument to various Python test rules.
mrry Dec 11, 2017
4080299
[XLA] Improve dot strength reductions to support transposes of the ri…
blakehechtman Dec 11, 2017
0683cdb
Enriching some C64 test coverage.
tensorflower-gardener Dec 11, 2017
d423542
Remove using-directives
tensorflower-gardener Dec 11, 2017
865bef3
Use the Snapshot kernel to force a copy of global step instead of the…
tensorflower-gardener Dec 11, 2017
f022085
[XLA] And window reversal to the parser.
blakehechtman Dec 11, 2017
bcdefca
Fix-ups for XLA docs.
eliben Dec 11, 2017
f783790
Make sure we serialize the creation and deletion of clusters from pyt…
benoitsteiner Dec 11, 2017
0117c7a
Use DataFormatVecPermute instead Gather, which is very slow.
Dec 11, 2017
1c46590
Fix incorrect parameter order in recall_at_precision.
tensorflower-gardener Dec 11, 2017
c4ef1b9
Add hvx/nnapi supports for Tensorflow Lite benchmark
zhixianyan Dec 11, 2017
db3ab6c
Adds remaining tests to _shared_embedding_column.
tensorflower-gardener Dec 11, 2017
bc87158
Optimized specializations for 3-channel depthwise with multiplier 2 a…
tensorflower-gardener Dec 11, 2017
b4c447a
Be more conservative when optimizing full reductions
benoitsteiner Dec 11, 2017
916b0d5
Rename ABSL Macros
tensorflower-gardener Dec 11, 2017
dae1f7a
Add ReverseDFSFrom variant that works with const Node*.
tensorflower-gardener Dec 11, 2017
c290c47
Always inline functions when creating an item.
benoitsteiner Dec 11, 2017
0afb4bb
Support all unary ops.
Dec 11, 2017
c381794
Support different threading modes in GPU device.
zheng-xq Dec 11, 2017
a4b68b6
Enable optimizations of operations with neutral/absorbing elements by…
tensorflower-gardener Dec 11, 2017
80ac330
[XLA:CPU] Error on unsupported dot instructions
Dec 11, 2017
db198b8
Tune SQLite
jart Dec 11, 2017
dd77f38
[XLA] Move BatchDot unrolling from TF2XLA bridge to AlgebraicSimplifi…
tensorflower-gardener Dec 11, 2017
037f036
Mark a FunctionDef's signature as stateful when it contains a statefu…
mrry Dec 11, 2017
f86016e
Fix definition of tflite_smartreply
jart Dec 11, 2017
4d1d1a1
Fix for variable naming when executing eagerly
allenlavoie Dec 11, 2017
1b4c609
[tf.data] Use a more efficient dispatch mechanism for functions in da…
mrry Dec 11, 2017
ab08664
[XLA:CPU] Minor refactor to the CPU layout assignment code
Dec 11, 2017
da9ef31
Fix the remote call graph construction so that the created _Send ops …
tensorflower-gardener Dec 12, 2017
634515e
Strength reduce division by a constant to multiplication by the recip…
tensorflower-gardener Dec 12, 2017
ee09e0f
Merge commit for internal changes
Dec 12, 2017
e5246fb
Satisfy buildifier
Dec 12, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ filegroup(
"//tensorflow/java/src/main/native:all_files",
"//tensorflow/python:all_files",
"//tensorflow/python/data:all_files",
"//tensorflow/python/data/kernel_tests:all_files",
"//tensorflow/python/data/ops:all_files",
"//tensorflow/python/data/util:all_files",
"//tensorflow/python/debug:all_files",
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/c/c_api_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ Status FillFunctionBody(
}
node_def->add_input(strings::StrCat("^", normalized));
}

// A function is stateful if any of its nodes are stateful.
if (node->op_def().is_stateful()) {
fdef->mutable_signature()->set_is_stateful(true);
}
}
return Status::OK();
}
Expand Down
45 changes: 45 additions & 0 deletions tensorflow/c/c_api_function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,51 @@ TEST_F(CApiFunctionTest, GetOpDef) {
EXPECT_EQ(op_def.name(), func_name_);
EXPECT_EQ(op_def.input_arg_size(), 1);
EXPECT_EQ(op_def.output_arg_size(), 1);
EXPECT_FALSE(op_def.is_stateful());

TF_DeleteBuffer(buffer);
}

void DefineStatefulFunction(const char* name, TF_Function** func) {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
TF_NewGraph(), TF_DeleteGraph);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
TF_DeleteStatus);

TF_Tensor* tensor_shape = Int32Tensor({37, 1});
TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape");
TF_Operation* random =
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());

TF_Output inputs[] = {};
TF_Output outputs[] = {{random, 0}};
*func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/0, -1,
/*opers=*/nullptr, 0, inputs, 1, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, "", s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
ASSERT_NE(*func, nullptr);
TF_DeleteTensor(tensor_shape);
}

TEST_F(CApiFunctionTest, StatefulOpDef) {
DefineStatefulFunction(func_name_, &func_);
TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);

// Test we can retrieve function OpDef from graph
TF_Buffer* buffer = TF_NewBuffer();
TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);

// Sanity check returned OpDef
string data(static_cast<const char*>(buffer->data), buffer->length);
OpDef op_def;
op_def.ParseFromString(data);
EXPECT_EQ(op_def.name(), func_name_);
EXPECT_EQ(op_def.input_arg_size(), 0);
EXPECT_EQ(op_def.output_arg_size(), 1);
EXPECT_TRUE(op_def.is_stateful());

TF_DeleteBuffer(buffer);
}
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/c/c_test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
return TF_FinishOperation(desc, s);
}

TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype,
TF_Graph* graph, TF_Status* s) {
TF_OperationDescription* desc =
TF_NewOperation(graph, "RandomUniform", "random_uniform");
TF_AddInput(desc, {shape, 0});
TF_SetAttrType(desc, "dtype", dtype);
return TF_FinishOperation(desc, s);
}

void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name, TF_Operation** op) {
TF_Operation* zero = ScalarConst(
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/c/c_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,

TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);

TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype,
TF_Graph* graph, TF_Status* s);

// Split `input` along the first dimention into 3 tensors
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name = "split3");
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/c/python_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,16 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
}
}

void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
mutex_lock l(graph->mu);
std::vector<const Edge*> control_edges;
for (const Edge* edge : op->node.in_edges()) {
if (!edge->IsControlEdge()) continue;
control_edges.push_back(edge);
}
for (const Edge* edge : control_edges) {
graph->graph.RemoveControlEdge(edge);
}
}

} // namespace tensorflow
2 changes: 2 additions & 0 deletions tensorflow/c/python_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status);

void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);

} // namespace tensorflow

#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
100 changes: 32 additions & 68 deletions tensorflow/compiler/tf2xla/lib/batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ namespace tensorflow {

// The current implementation simply unrolls the computation along the batch
// dimension.
// TODO(andydavis): add batching support to XLA's Dot operator.
xla::StatusOr<xla::ComputationDataHandle> BatchDot(
xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) {
Expand All @@ -52,26 +51,20 @@ xla::StatusOr<xla::ComputationDataHandle> BatchDot(

// The batch dimensions must be equal and the matrix dimensions must be
// valid.
std::vector<int64> dimensions;
int64 batch_count = 1;
std::vector<int64> batch_dimension_numbers;
for (int i = 0; i < ndims - 2; ++i) {
int64 x_size = x_shape->dimensions(i);
int64 y_size = y_shape->dimensions(i);
if (x_size != y_size) {
if (x_shape->dimensions(i) != y_shape->dimensions(i)) {
return errors::InvalidArgument(
"Dimension ", i, " of inputs to BatchedDot must be equal: ",
xla::ShapeUtil::HumanString(*x_shape), " vs ",
xla::ShapeUtil::HumanString(*y_shape));
}
dimensions.push_back(x_size);
batch_count *= x_size;
batch_dimension_numbers.push_back(i);
}

int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
int64 x_inner_dim_size = x_shape->dimensions(x_inner_dim);
int64 y_inner_dim_size = y_shape->dimensions(y_inner_dim);
if (x_inner_dim_size != y_inner_dim_size) {
if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) {
return errors::InvalidArgument(
"Dimensions ", x_inner_dim, " and ", y_inner_dim,
" of arguments to BatchedDot must be equal: ",
Expand All @@ -80,75 +73,46 @@ xla::StatusOr<xla::ComputationDataHandle> BatchDot(
" transpose: ", transpose_y);
}

// If there are no batch dimensions, use a regular Dot. This case exists
// to improve the readability of the emitted graphs.
if (dimensions.empty()) {
auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x;
auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y;
return builder->Dot(lhs, rhs);
// Check for zero lhs/rhs dim size.
if (xla::ShapeUtil::HasZeroElements(*x_shape) ||
xla::ShapeUtil::HasZeroElements(*y_shape)) {
std::vector<int64> dimensions(batch_dimension_numbers.size());
for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]);
}
int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
dimensions.push_back(x_shape->dimensions(x_outer_dim));
dimensions.push_back(y_shape->dimensions(y_outer_dim));
return builder->Broadcast(
builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())),
dimensions);
}

int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
dimensions.push_back(x_shape->dimensions(x_outer_dim));
dimensions.push_back(y_shape->dimensions(y_outer_dim));

if (x_shape->element_type() == xla::C64 && transpose_x) {
x = builder->Conj(x);
}
if (y_shape->element_type() == xla::C64 && transpose_y) {
y = builder->Conj(y);
}

// Reshape input tensors into 3D tensors by flattening the batch
// dimensions. This makes it easier to unroll the batch dimension.
auto x_flat =
builder->Reshape(x, {batch_count, x_shape->dimensions(ndims - 2),
x_shape->dimensions(ndims - 1)});
auto y_flat =
builder->Reshape(y, {batch_count, y_shape->dimensions(ndims - 2),
y_shape->dimensions(ndims - 1)});

// Slice batches into individual matrices and multiply them.
std::vector<xla::ComputationDataHandle> out_slices;
for (int64 i = 0; i < batch_count; ++i) {
// Slice off individual matrices and reshape to 2D tensors.
auto x_slice = builder->Slice(
x_flat, {i, 0, 0},
{i + 1, x_shape->dimensions(ndims - 2), x_shape->dimensions(ndims - 1)},
{1, 1, 1});
x_slice = builder->Reshape(x_slice, {x_shape->dimensions(ndims - 2),
x_shape->dimensions(ndims - 1)});
auto y_slice = builder->Slice(
y_flat, {i, 0, 0},
{i + 1, y_shape->dimensions(ndims - 2), y_shape->dimensions(ndims - 1)},
{1, 1, 1});
y_slice = builder->Reshape(y_slice, {y_shape->dimensions(ndims - 2),
y_shape->dimensions(ndims - 1)});

// Transpose if needed.
auto lhs = transpose_x ? builder->Transpose(x_slice, {1, 0}) : x_slice;
auto rhs = transpose_y ? builder->Transpose(y_slice, {1, 0}) : y_slice;

// Multiply matrices and add an outer singleton dimension to the output
// so we can concatenate along the flattened batch dimension later.
auto out = builder->Dot(lhs, rhs);
out = builder->Reshape(out,
{1, dimensions[ndims - 2], dimensions[ndims - 1]});
out_slices.push_back(out);
// If there are no batch dimensions, use a regular Dot.
// TODO(b/69062148) Remove this code when Dot emitters can be passed
// dimensions to transpose directly (i.e. without requiring a Transpose HLO).
if (batch_dimension_numbers.empty()) {
auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x;
auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y;
return builder->Dot(lhs, rhs);
}

// Concatenate output slices and reshape to original number of dimensions.
xla::ComputationDataHandle data;
if (out_slices.empty()) {
// It is illegal to pass an empty list to ConcatInDim.
// The batch count is empty, so both inputs must have zero elements.
// Arbitrarily use the left input as the argument to Reshape().
data = x;
} else {
data = builder->ConcatInDim(out_slices, 0);
xla::DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(x_inner_dim);
dot_dnums.add_rhs_contracting_dimensions(y_inner_dim);
for (auto batch_dimension_number : batch_dimension_numbers) {
dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
}
return builder->Reshape(data, dimensions);
return builder->DotGeneral(x, y, dot_dnums);
}

} // namespace tensorflow
55 changes: 47 additions & 8 deletions tensorflow/compiler/tf2xla/sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,59 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/sharding_util.h"

#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/device_name_utils.h"

namespace tensorflow {
namespace {
const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE";
const char kShardingAttribute[] = "_XlaSharding";
} // namespace

static const char DEVICE_SUFFIX_REPLICATED_CORE[] = "REPLICATED_CORE";
namespace {
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
GetShardingFromNodeDef(const NodeDef& node_def) {
if (!HasNodeAttr(node_def, kShardingAttribute)) {
return tensorflow::gtl::optional<xla::OpSharding>();
}
string value;
xla::OpSharding sharding;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
if (!sharding.ParseFromString(value)) {
return xla::InvalidArgument(
"Experimental _XlaSharding attribute was not a valid encoded "
"xla::OpSharding proto.");
}
return tensorflow::gtl::optional<xla::OpSharding>(sharding);
}

static Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
return errors::InvalidArgument(
"Invalid replicated core id: ", core,
"; num_cores_per_replica=", num_cores_per_replica);
}
} // namespace

xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) {
ParseShardingFromDevice(
const string& device_name, int num_cores_per_replica,
tensorflow::gtl::optional<xla::OpSharding> explicit_sharding) {
if (device_name.empty()) {
return tensorflow::gtl::optional<xla::OpSharding>();
}

DeviceNameUtils::ParsedName parsed_device;
if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) {
return errors::InvalidArgument("Malformed assigned device '", device_name,
"'");
}
if (!parsed_device.has_type ||
!StringPiece(parsed_device.type)
.ends_with(DEVICE_SUFFIX_REPLICATED_CORE)) {

if (explicit_sharding.has_value()) {
return explicit_sharding;
} else if (!parsed_device.has_type || !parsed_device.has_id ||
!StringPiece(parsed_device.type)
.contains(kDeviceSuffixReplicatedCore)) {
return tensorflow::gtl::optional<xla::OpSharding>();
} else {
const int core = parsed_device.id;
Expand All @@ -53,20 +78,34 @@ ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) {
}
}

xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica) {
const string& device_name = node_def.device();
TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional<xla::OpSharding> sharding,
GetShardingFromNodeDef(node_def));
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
}

xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const Node& node, int num_cores_per_replica) {
string device_name = node.assigned_device_name();
if (device_name.empty()) {
device_name = node.requested_device();
}
return ParseShardingFromDevice(device_name, num_cores_per_replica);
TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional<xla::OpSharding> sharding,
GetShardingFromNodeDef(node.def()));
return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding);
}

void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
string device_name = src.assigned_device_name();
if (device_name.empty()) {
device_name = src.requested_device();
}
dst->set_assigned_device_name(device_name);
if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) {
dst->AddAttr(kShardingAttribute, *attr);
}
}

} // namespace tensorflow
13 changes: 10 additions & 3 deletions tensorflow/compiler/tf2xla/sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@ namespace tensorflow {
// - if the device name is invalid.
// - the core is parsed and is out of the range [0, num_cores_per_replica).
//
// Otherwise, returns either a non-value or a sharding set as per
// xla:ShardingBuilder::AssignDevice.
// Otherwise, returns either:
// - explicit_sharding if explicit_sharding.has_value()
// - a non-value if there is no assigned core or
// - a sharding set as per xla::ShardingBuilder::AssignDevice.
xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const string& device_name, int num_cores_per_replica);
ParseShardingFromDevice(const string& device_name, int num_cores_per_replica,
tensorflow::gtl::optional<xla::OpSharding>
explicit_sharding = tensorflow::gtl::nullopt);

xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const Node& node, int num_cores_per_replica);

xla::StatusOr<tensorflow::gtl::optional<xla::OpSharding>>
ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica);

void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst);

} // namespace tensorflow
Expand Down