Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 163636676 #11893

Merged
merged 21 commits into from Jul 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
df897dd
Update ops-related pbtxt files.
tensorflower-gardener Jul 28, 2017
f00a7a0
Go: Update generated wrapper functions for TensorFlow ops.
tensorflower-gardener Jul 28, 2017
036bba5
[XLA:CPU] Emit DynamicUpdateSlice in place (when feasible).
tensorflower-gardener Jul 28, 2017
e9c5f33
Remove XLA Client::ExecuteAsync API.
Jul 28, 2017
e5d8470
Create VariantTensorData to use instead of VariantTensorDataProto.
gunan Jul 28, 2017
022413e
TFTS: nomsan on estimators_test to avoid timeouts
allenlavoie Jul 28, 2017
92620df
Minor comment improvement.
tensorflower-gardener Jul 28, 2017
a78545d
tfdbg: clean up SessionRunHook code
caisq Jul 28, 2017
1773d56
In ShapeRefiner constant folding, when extracting subgraph, make sure to
tensorflower-gardener Jul 28, 2017
d49af9a
Disable constant folding for session debug tests.
Jul 28, 2017
39cd350
Make assert_rank_in visible.
vahidk Jul 28, 2017
e5c52bf
Remove a dead return statement in AlgebraicSimplifier.
tensorflower-gardener Jul 28, 2017
c3664d1
Fix the constructor parameter in InstructionList to be pass-by-refere…
tensorflower-gardener Jul 28, 2017
5270d48
Replace constants with variables in memory optimizer test, so that th…
Jul 29, 2017
c44c8d0
Remove unused BUILD dependencies
tensorflower-gardener Jul 29, 2017
25ed0f2
Remove "small" test size for window_ops_test.
rryan Jul 29, 2017
f8b575d
Add a tf.contrib.signal API guide.
rryan Jul 29, 2017
adf5d1b
Automated g4 rollback of changelist 163510186
tensorflower-gardener Jul 29, 2017
00d3126
Change const nodes to variables in the test, so that they are not opt…
Jul 30, 2017
e17650b
This adds the cluster_resolver module to contrib/__init__.py so that …
Jul 30, 2017
f80b081
Merge commit for internal changes
Jul 30, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensorflow/c/c_api.h
Expand Up @@ -363,9 +363,9 @@ typedef struct TF_Output {
// Sets the shape of the Tensor referenced by `output` in `graph` to
// the shape described by `dims` and `num_dims`.
//
// If the number of dimensions is unknown, `num_dims` must be
// set to -1 and dims can be null. If a dimension is unknown,
// the corresponding entry in the `dims` array must be -1.
// If the number of dimensions is unknown, `num_dims` must be set to
// -1 and `dims` can be null. If a dimension is unknown, the
// corresponding entry in the `dims` array must be -1.
//
// This does not overwrite the existing shape associated with `output`,
// but merges the input shape with the existing shape. For example,
Expand Down
55 changes: 0 additions & 55 deletions tensorflow/compiler/xla/client/client.cc
Expand Up @@ -294,61 +294,6 @@ StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
return device_handles;
}

StatusOr<ExecutionHandle> Client::ExecuteAsync(
const Computation& computation,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const ExecutionOptions* execution_options) {
ExecuteAsyncRequest request;
*request.mutable_computation() = computation.handle();
for (GlobalData* argument : arguments) {
*request.add_arguments() = argument->handle();
}
if (execution_options == nullptr) {
*request.mutable_execution_options() = CreateDefaultExecutionOptions();
} else {
*request.mutable_execution_options() = *execution_options;
}

ExecuteAsyncResponse response;
VLOG(1) << "making execute async request: " << request.ShortDebugString();
Status s = stub_->ExecuteAsync(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}

return response.execution();
}

StatusOr<std::unique_ptr<GlobalData>> Client::WaitForExecution(
const Computation& computation, const ExecutionHandle& execution,
ExecutionProfile* execution_profile) {
WaitForExecutionRequest request;
*request.mutable_execution() = execution;

WaitForExecutionResponse response;
VLOG(1) << "making wait-for-execute request: " << request.ShortDebugString();
Status s = stub_->WaitForExecution(&request, &response);
VLOG(1) << "done with request";

if (!s.ok()) {
return s;
}

if (execution_profile != nullptr) {
*execution_profile = response.profile();
if (VLOG_IS_ON(1)) {
TF_ASSIGN_OR_RETURN(
auto execution_stats,
ExecutionStatsAsString(computation, response.profile()));
VLOG(1) << execution_stats;
}
}

return MakeUnique<GlobalData>(stub_, response.output());
}

Status Client::Unregister(const GlobalData& data) {
UnregisterRequest request;
*request.mutable_data() = data.handle();
Expand Down
18 changes: 0 additions & 18 deletions tensorflow/compiler/xla/client/client.h
Expand Up @@ -75,24 +75,6 @@ class Client {
// TransferToInfeed).
StatusOr<std::vector<DeviceHandle>> GetDeviceHandles(int64 device_count);

// Executes the given computation as above Execute(), but launches the
// computation asynchronously and returns before the execution is complete.
// Returns an ExecutionHandle that represents the launched execution, which is
// used to call WaitForExecution() to wait for the execution's completion.
StatusOr<ExecutionHandle> ExecuteAsync(
const Computation& computation,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const ExecutionOptions* execution_options = nullptr);

// Waits until the given asynchronously launched execution of the computation
// is complete and returns the execution result. Once this is called, the
// given execution handle is no longer valid. If execution_profile is not
// nullptr then the pointed-to ExecutionProfile will be filled with profile
// data from the execution.
StatusOr<std::unique_ptr<GlobalData>> WaitForExecution(
const Computation& computation, const ExecutionHandle& execution,
ExecutionProfile* execution_profile = nullptr);

// Transfer the global data provided to this client process, which is
// returned in the provided literal. Use sparingly to avoid transfer
// overheads.
Expand Down
1 change: 0 additions & 1 deletion tensorflow/compiler/xla/service/BUILD
Expand Up @@ -1254,7 +1254,6 @@ cc_library(
":hlo",
":hlo_cost_analysis",
":human_readable_profile_builder",
"//tensorflow/compiler/xla:metric_table_report",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
Expand Down
1 change: 0 additions & 1 deletion tensorflow/compiler/xla/service/algebraic_simplifier.cc
Expand Up @@ -1228,7 +1228,6 @@ Status AlgebraicSimplifierVisitor::HandleReduce(
return ReplaceWithNewInstruction(
reduce,
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
return Status::OK();
}
// A Transpose feeding a reduce can simply permute the reduction dimensions
// field.
Expand Down
105 changes: 103 additions & 2 deletions tensorflow/compiler/xla/service/cpu/ir_emitter.cc
Expand Up @@ -2089,15 +2089,116 @@ Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice,
return DefaultAction(dynamic_slice);
}

namespace {

// Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
// (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
const HloInstruction* LatestNonGteAncestorAndIndex(const HloInstruction* hlo,
ShapeIndex* index) {
if (hlo->opcode() == HloOpcode::kGetTupleElement) {
const auto* operand = LatestNonGteAncestorAndIndex(hlo->operand(0), index);
index->push_back(hlo->tuple_index());
return operand;
}
return hlo;
}

// Checks if we can emit code for DynamicUpdateSlice to update data in-place.
// Returns true if operand 0 of DynamicUpdateSlice and its output buffer
// share the same buffer allocation.
// Returns false otherwise.
// TODO(b/64142684) Share code with GPU implementation.
bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment,
HloInstruction* dynamic_update_slice) {
CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());

// Walk DynamicUpdateSlice operand(0) to parameter and get its
// associated operand. See if it shares an allocation with this operand.
ShapeIndex index;
auto* operand =
LatestNonGteAncestorAndIndex(dynamic_update_slice->operand(0), &index);
if (operand->opcode() != HloOpcode::kParameter) {
return false;
}

BufferAllocation::Slice operand_slice =
assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie();

BufferAllocation::Slice dynamic_update_slice_slice =
assignment.GetUniqueTopLevelSlice(dynamic_update_slice)
.ConsumeValueOrDie();

return operand_slice == dynamic_update_slice_slice;
}

} // namespace

Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
HloInstruction* /*operand*/,
HloInstruction* operand,
HloInstruction* update,
HloInstruction* /*start_indices*/) {
HloInstruction* start_indices) {
if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
EmitTargetAddressForOp(dynamic_update_slice));
emitted_value_[dynamic_update_slice] = target_address;
return EmitMemcpy(*update, *dynamic_update_slice);
} else if (CanUpdateDynamicSliceInPlace(assignment_, dynamic_update_slice)) {
VLOG(2) << "Emitting HandleDynamicUpdateSlice in-place.";
// DynamicUpdateSlice's operand(0) and 'fusion' output share the same
// BufferAllocation::Slice, so it is safe to emit code to update the slice
// 'in-place'. This avoids copying data outside of the slice update region.
// TODO(b/64142684) Implement in-place update for fused DynamicUpdateSlice.

// Emit IR to read dynamic start indices from 'start_indices'.
const int64 rank = ShapeUtil::Rank(operand->shape());
llvm_ir::IrArray::Index start_index(rank);
for (int64 i = 0; i < rank; ++i) {
llvm_ir::IrArray::Index dim_index({ir_builder_.getInt64(i)});
llvm_ir::IrArray start_indices_array(GetIrArrayForOp(start_indices));
start_index[i] =
start_indices_array.EmitReadArrayElement(dim_index, &ir_builder_);
}

// Create loop body emitter which emits code to do the following:
// *) Map requested 'index' and slice 'start_index' to input/output shape
// as 'output_index'.
// *) Reads value from 'update'.
// *) Writes value to input/output array at 'output_index'.
auto loop_body_emitter =
[&](const llvm_ir::IrArray::Index& index) -> Status {
// Calculate 'output_index' at which to write value from update.
llvm_ir::IrArray::Index output_index(rank);
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// output_index = (start_index + index) % dim_size
llvm::Value* dim_size = llvm::ConstantInt::get(
index[i]->getType(), operand->shape().dimensions(i));
llvm::Value* start_index0 = ir_builder_.CreateZExtOrBitCast(
start_index[i], index[i]->getType());
output_index[i] = ir_builder_.CreateURem(
ir_builder_.CreateAdd(start_index0, index[i]), dim_size);
}

// Read value from 'update'.
llvm_ir::IrArray update_array(GetIrArrayForOp(update));
llvm::Value* update_data =
update_array.EmitReadArrayElement(index, &ir_builder_);

// Write value to output array.
llvm_ir::IrArray(GetEmittedValueFor(operand), operand->shape())
.EmitWriteArrayElement(output_index, update_data, &ir_builder_);
return Status::OK();
};

TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(loop_body_emitter, update->shape(), &ir_builder_)
.EmitLoop());

TF_ASSIGN_OR_RETURN(llvm::Value * dynamic_update_slice_address,
EmitTargetAddressForOp(dynamic_update_slice));
emitted_value_[dynamic_update_slice] = dynamic_update_slice_address;
return Status::OK();
}
return DefaultAction(dynamic_update_slice);
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/service/hlo_rematerialization.cc
Expand Up @@ -83,7 +83,7 @@ bool IsRematerializable(const HloInstruction* instruction) {
// before arbitrary elements.
class InstructionList {
public:
explicit InstructionList(const std::vector<const HloInstruction*> order) {
explicit InstructionList(const std::vector<const HloInstruction*>& order) {
int64 position = 0;
for (const HloInstruction* inst : order) {
instructions_.push_back(const_cast<HloInstruction*>(inst));
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/xla/service/service.h
Expand Up @@ -133,6 +133,10 @@ class Service : public ServiceInterface {
// Asynchronously executes a computation with provided arguments. Invokes
// the provided computation with the provided global data passed as
// immutable arguments. Returns a handle to the execution.
//
// (Note: The corresponding function in xla::Client was removed as part of
// b/64116060, in an attempt to simplify our API. We're keeping this around
// for now in case we want to expose this to clients in a different way.)
tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
ExecuteAsyncResponse* result) override;

Expand Down
6 changes: 0 additions & 6 deletions tensorflow/compiler/xla/tests/client_library_test_base.cc
Expand Up @@ -82,12 +82,6 @@ StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
return client_->Execute(computation, arguments, &execution_options_);
}

StatusOr<ExecutionHandle> ClientLibraryTestBase::ExecuteAsync(
const Computation& computation,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
return client_->ExecuteAsync(computation, arguments, &execution_options_);
}

StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
const Computation& computation,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
Expand Down
3 changes: 0 additions & 3 deletions tensorflow/compiler/xla/tests/client_library_test_base.h
Expand Up @@ -77,9 +77,6 @@ class ClientLibraryTestBase : public ::testing::Test {
StatusOr<std::unique_ptr<GlobalData>> Execute(
ComputationBuilder* builder,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
StatusOr<ExecutionHandle> ExecuteAsync(
const Computation& computation,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
ComputationBuilder* builder,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
Expand Down
27 changes: 18 additions & 9 deletions tensorflow/compiler/xla/tests/while_test.cc
Expand Up @@ -784,8 +784,10 @@ void BM_WhileLoop(int num_iters) {
LocalClient* client =
ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();

const int64 seq_len = 100;
Shape loop_state_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})});
{ShapeUtil::MakeShape(S32, {}),
ShapeUtil::MakeShape(F32, {seq_len, 1024, 1024})});

// Create while condition computation with 'loop_limit'.
const int32 loop_limit = 100;
Expand All @@ -803,20 +805,27 @@ void BM_WhileLoop(int num_iters) {
{
ComputationBuilder builder(client, "body");
auto prev = builder.Parameter(0, loop_state_shape, "prev");
// TupleElement 0
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto one = builder.ConstantR0<int32>(1);
auto next_iteration = builder.Add(iteration, one);
auto one_vec = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, one_vec);
auto result = builder.Tuple({next_iteration, new_weights});
auto out0 = builder.Add(iteration, builder.ConstantR0<int32>(1));
// TupleElement 1
auto input = builder.GetTupleElement(prev, 1);
// Update.
auto one = builder.ConstantR0<float>(1.0);
auto update = builder.Broadcast(one, {1, 1024, 1024});
// Starts = iteration * 2;
auto starts = builder.ConstantR1<int32>({0, 0, 0});
// UpdateSlice.
auto out1 = builder.DynamicUpdateSlice(input, update, starts);
auto result = builder.Tuple({out0, out1});
body = builder.Build().ConsumeValueOrDie();
}

// Create a While instruction.
ComputationBuilder builder(client, "while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto zero = builder.ConstantR0<float>(0.0);
auto input = builder.Broadcast(zero, {seq_len, 1024, 1024});
auto init = builder.Tuple({builder.ConstantR0<int32>(0), input});
builder.While(condition, body, init);
auto computation = builder.Build().ConsumeValueOrDie();

Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/__init__.py
Expand Up @@ -21,6 +21,7 @@
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import bayesflow
from tensorflow.contrib import cloud
from tensorflow.contrib import cluster_resolver
from tensorflow.contrib import compiler
from tensorflow.contrib import copy_graph
from tensorflow.contrib import crf
Expand Down
1 change: 0 additions & 1 deletion tensorflow/contrib/signal/BUILD
Expand Up @@ -78,7 +78,6 @@ cuda_py_tests(

cuda_py_tests(
name = "window_ops_test",
size = "small",
srcs = ["python/kernel_tests/window_ops_test.py"],
additional_deps = [
":signal_py",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/contrib/signal/__init__.py
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Signal processing operations.

See the @{$python/contrib.signal} guide.

@@frame
@@hamming_window
@@hann_window
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/timeseries/python/timeseries/BUILD
Expand Up @@ -103,6 +103,7 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_pip_gpu", # b/63391119
"nomsan", # Takes too long to run.
],
deps = [
":ar_model",
Expand Down
1 change: 0 additions & 1 deletion tensorflow/core/BUILD
Expand Up @@ -2972,7 +2972,6 @@ cc_test(
srcs = ["example/example_parser_configuration_test.cc"],
data = [":example_parser_configuration_testdata"],
deps = [
":core",
":core_cpu",
":core_cpu_internal",
":direct_session_internal",
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/common_runtime/executor.cc
Expand Up @@ -1720,8 +1720,8 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
// Only merge and transfer nodes can have no-value inputs.
if (!entry->has_value) {
if (!is_merge) {
DCHECK(IsTransferNode(node));
DCHECK(!entry->val_field_is_set);
DCHECK(IsTransferNode(node)) << node->name() << " - input " << i;
DCHECK(!entry->val_field_is_set) << node->name() << " - input " << i;
entry->has_value = true;
entry->val_field_is_set = true;
entry->val.Init(*kEmptyTensor);
Expand Down