Skip to content

Commit

Permalink
TensorFlow: Upstream changes to git.
Browse files Browse the repository at this point in the history
Change 109730179
	Add support for selecting partition strategy in tf.nn.embedding_lookup and related ops, and allow unequally-sized shards to be used as input.
Change 109729548
	TensorFlow: add RELEASE.md notes for 0.6.0.
Change 109728185
	Make seq2seq_test non-flaky by setting python and numpy random seed.
Change 109725913
	Refactor slot creation in optimizers and moving averages to separate file
Change 109718024
	TensorFlow: reduce runtime of seq2seq_test from ~30s to ~18s.
Change 109712251
	More performance improvement for convnet on GPU.
	+ Switch forward convolution format to NCHW.
	+ Allocate scratch space for forward- and backward- convolutions.
	+ Users can use "TF_CUDNN_WORKSPACE_LIMIT_IN_MB" to configure the scratch space
	limit. The default limit in 1GB.
Change 109710898
	Added extract_sub_graph utility function

Base CL: 109731609
  • Loading branch information
Vijay Vasudevan committed Dec 8, 2015
1 parent ddd4aaf commit 2c3738d
Show file tree
Hide file tree
Showing 14 changed files with 764 additions and 138 deletions.
27 changes: 27 additions & 0 deletions RELEASE.md
@@ -1,3 +1,30 @@
# Release 0.6.0

## Major Features and Improvements

* Python 3.3+ support via changes to python codebase and ability
to specify python version via ./configure.

* Some improvements to GPU performance and memory usage:
[convnet benchmarks](https://github.com/soumith/convnet-benchmarks/issues/66)
roughly equivalent with native cudnn v2 performance. Improvements mostly due
to moving to 32-bit indices, faster shuffling kernels. More improvements to
come in later releases.


## Bug fixes

* Lots of fixes to documentation and tutorials, many contributed
by the public.

* 271 closed issues on github issues.

## Backwards-incompatible changes

* tf.nn.fixed_unigram_candidate_sampler changed its default 'distortion'
attribute from 0.0 to 1.0. This was a bug in the original release
that is now fixed.

# Release 0.5.0

Initial release of TensorFlow.
36 changes: 18 additions & 18 deletions tensorflow/core/kernels/conv_grad_ops.cc
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#endif // GOOGLE_CUDA

namespace tensorflow {
Expand Down Expand Up @@ -756,17 +757,6 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")

// GPU definitions of both ops.
#if GOOGLE_CUDA
namespace {
template <typename T>
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
uint64 size) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
size * sizeof(T));
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
}
} // namespace

// The slow version (but compiles for GPU)

// Backprop for input.
Expand Down Expand Up @@ -929,10 +919,15 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
pre_transformed_in_backprop.template flat<T>().size());

static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
);
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
bool cudnn_launch_status =
stream->ThenConvolveBackwardData(filter_desc, filter_ptr, output_desc,
out_backprop_ptr, conv_desc,
input_desc, &in_backprop_ptr)
stream->ThenConvolveBackwardDataWithScratch(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator)
.ok();

if (!cudnn_launch_status) {
Expand Down Expand Up @@ -1185,7 +1180,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
context->eigen_device<Device>(),
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
transformed_input.tensor<T, 4>());

auto out_backprop_ptr =
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
transformed_out_backprop.template flat<T>().size());
Expand All @@ -1196,10 +1190,16 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
AsDeviceMemory(transformed_input.template flat<T>().data(),
transformed_input.template flat<T>().size());

static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
);
CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
context);
bool cudnn_launch_status =
stream->ThenConvolveBackwardFilter(input_desc, input_ptr, output_desc,
out_backprop_ptr, conv_desc,
filter_desc, &filter_backprop_ptr)
stream->ThenConvolveBackwardFilterWithScratch(
input_desc, input_ptr, output_desc, out_backprop_ptr,
conv_desc, filter_desc, &filter_backprop_ptr,
&scratch_allocator)
.ok();

if (!cudnn_launch_status) {
Expand Down
82 changes: 63 additions & 19 deletions tensorflow/core/kernels/conv_ops.cc
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"
Expand All @@ -34,6 +35,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#endif // GOOGLE_CUDA

namespace tensorflow {
Expand Down Expand Up @@ -206,16 +208,22 @@ REGISTER_KERNEL_BUILDER(Name("Conv2D")

#if GOOGLE_CUDA

namespace {
template <typename T>
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
uint64 size) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
size * sizeof(T));
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
if (workspace_limit_in_mb_str != nullptr &&
strcmp(workspace_limit_in_mb_str, "") != 0) {
int64 scratch_limit_in_mb = -1;
if (strings::safe_strto64(workspace_limit_in_mb_str,
&scratch_limit_in_mb)) {
return scratch_limit_in_mb * (1 << 20);
} else {
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
<< workspace_limit_in_mb_str;
}
}
return default_value_in_bytes;
}
} // namespace

template <typename T>
struct LaunchConvOp<GPUDevice, T> {
Expand Down Expand Up @@ -287,18 +295,34 @@ struct LaunchConvOp<GPUDevice, T> {
input = transformed_input;
}

{
// Convert the input tensor from NHWC to NCHW.
Tensor transformed_input;
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({input.dim_size(0), input.dim_size(3),
input.dim_size(1), input.dim_size(2)}),
&transformed_input));
functor::NHWCToNCHW<GPUDevice, T>()(
ctx->eigen_device<GPUDevice>(),
const_cast<const Tensor&>(input).tensor<T, 4>(),
transformed_input.tensor<T, 4>());
input = transformed_input;
}

perftools::gputools::dnn::BatchDescriptor input_desc;
input_desc.set_count(input.dim_size(0))
.set_height(input.dim_size(1))
.set_width(input.dim_size(2))
.set_feature_map_count(input.dim_size(3))
.set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
.set_feature_map_count(input.dim_size(1))
.set_height(input.dim_size(2))
.set_width(input.dim_size(3))
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
perftools::gputools::dnn::BatchDescriptor output_desc;
output_desc.set_count(output->dim_size(0))
.set_height(output->dim_size(1))
.set_width(output->dim_size(2))
.set_feature_map_count(output->dim_size(3))
.set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
perftools::gputools::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(filter.dim_size(0))
.set_input_filter_width(filter.dim_size(1))
Expand All @@ -320,24 +344,44 @@ struct LaunchConvOp<GPUDevice, T> {
ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));

Tensor transformed_output;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({output->dim_size(0), output->dim_size(3),
output->dim_size(1), output->dim_size(2)}),
&transformed_output));

auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
input.template flat<T>().size());
auto filter_ptr =
AsDeviceMemory(transformed_filter.template flat<T>().data(),
transformed_filter.template flat<T>().size());
auto output_ptr = AsDeviceMemory(output->template flat<T>().data(),
output->template flat<T>().size());

auto output_ptr =
AsDeviceMemory(transformed_output.template flat<T>().data(),
transformed_output.template flat<T>().size());

static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
);
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
bool cudnn_launch_status =
stream->ThenConvolve(input_desc, input_ptr, filter_desc, filter_ptr,
conv_desc, output_desc, &output_ptr)
stream->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc,
filter_ptr, conv_desc, output_desc,
&output_ptr, &scratch_allocator)
.ok();

if (!cudnn_launch_status) {
ctx->SetStatus(errors::Internal(
"cuDNN launch failure : input shape(", input.shape().DebugString(),
") filter shape(", filter.shape().DebugString(), ")"));
}

// Convert the output tensor back from NHWC to NCHW.
functor::NCHWToNHWC<GPUDevice, T>()(
ctx->eigen_device<GPUDevice>(),
const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
output->tensor<T, 4>());
} else {
LaunchGeneric<GPUDevice, T>::launch(ctx, input_param, filter, stride,
padding, output);
Expand Down
84 changes: 84 additions & 0 deletions tensorflow/core/kernels/conv_ops_gpu.h
@@ -0,0 +1,84 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_

#if GOOGLE_CUDA

#include "tensorflow/stream_executor/scratch_allocator.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"

namespace tensorflow {

// TODO(zhengxq): move this to gpu_util.h. The use of such wrapers is wide
// spread.
template <typename T>
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
uint64 size) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
size * sizeof(T));
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
}

// Get the Cudnn workspace limit from the environment variable, which is in MB.
// Return the workspace memory limit in bytes. If no value is set, return the
// default value.
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes);

// A class to provide scratch-space allocator for Stream-Executor Cudnn
// callback. TensorFlow is responsible for releasing the temporary buffers after
// the kernel finishes.
class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
public:
virtual ~CudnnScratchAllocator() {}
CudnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
: memory_limit_(memory_limit), context_(context) {}
virtual int64 GetMemoryLimitInBytes(
perftools::gputools::Stream* stream) override {
return memory_limit_;
}
virtual perftools::gputools::port::StatusOr<
perftools::gputools::DeviceMemory<uint8>>
AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override {
Tensor temporary_memory;

Status allocation_status(context_->allocate_temp(
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
if (!allocation_status.ok()) {
LOG(WARNING) << allocation_status;
context_->SetStatus(allocation_status);
return perftools::gputools::port::StatusOr<
perftools::gputools::DeviceMemory<uint8>>();
}

return perftools::gputools::port::StatusOr<
perftools::gputools::DeviceMemory<uint8>>(
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
}

private:
int64 memory_limit_;
OpKernelContext* context_;
};

} // namespace tensorflow

#endif // GOOGLE_CUDA

#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_
62 changes: 62 additions & 0 deletions tensorflow/python/client/graph_util.py
Expand Up @@ -19,6 +19,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy

import tensorflow.python.platform

Expand Down Expand Up @@ -155,3 +156,64 @@ def pin_to_cpu(op):
logging.info("Operation %s has been assigned to a non-CPU (%s), so "
"it will not be pinned to the CPU.", op.name, dev.device_type)
return device


def _node_name(n):
if n.startswith("^"):
return n[1:]
else:
return n.split(":")[0]


def extract_sub_graph(graph_def, dest_nodes):
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
Args:
graph_def: A graph_pb2.GraphDef proto.
dest_nodes: A list of strings specifying the destination node names.
Returns:
The GraphDef of the sub-graph.
Raises:
TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
"""

if not isinstance(graph_def, graph_pb2.GraphDef):
raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")

edges = {} # Keyed by the dest node name.
name_to_node_map = {} # Keyed by node name.

# Keeps track of node sequences. It is important to still output the
# operations in the original order.
node_seq = {} # Keyed by node name.
seq = 0
for node in graph_def.node:
n = _node_name(node.name)
name_to_node_map[n] = node
edges[n] = [_node_name(x) for x in node.input]
node_seq[n] = seq
seq += 1

for d in dest_nodes:
assert d in name_to_node_map, "%d is not in graph" % d

nodes_to_keep = set()
# Breadth first search to find all the nodes that we should keep.
next_to_visit = dest_nodes[:]
while next_to_visit:
n = next_to_visit[0]
del next_to_visit[0]
if n in nodes_to_keep:
# Already visited this node.
continue
nodes_to_keep.add(n)
next_to_visit += edges[n]

nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
# Now construct the output GraphDef
out = graph_pb2.GraphDef()
for n in nodes_to_keep_list:
out.node.extend([copy.deepcopy(name_to_node_map[n])])

return out

0 comments on commit 2c3738d

Please sign in to comment.