Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream Executor: perform a cudnn library vs. binary check, add relu6/x support. #1987

Merged
merged 1 commit into from Apr 19, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 24 additions & 6 deletions tensorflow/stream_executor/cuda/cuda_dnn.cc
Expand Up @@ -283,6 +283,22 @@ port::Status CudnnSupport::Init() {
auto status = dynload::cudnnCreate(
parent_, reinterpret_cast<cudnnHandle_t*>(&dnn_handle_));
if (status == CUDNN_STATUS_SUCCESS) {
// Check whether loaded version of CuDNN matches what the source
// was built with.
size_t loaded_version = dynload::cudnnGetVersion();
bool library_loaded_matches_source = (loaded_version == CUDNN_VERSION);
if (!library_loaded_matches_source) {
const string error =
port::StrCat("Loaded cudnn library: ", loaded_version,
" but source was compiled against ", CUDNN_VERSION,
". If using a binary install, upgrade your cudnn "
"library to match. If building from sources, "
"make sure the library loaded matches the "
"version you specified during compile configuration.");
LOG(ERROR) << error;
return port::Status{port::error::INTERNAL, error};
}

return port::Status::OK();
}

Expand All @@ -304,6 +320,7 @@ port::Status CudnnSupport::Init() {
}
}
}

return port::Status{port::error::INTERNAL,
port::StrCat("cudnn library could not create a handle: ",
ToString(status))};
Expand Down Expand Up @@ -561,7 +578,8 @@ class ScopedPoolingDescriptor {
class ScopedActivationDescriptor {
public:
ScopedActivationDescriptor(CUDAExecutor* parent,
dnn::ActivationMode activation_mode)
dnn::ActivationMode activation_mode,
double value_max)
: parent_(parent), handle_(nullptr) {
cudnnStatus_t status =
dynload::cudnnCreateActivationDescriptor(parent_, &handle_);
Expand All @@ -575,12 +593,11 @@ class ScopedActivationDescriptor {
switch (activation_mode) {
case dnn::ActivationMode::kRelu6:
relu_ceiling = 6.0;
mode = CUDNN_ACTIVATION_RELU;
mode = CUDNN_ACTIVATION_CLIPPED_RELU;
break;
case dnn::ActivationMode::kReluX:
// TODO(leary) should probably do a post-pass to clip at X?
LOG(WARNING) << "user requested ReluX, but providing Relu instead";
mode = CUDNN_ACTIVATION_RELU;
relu_ceiling = value_max;
mode = CUDNN_ACTIVATION_CLIPPED_RELU;
break;
case dnn::ActivationMode::kRelu:
mode = CUDNN_ACTIVATION_RELU;
Expand Down Expand Up @@ -1272,7 +1289,8 @@ bool CudnnSupport::DoActivate(Stream* stream,
}

#if CUDNN_VERSION >= 5000
ScopedActivationDescriptor activation_desc{parent_, activation_mode};
ScopedActivationDescriptor activation_desc{parent_, activation_mode,
dimensions.value_max()};
#else
cudnnActivationMode_t mode;
switch (activation_mode) {
Expand Down