Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-algorithm deterministic cuDNN convolutions #34951

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
82 changes: 55 additions & 27 deletions tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logger.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/proto/proto_utils.h"
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"

Expand Down Expand Up @@ -309,6 +310,35 @@ StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
return result_or;
}

// The following function allows deterministic ops to be implemented relatively
// quickly using environment variables. It is intended to be temporary. The
// longer-term intention is to enable deterministic ops via tf.config and
// appropriate plumbing. See the discussion on PR 34951 for more information:
// https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316
// This function and associated comment are replicated in the following three
// places:
// 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
// 2. tensorflow/core/kernels/gpu_utils.cc
// 3. tensorflow/stream_executor/cuda/cuda_dnn.cc
// When implementing the plumbing, you should also search for the use of
// TF_DETERMINISTIC_OPS on its own.
// TODO(duncanriach): move to an API that uses tf.config and implement the first
// phase of plumbing.
bool RequireCudnnDeterminism() {
static bool require_cudnn_determinism = [] {
bool deterministic_ops = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
/*default_val=*/false,
&deterministic_ops));
bool cudnn_deterministic = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
/*default_val=*/false,
&cudnn_deterministic));
return deterministic_ops || cudnn_deterministic;
}();
return require_cudnn_determinism;
}

StatusOr<tensorflow::AutotuneResult>
GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
Expand Down Expand Up @@ -536,43 +566,41 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
}
}

// For now, we ignore WRONG_RESULT failures because false-positives are
// possible (e.g. perhaps the reference algorithm is the one that's
// incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
// quite severe and can be detected with high accuracy.
auto has_failure = [](const AutotuneResult& r) {
return r.has_failure() &&
r.failure().kind() != AutotuneResult::WRONG_RESULT;
};

// Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED
// error.
//
// TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs
// in the output of the conv and skip those.
//
// The successful one should have a smaller key, since we are doing
// min_element. If they are both unsuccessful, keep the earlier one in
// the vector by comparing pointers.
auto result_comparison_key = [&has_failure](const AutotuneResult& r) {
return std::make_tuple(
has_failure(r),
tensorflow::proto_utils::FromDurationProto(r.run_time()));
};
const auto& best_result = absl::c_min_element(
profile_results,
[&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return result_comparison_key(lhs) < result_comparison_key(rhs);
// For now, we ignore WRONG_RESULT failures because false-positives are
// possible (e.g. perhaps the reference algorithm is the one that's
// incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
// quite severe and can be detected with high accuracy.
std::vector<AutotuneResult> filtered_results;
absl::c_copy_if(
profile_results, std::back_inserter(filtered_results),
[](const AutotuneResult& r) {
return !(r.has_failure() &&
r.failure().kind() != AutotuneResult::WRONG_RESULT);
});
if (filtered_results.empty()) {
return InternalError(
"All algorithms tried for convolution %s failed. Falling back to "
"default algorithm. ",
instr->ToString());
}

if (best_result != profile_results.end() && !has_failure(*best_result)) {
return *best_result;
auto selected_result = filtered_results.begin();
if (!RequireCudnnDeterminism()) {
selected_result = absl::c_min_element(
filtered_results,
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
tensorflow::proto_utils::FromDurationProto(rhs.run_time());
});
}

return InternalError(
"All algorithms tried for convolution %s failed. Falling back to "
"default algorithm.",
instr->ToString());
return *selected_result;
}

StatusOr<tensorflow::AutotuneResult>
Expand Down
75 changes: 53 additions & 22 deletions tensorflow/core/kernels/gpu_utils.cc
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/platform/logger.h"
#include "tensorflow/core/protobuf/autotuning.pb.h"
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/proto/proto_utils.h"
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
Expand Down Expand Up @@ -211,6 +212,35 @@ void LogFusedConvForwardAutotuneResults(
Logger::GetSingleton()->LogProto(log);
}

// The following function allows deterministic ops to be implemented relatively
// quickly using environment variables. It is intended to be temporary. The
// longer-term intention is to enable deterministic ops via tf.config and
// appropriate plumbing. See the discussion on PR 34951 for more information:
// https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316
// This function and associated comment are replicated in the following three
// places:
// 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
// 2. tensorflow/core/kernels/gpu_utils.cc
// 3. tensorflow/stream_executor/cuda/cuda_dnn.cc
// When implementing the plumbing, you should also search for the use of
// TF_DETERMINISTIC_OPS on its own.
// TODO(duncanriach): move to an API that uses tf.config and implement the first
// phase of plumbing.
bool RequireCudnnDeterminism() {
static bool require_cudnn_determinism = [] {
bool deterministic_ops = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
/*default_val=*/false,
&deterministic_ops));
bool cudnn_deterministic = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
/*default_val=*/false,
&cudnn_deterministic));
return deterministic_ops || cudnn_deterministic;
}();
return require_cudnn_determinism;
}

Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
se::dnn::AlgorithmConfig* algo) {
std::vector<AutotuneResult> filtered_results;
Expand All @@ -220,31 +250,32 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
if (filtered_results.empty()) {
return errors::NotFound("No algorithm worked!");
}
std::vector<AutotuneResult> filtered_results_no_scratch;
absl::c_copy_if(
filtered_results, std::back_inserter(filtered_results_no_scratch),
[](const AutotuneResult& result) { return result.scratch_bytes() == 0; });

auto selected_result = filtered_results.begin();
auto selected_result_no_scratch = filtered_results_no_scratch.begin();
if (!RequireCudnnDeterminism()) {
auto compare_run_times = [](const AutotuneResult& lhs,
const AutotuneResult& rhs) {
return proto_utils::FromDurationProto(lhs.run_time()) <
proto_utils::FromDurationProto(rhs.run_time());
};
selected_result = absl::c_min_element(filtered_results, compare_run_times);
selected_result_no_scratch = absl::c_min_element(
filtered_results_no_scratch, compare_run_times);
}

const auto best_result = absl::c_min_element(
filtered_results,
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return proto_utils::FromDurationProto(lhs.run_time()) <
proto_utils::FromDurationProto(rhs.run_time());
});

const auto best_result_no_scratch = absl::c_min_element(
filtered_results,
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return std::make_tuple(lhs.scratch_bytes(),
proto_utils::FromDurationProto(lhs.run_time())) <
std::make_tuple(rhs.scratch_bytes(),
proto_utils::FromDurationProto(rhs.run_time()));
});

algo->set_algorithm({best_result->conv().algorithm(),
best_result->conv().tensor_ops_enabled()});
if (best_result_no_scratch != filtered_results.end() &&
best_result_no_scratch->scratch_bytes() == 0) {
algo->set_algorithm({selected_result->conv().algorithm(),
selected_result->conv().tensor_ops_enabled()});
if (selected_result_no_scratch != filtered_results_no_scratch.end()) {
algo->set_algorithm_no_scratch(
{best_result_no_scratch->conv().algorithm(),
best_result_no_scratch->conv().tensor_ops_enabled()});
{selected_result_no_scratch->conv().algorithm(),
selected_result_no_scratch->conv().tensor_ops_enabled()});
}

return Status::OK();
}

Expand Down
95 changes: 66 additions & 29 deletions tensorflow/python/kernel_tests/cudnn_deterministic_base.py
Expand Up @@ -27,22 +27,39 @@
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test

# Setting either of the two environment variables TF_CUDNN_DETERMINISTIC or
# TF_DETERMINISTIC_OPS to "true" or "1" will disable autotuning of cuDNN
# algorithms and cause deterministic cuDNN algorithms to be selected when both
# deterministic and non-deterministic algorithms are available. These tests are
# intended to confirm that deterministic algorithms are chosen when either
# environment variable is set to "true" or "1". The tested configurations were
# first confirmed to produce non-deterministic results when the environment
# variables are not set.
# Notes:
#
# Deterministic cuDNN operation is selected by setting either of the two
# environment variables TF_CUDNN_DETERMINISTIC or TF_DETERMINISTIC_OPS to 'true'
# or '1' while also not setting the environment variable TF_CUDNN_USE_AUTOTUNE
# to 'false' or '0'.
#
# Where both deterministic and non-deterministic cuDNN algorithms are available,
# selecting determinitic operation will lead to only the deterministic
# algorithms being chosen. Additionally, selecting deterministic operation will
# result in a deterministic, or reproducible, selection of algorithms (for any
# given layer configuration) for each of the forward and the two backward paths.
#
# These tests intend to confirm that deterministic algorithms are chosen (for
# the back-prop paths) when desterministic operation is selected. The tested
# configurations were first confirmed to produce non-deterministic results when
# the above-mentioned environment variables are not set.
#
# Even though selecting determinitic operation should ensure that the same
# algorithms, for a given layer configuration, are always used (i.e. that
# algorithm selection is deterministic / reproducible), this is not tested.

# TODO(duncanriach): Add test for deterministic cuDNN max-pooling

_PADDING = 'SAME'
_STRIDES = [1, 1, 1, 1]
LayerShapeNHWC = collections.namedtuple('LayerShapeNHWC',
'batch, height, width, channels')
FilterShape2D = collections.namedtuple(
'FilterShape2D', 'height, width, in_channels, out_channels')

LayerShape = collections.namedtuple('LayerShape',
'batch, height, width, channels')
FilterShape = collections.namedtuple(
'FilterShape', 'height, width, in_channels, out_channels')
LayerShapeNCDHW = collections.namedtuple(
'LayerShapeNCDHW', 'batch, channels, depth, height, width')
FilterShape3D = collections.namedtuple(
'FilterShape3D', 'depth, height, width, in_channels, out_channels')


class ConvolutionTest(test.TestCase):
Expand All @@ -53,14 +70,14 @@ def _random_data_op(self, shape):
return constant_op.constant(
2 * np.random.random_sample(shape) - 1, dtype=dtypes.float32)

def _random_out_op(self, in_shape, filter_shape):
def _random_out_op(self, in_shape, filter_shape, strides, padding):
# Choosing not to use array_op.zeros() to prevent possible removal by
# optimization
in_op = self._random_data_op(in_shape)
filter_op = self._random_data_op(filter_shape)
# Use the forward op's shape-inference
conv_op = nn_ops.conv2d(
in_op, filter_op, strides=_STRIDES, padding=_PADDING)
in_op, filter_op, strides=strides, padding=padding)
out_shape = conv_op.get_shape()
out_op = self._random_data_op(out_shape)
return out_op
Expand All @@ -71,29 +88,49 @@ def _assert_reproducible(self, operation):
result_2 = self.evaluate(operation)
self.assertAllEqual(result_1, result_2)

# The default forward algorithm choice, when using cuDNN 7, does not support
# the following layer configuration. This test case intends to confirm that
# an alternative algorithm is selected. Note that, in cuDNN 7, all forward
# algorithms are determnistic.
@test_util.run_cuda_only
def testForward(self):
np.random.seed(3)
in_shape = LayerShapeNCDHW(batch=2, channels=3, depth=5, height=7, width=6)
filter_shape = FilterShape3D(depth=3, height=3, width=3, in_channels=3,
out_channels=2)
in_op = self._random_data_op(in_shape)
filter_op = self._random_data_op(filter_shape)
strides = [1, 1, 1, 1, 1]
padding = 'VALID'
dilations = [1, 1, 2, 2, 2]
out_op = nn_ops.conv3d(in_op, filter_op, strides=strides, padding=padding,
data_format='NCDHW', dilations=dilations)
self._assert_reproducible(out_op)

@test_util.run_cuda_only
def testBackwardFilterGradient(self):
np.random.seed(1)
in_shape = LayerShape(batch=8, height=128, width=128, channels=8)
filter_shape = FilterShape(height=3, width=3, in_channels=8, out_channels=8)
in_shape = LayerShapeNHWC(batch=8, height=128, width=128, channels=8)
filter_shape = FilterShape2D(height=3, width=3, in_channels=8,
out_channels=8)
in_op = self._random_data_op(in_shape)
out_op = self._random_out_op(in_shape, filter_shape)
strides = [1, 1, 1, 1]
padding = 'SAME'
out_op = self._random_out_op(in_shape, filter_shape, strides, padding)
filter_gradient_op = nn_ops.conv2d_backprop_filter(
in_op, filter_shape, out_op, strides=_STRIDES, padding=_PADDING)
in_op, filter_shape, out_op, strides=strides, padding=padding)
self._assert_reproducible(filter_gradient_op)

@test_util.run_cuda_only
def testBackwardInputGradient(self):
np.random.seed(2)
in_shape = LayerShape(batch=8, height=32, width=32, channels=8)
filter_shape = FilterShape(
height=7, width=7, in_channels=8, out_channels=128)
in_shape = LayerShapeNHWC(batch=8, height=32, width=32, channels=8)
filter_shape = FilterShape2D(height=7, width=7, in_channels=8,
out_channels=128)
filter_op = self._random_data_op(filter_shape)
out_op = self._random_out_op(in_shape, filter_shape)
strides = [1, 1, 1, 1]
padding = 'SAME'
out_op = self._random_out_op(in_shape, filter_shape, strides, padding)
input_gradient_op = nn_ops.conv2d_backprop_input(
in_shape, filter_op, out_op, strides=_STRIDES, padding=_PADDING)
in_shape, filter_op, out_op, strides=strides, padding=padding)
self._assert_reproducible(input_gradient_op)

# TODO(duncanriach): (1) add test to confirm that forward autotuning is
# disabled for cuDNN convolution; (2) add test for deterministic cuDNN
# max-pooling