Skip to content

Commit

Permalink
TensorFlow: upstream latest changes to git.
Browse files Browse the repository at this point in the history
Change 109537918
	TensorFlow pip setup: wheel >= 0.26 for python3 pip install
Change 109505848
	Fix distortion default value to 1.0 in fixed_unigram_candidate_sampler. This means we default to the actual provided unigram distribution, instead of to the uniform (as it is currently).
Change 109470494
	Bugfix in gradients calculation when the ys rely on each other.
Change 109467619
	Fix CIFAR-10 model to train on all the training data instead of just 80% of it. Fixes #396.
Change 109467557
	Replaced checkpoint file with binary GraphDef.
Change 109467433
	Updates to C++ tutorial section.
Change 109465269
	TensorFlow: update documentation for tutorials to not assume use of bazel
	(when possible).
Change 109462916
	A tutorial for image recognition to coincide with the release of the latest Inception image classification model.
Change 109462342
	Clear control dependencies in variable_scope.get_variable() when creating
	ops for the initializer.

	Add tests of various error conditions.
Change 109461981
	Various performance improvements in low-level node execution code paths.

	Speeds up ptb_word_lm on my desktop with a Titan X from
	3638 words per second to 3751 words per second (3.1% speedup).

	Changes include:

	o Avoided many strcmp operations per node execution and extra touches
	of cache lines in executor.cc, by making all the various IsMerge,
	IsSwitch, IsSend, etc. operations instead be based on an internal enum
	value that is pre-computed at Node construction time, rather than doing
	string comparisons against node->type_string().  We were doing about
	6 such comparisons per executed node.

	o Removed mutex_lock in executor.cc in ExecutorState::Process.  The
	lock was not needed and the comment about the iterations array being
	potentially resized is not true (the iterations arrays are created
	with a fixed size).  Checked with yuanbyu to confirm this.

	o Added new two-argument port::Tracing::ScopedAnnotation constructor
	that takes two StringPiece arguments, and only concatenates them
	lazily if tracing is enabled.  Also changed the code in
	platform/tracing.{h,cc} so that the ScopedAnnotation constructor and
	the TraceMe constructor can be inlined.

	o In BaseGPUDevice::Compute, used the two-argument ScopedAnnotation
	constructor to avoid doing StrCat(opkernel->name(), ":",
	op_kernel->type_string()) on every node execution on a GPU.

	o Introduced a new TensorReference class that just holds a reference to an
	underlying TensorBuffer, and requires an explicit Unref().

	o Changed the EventMgr interface to take a vector of TensorReference objects
	for EventMgr::ThenDeleteTensors, rather than a vector of Tensor objects.

	o Used TensorReference in a few places in gpu_util.cc

	o Minor: switched to using InlinedVectors in a few places to get better
	cache locality.
Change 109456692
	Updated the label_image example to use the latest Inception model
Change 109456545
	Provides classify_image which performs image recognition on a 1000 object label set.

	  $ ./classify_image
	  giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (score = 0.88493)
	  indri, indris, Indri indri, Indri brevicaudatus (score = 0.00878)
	  lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (score = 0.00317)
	  custard apple (score = 0.00149)
	  earthstar (score = 0.00127)

Change 109455002
	TensorFlow: make the helper libraries for various models available
	in the pip package so that when users type:

	python translate.py ...

	the absolute import works.

	This change is supposed to help make our tutorials run without the
	*need* to use bazel.
Change 109450041
	TensorFlow: remove cifar and convolutional binary copies from pip install.
	Adds embedding and some other models to the list.
Change 109448520
	Move the description of a failing invariant from a comment into the dcheck-fail message text.
Change 109447577
	TensorBoard has release tagging (tensorboard/TAG)
	Also track TensorBoard changes (tensorboard/CHANGES)
Change 109444161
	Added ParseSingleSequenceExample + python wrappers + unit tests.
Change 109440864
	Update all the TensorFlow Dockerfiles, and simplify GPU containers.

	This change updates all four of our Dockerfiles to match the targets discussed
	in #149. The most notable
	change here is moving the GPU images to use the NVidia containers which
	include cudnn and other build-time dependencies, dramatically simplifying both
	the build and run steps.

	A description of which tags exist and get pushed where will be in a follow-up.
Change 109432591
	Some pylint and pydoc changes in saver.
Change 109430127
	Remove unused hydrogen components
Change 109419354
	The RNN api, although moved into python/ops/, remains undocumented.

	It may still change at any time.

Base CL: 109538006
  • Loading branch information
Vijay Vasudevan committed Dec 6, 2015
1 parent 40d0d29 commit f9d3e9d
Show file tree
Hide file tree
Showing 70 changed files with 2,784 additions and 1,274 deletions.
25 changes: 15 additions & 10 deletions tensorflow/core/common_runtime/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ class ExecutorState {
int max_parallel_iterations = 1;

// The iteration states of this frame.
std::vector<IterationState*> iterations;
gtl::InlinedVector<IterationState*, 12> iterations;

// The NextIteration nodes to enter a new iteration. If the number of
// outstanding iterations reaches the limit, we will defer the start of
Expand Down Expand Up @@ -672,6 +672,16 @@ class ExecutorState {

// One thread of control finishes.
void Finish();

// A standalone routine for this expression so that we can express
// that we don't want thread safety analysis on this reference (it's
// safe to do without the lock because the iterations array never
// resizes and this particular iteration's array element will not
// be changed out from under us because the iteration is still alive).
std::vector<Entry>* GetInputTensors(FrameState* input_frame, int64 input_iter)
const NO_THREAD_SAFETY_ANALYSIS {
return input_frame->GetIteration(input_iter)->input_tensors;
}
};

ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
Expand Down Expand Up @@ -891,13 +901,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {

VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def());

std::vector<Entry>* input_tensors;
{
// Need the lock because the iterations vector could be resized by
// another thread.
mutex_lock l(mu_);
input_tensors = input_frame->GetIteration(input_iter)->input_tensors;
}
std::vector<Entry>* input_tensors =
GetInputTensors(input_frame, input_iter);
Entry* first_input = input_tensors->data() + item.input_start;
outputs.clear();
outputs.resize(node->num_outputs());
Expand Down Expand Up @@ -1081,9 +1086,9 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,

for (int i = 0; i < node->num_outputs(); ++i) {
TensorValue val = ctx->release_output(i);
// Only Switch and Recv can generate new dead outputs.
if (*ctx->is_output_dead() || val.tensor == nullptr) {
DCHECK(IsSwitch(node) || IsRecv(node));
DCHECK(IsSwitch(node) || IsRecv(node))
<< "Only Switch and Recv can generate new dead outputs.";
} else {
Entry* out = &((*outputs)[i]);
out->has_value = true;
Expand Down
22 changes: 11 additions & 11 deletions tensorflow/core/common_runtime/gpu/gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,8 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
context->SetStatus(errors::Internal(
"Invalid synchronous 'Compute' on GPU for '_Recv' op"));
} else {
const string label =
strings::StrCat(op_kernel->name(), ":", op_kernel->type_string());
port::Tracing::ScopedAnnotation annotation(label);
port::Tracing::ScopedAnnotation annotation(op_kernel->name(),
op_kernel->type_string());

const auto num_streams = streams_.size();
if (num_streams > 1) {
Expand Down Expand Up @@ -320,18 +319,19 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
// Keep a copy of the inputs before Compute runs, in case they get
// deleted. TODO(misard) this will be fixed when the tracking is
// done right.
std::vector<Tensor>* tensor_refs = nullptr;
EventMgr::TensorReferenceVector* tensor_refs = nullptr;
if (!FLAGS_brain_gpu_sync_every_op) {
tensor_refs = new std::vector<Tensor>;
tensor_refs->reserve(context->num_inputs() + context->num_outputs());
for (int ii = 0; ii < context->num_inputs(); ++ii) {
const int N_inputs = context->num_inputs();
tensor_refs = new EventMgr::TensorReferenceVector;
tensor_refs->reserve(N_inputs + context->num_outputs());
for (int ii = 0; ii < N_inputs; ++ii) {
if (context->has_input(ii)) {
if (IsRefType(context->input_dtype(ii))) {
Tensor in = context->mutable_input(ii, false);
tensor_refs->push_back(in);
tensor_refs->push_back(TensorReference(in));
} else {
const Tensor& in = context->input(ii);
tensor_refs->push_back(in);
tensor_refs->push_back(TensorReference(in));
}
}
}
Expand All @@ -353,12 +353,12 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
for (int ii = 0; ii < context->num_temps(); ++ii) {
Tensor* temp = context->temp(ii);
VLOG(2) << "Saving ref to temp Tensor @ " << DMAHelper::base(temp);
tensor_refs->push_back(*temp);
tensor_refs->push_back(TensorReference(*temp));
}
for (int ii = 0; ii < context->num_outputs(); ++ii) {
Tensor* temp = context->mutable_output(ii);
if (nullptr != temp) {
tensor_refs->push_back(*temp);
tensor_refs->push_back(TensorReference(*temp));
}
}
em_->ThenDeleteTensors(stream, tensor_refs);
Expand Down
16 changes: 12 additions & 4 deletions tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <deque>
#include <vector>
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
Expand All @@ -45,10 +46,12 @@ class EventMgr {

~EventMgr();

typedef gtl::InlinedVector<TensorReference, 4> TensorReferenceVector;

// Takes ownership of *tensors and deletes it as soon as all events
// currently enqueued on *stream have completed.
inline void ThenDeleteTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors) {
TensorReferenceVector* tensors) {
ToFreeVector to_free;
{
mutex_lock l(mu_);
Expand Down Expand Up @@ -94,7 +97,7 @@ class EventMgr {

struct InUse {
perftools::gputools::Event* event;
std::vector<Tensor>* mem;
TensorReferenceVector* mem;
BufRec bufrec;
std::function<void()> func;
};
Expand All @@ -103,7 +106,12 @@ class EventMgr {

void FreeMemory(const ToFreeVector& to_free) {
for (const auto& iu : to_free) {
delete iu.mem;
if (iu.mem != nullptr) {
for (auto& t : *(iu.mem)) {
t.Unref();
}
delete iu.mem;
}
if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
// The function must be called in another thread.
if (iu.func != nullptr) threadpool_.Schedule(iu.func);
Expand All @@ -118,7 +126,7 @@ class EventMgr {
EXCLUSIVE_LOCKS_REQUIRED(mu_);

void QueueTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors)
TensorReferenceVector* tensors)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr});
}
Expand Down
20 changes: 10 additions & 10 deletions tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class TEST_EventMgrHelper {
}

void QueueTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors) {
EventMgr::TensorReferenceVector* tensors) {
mutex_lock l(em_->mu_);
em_->QueueTensors(stream, tensors);
}
Expand Down Expand Up @@ -77,12 +77,12 @@ TEST(EventMgr, DelayedPolling) {
EventMgr em(stream_exec);
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
std::vector<Tensor>* v = nullptr;
EventMgr::TensorReferenceVector* v = nullptr;
std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
v = new std::vector<Tensor>;
v = new EventMgr::TensorReferenceVector;
th.QueueTensors(stream.get(), v);
EXPECT_EQ(i + 1, th.queue_size());
EXPECT_EQ(0, th.free_size());
Expand All @@ -92,7 +92,7 @@ TEST(EventMgr, DelayedPolling) {
EXPECT_EQ(5, th.free_size());
for (int j = 0; j < 2; ++j) {
for (int i = 0; i < 5; ++i) {
v = new std::vector<Tensor>;
v = new EventMgr::TensorReferenceVector;
th.QueueTensors(stream.get(), v);
EXPECT_EQ(i + 1, th.queue_size());
EXPECT_EQ(4 - i, th.free_size());
Expand All @@ -110,12 +110,12 @@ TEST(EventMgr, ImmediatePolling) {
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
EXPECT_EQ(0, th.free_size());
std::vector<Tensor>* v = nullptr;
EventMgr::TensorReferenceVector* v = nullptr;
std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
v = new std::vector<Tensor>;
v = new EventMgr::TensorReferenceVector;
em.ThenDeleteTensors(stream.get(), v);
EXPECT_EQ(0, th.queue_size());
EXPECT_EQ(1, th.free_size());
Expand All @@ -130,12 +130,12 @@ TEST(EventMgr, LongDelayedPolling) {
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
EXPECT_EQ(0, th.free_size());
std::vector<Tensor>* v = nullptr;
EventMgr::TensorReferenceVector* v = nullptr;
std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
v = new std::vector<Tensor>;
v = new EventMgr::TensorReferenceVector;
th.QueueTensors(stream.get(), v);
EXPECT_EQ(1 + i, th.queue_size());
EXPECT_EQ(0, th.free_size());
Expand All @@ -153,12 +153,12 @@ TEST(EventMgr, NonEmptyShutdown) {
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
EXPECT_EQ(0, th.free_size());
std::vector<Tensor>* v = nullptr;
EventMgr::TensorReferenceVector* v = nullptr;
std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
v = new std::vector<Tensor>;
v = new EventMgr::TensorReferenceVector;
th.QueueTensors(stream.get(), v);
EXPECT_EQ(1 + i, th.queue_size());
EXPECT_EQ(0, th.free_size());
Expand Down
13 changes: 7 additions & 6 deletions tensorflow/core/common_runtime/gpu/gpu_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
Expand Down Expand Up @@ -91,7 +92,7 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
DeviceMemoryBase gpu_src_ptr(const_cast<char*>(src_ptr), num_bytes);
stream->ThenMemcpy(mb, gpu_src_ptr, num_bytes);
// Use of tensor may outlive stack scope, so keep a ref.
Tensor* tensor_ref = new Tensor(tensor);
TensorReference tensor_ref(tensor);
dev->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
stream, [stream, done, proto, mb, num_bytes, alloc, tensor_ref]() {
if (!stream->ok()) {
Expand All @@ -104,7 +105,7 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
LOG(FATAL) << "SetProtoFromGPU: GPU Memcpy failed";
return;
}
delete tensor_ref;
tensor_ref.Unref();
port::CopyFromArray(proto->mutable_tensor_content(), mb, num_bytes);
alloc->Deallocate<char>(mb);
done(Status::OK());
Expand Down Expand Up @@ -169,10 +170,10 @@ void GPUUtil::CopyViaDMA(const string& edge_name,
total_bytes);
if (dst_device_type == DeviceType(DEVICE_GPU).type()) {
// Use of input may outlive stack scope, so keep a ref.
Tensor* input_ref = new Tensor(*input);
TensorReference input_ref(*input);
src_dev_info->event_mgr->ThenExecute(
stream, [done, stream, input_ref]() {
delete input_ref;
input_ref.Unref();
if (!stream->ok()) {
done(errors::Internal("GPU->GPU Memcpy failed"));
} else {
Expand Down Expand Up @@ -262,9 +263,9 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
auto* dev_info = gpu_device->tensorflow_gpu_device_info();
// Use of cpu_tensor may outlive stack scope, so keep a ref.
Tensor* input_ref = new Tensor(*cpu_tensor);
TensorReference input_ref(*cpu_tensor);
dev_info->event_mgr->ThenExecute(stream, [stream, done, input_ref]() {
delete input_ref;
input_ref.Unref();
if (!stream->ok()) {
done(errors::Internal("CopyCPUTensorToGPU: GPU Memcpy failed"));
} else {
Expand Down
47 changes: 47 additions & 0 deletions tensorflow/core/framework/tensor_reference.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
#define TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_

#include "tensorflow/core/public/tensor.h"

namespace tensorflow {

// An opaque class that holds a reference to an underlying TensorBuffer.
// Unlike Tensor, it does not have any shape or type information, so
// it is cheaper to construct/move, but the only thing you can really do
// with it is Unref it, which releases one of the references to the underlying
// TensorBuffer.
// IMPORTANT: If you do not call Unref(), you will likely leak tensor memory.
class TensorReference {
public:
explicit TensorReference(const Tensor& tensor) : buf_(tensor.buf_) {
if (buf_) buf_->Ref();
}

~TensorReference() {}

void Unref() const {
if (buf_) buf_->Unref();
}

private:
TensorBuffer* buf_;
};

} // namespace tensorflow

#endif // TENSORFLOW_FRAMEWORK_TENSOR_REFERENCE_H_
36 changes: 35 additions & 1 deletion tensorflow/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ string Node::DebugString() const {
}

Node::Node()
: id_(-1), cost_id_(-1), props_(nullptr), assigned_device_name_() {}
: id_(-1),
cost_id_(-1),
class_(NC_UNINITIALIZED),
props_(nullptr),
assigned_device_name_() {}

Node::~Node() {
if (props_) {
Expand All @@ -65,13 +69,43 @@ void Node::Initialize(int id, int cost_id, Properties* props) {
props_->Unref();
}
props_ = props;
// Initialize the class_ based on the type string
const string& ts = this->type_string();
class_ = NC_UNINITIALIZED;

#define SET_CLASS(enum_val, ts, str1, str2) \
do { \
if ((((ts) == (str1)) || ((ts) == (str2)))) { \
/* Cannot be member of more than one class*/ \
CHECK(class_ == NC_UNINITIALIZED); \
class_ = (enum_val); \
} \
} while (0)

SET_CLASS(NC_SWITCH, ts, "Switch", "RefSwitch");
SET_CLASS(NC_MERGE, ts, "Merge", "");
SET_CLASS(NC_ENTER, ts, "Enter", "RefEnter");
SET_CLASS(NC_EXIT, ts, "Exit", "");
SET_CLASS(NC_NEXT_ITERATION, ts, "NextIteration", "");
SET_CLASS(NC_LOOP_COND, ts, "LoopCond", "");
SET_CLASS(NC_CONTROL_TRIGGER, ts, "ControlTrigger", "");
SET_CLASS(NC_SEND, ts, "_Send", "_HostSend");
SET_CLASS(NC_RECV, ts, "_Recv", "_HostRecv");
SET_CLASS(NC_CONSTANT, ts, "Const", "HostConst");
SET_CLASS(NC_VARIABLE, ts, "Variable", "");
SET_CLASS(NC_IDENTITY, ts, "Identity", "RefIdentity");
if (class_ == NC_UNINITIALIZED) {
class_ = NC_OTHER; // Catch all
}
#undef SET_CLASS
}

void Node::Clear() {
in_edges_.clear();
out_edges_.clear();
id_ = -1;
cost_id_ = -1;
class_ = NC_UNINITIALIZED;

if (props_) {
props_->Unref();
Expand Down
Loading

0 comments on commit f9d3e9d

Please sign in to comment.