Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent CHECK-fail in LSTM/GRU with zero-length input.
PiperOrigin-RevId: 346239181
Change-Id: I5f233dbc076aab7bb4e31ba24f5abd4eaf99ea4f
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Dec 8, 2020
1 parent 042a692 commit 1475541
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tensorflow/stream_executor/cuda/cuda_dnn.cc
Expand Up @@ -1468,7 +1468,9 @@ class CudnnRnnSequenceTensorDescriptor
static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
cudnnDataType_t data_type) {
CHECK_GT(max_seq_length, 0);
if (max_seq_length <= 0) {
return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
}
int dims[] = {batch_size, data_size, 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
TensorDescriptor tensor_desc = CreateTensorDescriptor();
Expand All @@ -1486,7 +1488,9 @@ class CudnnRnnSequenceTensorDescriptor
GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
const absl::Span<const int>& seq_lengths, bool time_major,
cudnnDataType_t data_type) {
CHECK_GT(max_seq_length, 0);
if (max_seq_length <= 0) {
return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
}
int dims[] = {batch_size, data_size, 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
TensorDescriptor tensor_desc = CreateTensorDescriptor();
Expand Down

0 comments on commit 1475541

Please sign in to comment.