diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index bf175999c55412..95a5827bf44095 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/proto/proto_utils.h" #include "tensorflow/stream_executor/gpu/redzone_allocator.h" @@ -309,6 +310,35 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( return result_or; } +// The following function allows deterministic ops to be implemented relatively +// quickly using environment variables. It is intended to be temporary. The +// longer-term intention is to enable deterministic ops via tf.config and +// appropriate plumbing. See the discussion on PR 34951 for more information: +// https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316 +// This function and associated comment are replicated in the following three +// places: +// 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +// 2. tensorflow/core/kernels/gpu_utils.cc +// 3. tensorflow/stream_executor/cuda/cuda_dnn.cc +// When implementing the plumbing, you should also search for the use of +// TF_DETERMINISTIC_OPS on its own. +// TODO(duncanriach): move to an API that uses tf.config and implement the first +// phase of plumbing. +bool RequireCudnnDeterminism() { + static bool require_cudnn_determinism = [] { + bool deterministic_ops = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", + /*default_val=*/false, + &deterministic_ops)); + bool cudnn_deterministic = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC", + /*default_val=*/false, + &cudnn_deterministic)); + return deterministic_ops || cudnn_deterministic; + }(); + return require_cudnn_determinism; +} + StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, @@ -536,43 +566,41 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( } } - // For now, we ignore WRONG_RESULT failures because false-positives are - // possible (e.g. perhaps the reference algorithm is the one that's - // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're - // quite severe and can be detected with high accuracy. - auto has_failure = [](const AutotuneResult& r) { - return r.has_failure() && - r.failure().kind() != AutotuneResult::WRONG_RESULT; - }; - // Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED // error. // // TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs // in the output of the conv and skip those. // - // The successful one should have a smaller key, since we are doing - // min_element. If they are both unsuccessful, keep the earlier one in - // the vector by comparing pointers. - auto result_comparison_key = [&has_failure](const AutotuneResult& r) { - return std::make_tuple( - has_failure(r), - tensorflow::proto_utils::FromDurationProto(r.run_time())); - }; - const auto& best_result = absl::c_min_element( - profile_results, - [&](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return result_comparison_key(lhs) < result_comparison_key(rhs); + // For now, we ignore WRONG_RESULT failures because false-positives are + // possible (e.g. perhaps the reference algorithm is the one that's + // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're + // quite severe and can be detected with high accuracy. + std::vector filtered_results; + absl::c_copy_if( + profile_results, std::back_inserter(filtered_results), + [](const AutotuneResult& r) { + return !(r.has_failure() && + r.failure().kind() != AutotuneResult::WRONG_RESULT); }); + if (filtered_results.empty()) { + return InternalError( + "All algorithms tried for convolution %s failed. Falling back to " + "default algorithm. ", + instr->ToString()); + } - if (best_result != profile_results.end() && !has_failure(*best_result)) { - return *best_result; + auto selected_result = filtered_results.begin(); + if (!RequireCudnnDeterminism()) { + selected_result = absl::c_min_element( + filtered_results, + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) < + tensorflow::proto_utils::FromDurationProto(rhs.run_time()); + }); } - return InternalError( - "All algorithms tried for convolution %s failed. Falling back to " - "default algorithm.", - instr->ToString()); + return *selected_result; } StatusOr diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc index 4681c624eb4985..d62e6498376eae 100644 --- a/tensorflow/core/kernels/gpu_utils.cc +++ b/tensorflow/core/kernels/gpu_utils.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/protobuf/conv_autotuning.pb.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/proto/proto_utils.h" #include "tensorflow/stream_executor/gpu/asm_compiler.h" #include "tensorflow/stream_executor/gpu/redzone_allocator.h" @@ -211,6 +212,35 @@ void LogFusedConvForwardAutotuneResults( Logger::GetSingleton()->LogProto(log); } +// The following function allows deterministic ops to be implemented relatively +// quickly using environment variables. It is intended to be temporary. The +// longer-term intention is to enable deterministic ops via tf.config and +// appropriate plumbing. See the discussion on PR 34951 for more information: +// https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316 +// This function and associated comment are replicated in the following three +// places: +// 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +// 2. tensorflow/core/kernels/gpu_utils.cc +// 3. tensorflow/stream_executor/cuda/cuda_dnn.cc +// When implementing the plumbing, you should also search for the use of +// TF_DETERMINISTIC_OPS on its own. +// TODO(duncanriach): move to an API that uses tf.config and implement the first +// phase of plumbing. +bool RequireCudnnDeterminism() { + static bool require_cudnn_determinism = [] { + bool deterministic_ops = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", + /*default_val=*/false, + &deterministic_ops)); + bool cudnn_deterministic = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC", + /*default_val=*/false, + &cudnn_deterministic)); + return deterministic_ops || cudnn_deterministic; + }(); + return require_cudnn_determinism; +} + Status BestCudnnConvAlgorithm(absl::Span results, se::dnn::AlgorithmConfig* algo) { std::vector filtered_results; @@ -220,31 +250,32 @@ Status BestCudnnConvAlgorithm(absl::Span results, if (filtered_results.empty()) { return errors::NotFound("No algorithm worked!"); } + std::vector filtered_results_no_scratch; + absl::c_copy_if( + filtered_results, std::back_inserter(filtered_results_no_scratch), + [](const AutotuneResult& result) { return result.scratch_bytes() == 0; }); + + auto selected_result = filtered_results.begin(); + auto selected_result_no_scratch = filtered_results_no_scratch.begin(); + if (!RequireCudnnDeterminism()) { + auto compare_run_times = [](const AutotuneResult& lhs, + const AutotuneResult& rhs) { + return proto_utils::FromDurationProto(lhs.run_time()) < + proto_utils::FromDurationProto(rhs.run_time()); + }; + selected_result = absl::c_min_element(filtered_results, compare_run_times); + selected_result_no_scratch = absl::c_min_element( + filtered_results_no_scratch, compare_run_times); + } - const auto best_result = absl::c_min_element( - filtered_results, - [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return proto_utils::FromDurationProto(lhs.run_time()) < - proto_utils::FromDurationProto(rhs.run_time()); - }); - - const auto best_result_no_scratch = absl::c_min_element( - filtered_results, - [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return std::make_tuple(lhs.scratch_bytes(), - proto_utils::FromDurationProto(lhs.run_time())) < - std::make_tuple(rhs.scratch_bytes(), - proto_utils::FromDurationProto(rhs.run_time())); - }); - - algo->set_algorithm({best_result->conv().algorithm(), - best_result->conv().tensor_ops_enabled()}); - if (best_result_no_scratch != filtered_results.end() && - best_result_no_scratch->scratch_bytes() == 0) { + algo->set_algorithm({selected_result->conv().algorithm(), + selected_result->conv().tensor_ops_enabled()}); + if (selected_result_no_scratch != filtered_results_no_scratch.end()) { algo->set_algorithm_no_scratch( - {best_result_no_scratch->conv().algorithm(), - best_result_no_scratch->conv().tensor_ops_enabled()}); + {selected_result_no_scratch->conv().algorithm(), + selected_result_no_scratch->conv().tensor_ops_enabled()}); } + return Status::OK(); } diff --git a/tensorflow/python/kernel_tests/cudnn_deterministic_base.py b/tensorflow/python/kernel_tests/cudnn_deterministic_base.py index 2b526f0ec6bb9f..07a4492a99936d 100644 --- a/tensorflow/python/kernel_tests/cudnn_deterministic_base.py +++ b/tensorflow/python/kernel_tests/cudnn_deterministic_base.py @@ -27,22 +27,39 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test -# Setting either of the two environment variables TF_CUDNN_DETERMINISTIC or -# TF_DETERMINISTIC_OPS to "true" or "1" will disable autotuning of cuDNN -# algorithms and cause deterministic cuDNN algorithms to be selected when both -# deterministic and non-deterministic algorithms are available. These tests are -# intended to confirm that deterministic algorithms are chosen when either -# environment variable is set to "true" or "1". The tested configurations were -# first confirmed to produce non-deterministic results when the environment -# variables are not set. +# Notes: +# +# Deterministic cuDNN operation is selected by setting either of the two +# environment variables TF_CUDNN_DETERMINISTIC or TF_DETERMINISTIC_OPS to 'true' +# or '1' while also not setting the environment variable TF_CUDNN_USE_AUTOTUNE +# to 'false' or '0'. +# +# Where both deterministic and non-deterministic cuDNN algorithms are available, +# selecting determinitic operation will lead to only the deterministic +# algorithms being chosen. Additionally, selecting deterministic operation will +# result in a deterministic, or reproducible, selection of algorithms (for any +# given layer configuration) for each of the forward and the two backward paths. +# +# These tests intend to confirm that deterministic algorithms are chosen (for +# the back-prop paths) when desterministic operation is selected. The tested +# configurations were first confirmed to produce non-deterministic results when +# the above-mentioned environment variables are not set. +# +# Even though selecting determinitic operation should ensure that the same +# algorithms, for a given layer configuration, are always used (i.e. that +# algorithm selection is deterministic / reproducible), this is not tested. + +# TODO(duncanriach): Add test for deterministic cuDNN max-pooling -_PADDING = 'SAME' -_STRIDES = [1, 1, 1, 1] +LayerShapeNHWC = collections.namedtuple('LayerShapeNHWC', + 'batch, height, width, channels') +FilterShape2D = collections.namedtuple( + 'FilterShape2D', 'height, width, in_channels, out_channels') -LayerShape = collections.namedtuple('LayerShape', - 'batch, height, width, channels') -FilterShape = collections.namedtuple( - 'FilterShape', 'height, width, in_channels, out_channels') +LayerShapeNCDHW = collections.namedtuple( + 'LayerShapeNCDHW', 'batch, channels, depth, height, width') +FilterShape3D = collections.namedtuple( + 'FilterShape3D', 'depth, height, width, in_channels, out_channels') class ConvolutionTest(test.TestCase): @@ -53,14 +70,14 @@ def _random_data_op(self, shape): return constant_op.constant( 2 * np.random.random_sample(shape) - 1, dtype=dtypes.float32) - def _random_out_op(self, in_shape, filter_shape): + def _random_out_op(self, in_shape, filter_shape, strides, padding): # Choosing not to use array_op.zeros() to prevent possible removal by # optimization in_op = self._random_data_op(in_shape) filter_op = self._random_data_op(filter_shape) # Use the forward op's shape-inference conv_op = nn_ops.conv2d( - in_op, filter_op, strides=_STRIDES, padding=_PADDING) + in_op, filter_op, strides=strides, padding=padding) out_shape = conv_op.get_shape() out_op = self._random_data_op(out_shape) return out_op @@ -71,29 +88,49 @@ def _assert_reproducible(self, operation): result_2 = self.evaluate(operation) self.assertAllEqual(result_1, result_2) + # The default forward algorithm choice, when using cuDNN 7, does not support + # the following layer configuration. This test case intends to confirm that + # an alternative algorithm is selected. Note that, in cuDNN 7, all forward + # algorithms are determnistic. + @test_util.run_cuda_only + def testForward(self): + np.random.seed(3) + in_shape = LayerShapeNCDHW(batch=2, channels=3, depth=5, height=7, width=6) + filter_shape = FilterShape3D(depth=3, height=3, width=3, in_channels=3, + out_channels=2) + in_op = self._random_data_op(in_shape) + filter_op = self._random_data_op(filter_shape) + strides = [1, 1, 1, 1, 1] + padding = 'VALID' + dilations = [1, 1, 2, 2, 2] + out_op = nn_ops.conv3d(in_op, filter_op, strides=strides, padding=padding, + data_format='NCDHW', dilations=dilations) + self._assert_reproducible(out_op) + @test_util.run_cuda_only def testBackwardFilterGradient(self): np.random.seed(1) - in_shape = LayerShape(batch=8, height=128, width=128, channels=8) - filter_shape = FilterShape(height=3, width=3, in_channels=8, out_channels=8) + in_shape = LayerShapeNHWC(batch=8, height=128, width=128, channels=8) + filter_shape = FilterShape2D(height=3, width=3, in_channels=8, + out_channels=8) in_op = self._random_data_op(in_shape) - out_op = self._random_out_op(in_shape, filter_shape) + strides = [1, 1, 1, 1] + padding = 'SAME' + out_op = self._random_out_op(in_shape, filter_shape, strides, padding) filter_gradient_op = nn_ops.conv2d_backprop_filter( - in_op, filter_shape, out_op, strides=_STRIDES, padding=_PADDING) + in_op, filter_shape, out_op, strides=strides, padding=padding) self._assert_reproducible(filter_gradient_op) @test_util.run_cuda_only def testBackwardInputGradient(self): np.random.seed(2) - in_shape = LayerShape(batch=8, height=32, width=32, channels=8) - filter_shape = FilterShape( - height=7, width=7, in_channels=8, out_channels=128) + in_shape = LayerShapeNHWC(batch=8, height=32, width=32, channels=8) + filter_shape = FilterShape2D(height=7, width=7, in_channels=8, + out_channels=128) filter_op = self._random_data_op(filter_shape) - out_op = self._random_out_op(in_shape, filter_shape) + strides = [1, 1, 1, 1] + padding = 'SAME' + out_op = self._random_out_op(in_shape, filter_shape, strides, padding) input_gradient_op = nn_ops.conv2d_backprop_input( - in_shape, filter_op, out_op, strides=_STRIDES, padding=_PADDING) + in_shape, filter_op, out_op, strides=strides, padding=padding) self._assert_reproducible(input_gradient_op) - - # TODO(duncanriach): (1) add test to confirm that forward autotuning is - # disabled for cuDNN convolution; (2) add test for deterministic cuDNN - # max-pooling diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 70cc11a3e03148..2c640d4ff0b58b 100755 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -630,9 +630,22 @@ bool BatchnormSpatialPersistentEnabled() { return is_enabled; } -// A helper function to decide whether to enable deterministic functionality. -bool RequireDeterminism() { - static bool require_determinism = [] { +// The following function allows deterministic ops to be implemented relatively +// quickly using environment variables. It is intended to be temporary. The +// longer-term intention is to enable deterministic ops via tf.config and +// appropriate plumbing. See the discussion on PR 34951 for more information: +// https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316 +// This function and associated comment are replicated in the following three +// places: +// 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +// 2. tensorflow/core/kernels/gpu_utils.cc +// 3. tensorflow/stream_executor/cuda/cuda_dnn.cc +// When implementing the plumbing, you should also search for the use of +// TF_DETERMINISTIC_OPS on its own. +// TODO(duncanriach): move to an API that uses tf.config and implement the first +// phase of plumbing. +bool RequireCudnnDeterminism() { + static bool require_cudnn_determinism = [] { bool deterministic_ops = false; TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", /*default_val=*/false, @@ -643,7 +656,7 @@ bool RequireDeterminism() { &cudnn_deterministic)); return deterministic_ops || cudnn_deterministic; }(); - return require_determinism; + return require_cudnn_determinism; } std::tuple GetCcMajorMinor(Stream* stream) { @@ -744,7 +757,7 @@ class CudnnPoolingDescriptor { std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), &CheckedNarrowing); bool propagate_nans = pooling_descriptor.propagate_nans(); - const auto cudnn_max_pooling_mode = RequireDeterminism() + const auto cudnn_max_pooling_mode = RequireCudnnDeterminism() ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX; CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor( @@ -3247,16 +3260,10 @@ bool CudnnSupport::GetConvolveAlgorithms( bool tensor_op_math_available = TensorOpMathAvailable(cc_major); out_algorithms->clear(); - if (RequireDeterminism()) { - out_algorithms->push_back({CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, - tensor_op_math_available}); - return true; - } - std::vector algo_types = { // clang-format off - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, CUDNN_CONVOLUTION_FWD_ALGO_FFT, @@ -3270,11 +3277,12 @@ bool CudnnSupport::GetConvolveAlgorithms( algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED); } + // The algorithms are intentionally ordered for deterministic operation for (auto i : algo_types) { - out_algorithms->push_back({i, /*use_tensor_ops=*/false}); if (tensor_op_math_available) { out_algorithms->push_back({i, /*use_tensor_ops=*/true}); } + out_algorithms->push_back({i, /*use_tensor_ops=*/false}); } return true; @@ -3308,15 +3316,8 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( bool tensor_op_math_available = TensorOpMathAvailable(cc_major); out_algorithms->clear(); - if (RequireDeterminism()) { - out_algorithms->push_back( - {CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, tensor_op_math_available}); - return true; - } - std::vector algo_types = { // clang-format off - CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, @@ -3326,12 +3327,16 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); } + if (!RequireCudnnDeterminism()) { + algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0); + } + // The algorithms are intentionally ordered for deterministic operation for (auto i : algo_types) { - out_algorithms->push_back({i, /*use_tensor_ops=*/false}); if (tensor_op_math_available) { out_algorithms->push_back({i, /*use_tensor_ops=*/true}); } + out_algorithms->push_back({i, /*use_tensor_ops=*/false}); } return true; @@ -3343,18 +3348,10 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( bool tensor_op_math_available = TensorOpMathAvailable(cc_major); out_algorithms->clear(); - if (RequireDeterminism()) { - out_algorithms->push_back( - {CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, tensor_op_math_available}); - return true; - } - std::vector algo_types = { // clang-format off - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, // Based on cudnn.h, the following is not implemented. // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD, @@ -3366,12 +3363,17 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED); } + if (!RequireCudnnDeterminism()) { + algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0); + algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3); + } + // The algorithms are intentionally ordered for deterministic operation for (auto i : algo_types) { - out_algorithms->push_back({i, /*use_tensor_ops=*/false}); if (tensor_op_math_available) { out_algorithms->push_back({i, /*use_tensor_ops=*/true}); } + out_algorithms->push_back({i, /*use_tensor_ops=*/false}); } return true;