Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions aten/src/ATen/native/cudnn/LossCTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_assert_async.h>
#include <ATen/ops/_cudnn_ctc_loss.h>
#include <ATen/ops/_cudnn_ctc_loss_native.h>
#include <ATen/ops/_use_cudnn_ctc_loss.h>
#include <ATen/ops/_use_cudnn_ctc_loss_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/le.h>
#include <ATen/ops/lt.h>
#endif

#if (!AT_CUDNN_ENABLED())
Expand Down Expand Up @@ -81,11 +84,6 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
namespace at {
namespace native {

namespace {
// "cache" whether we've previously failed the target lengths check
static bool tensor_failed_target_lengths_check = false;
} // namespace

bool _use_cudnn_ctc_loss(
const Tensor& log_probs,
const Tensor& targets,
Expand Down Expand Up @@ -132,29 +130,27 @@ bool _use_cudnn_ctc_loss_tensor(
(log_probs.dim() == 3) && (input_lengths.scalar_type() == at::kInt) &&
(target_lengths.scalar_type() == at::kInt);

if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
for (const auto b : c10::irange(tl.size())) {
// target length < 256 is documented, but we see illegal memory accesses
// when target lengths > input lengths for CuDNN
Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
if (use_cudnn) {
if (at::cuda::currentStreamCaptureStatus() ==
at::cuda::CaptureStatus::None) {
Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
use_cudnn = use_cudnn && (tl[b] < 256) && (tl[b] <= il[b]);
if (!use_cudnn) {
tensor_failed_target_lengths_check = true;
break;
for (const auto b : c10::irange(tl.size())) {
// target length < 256 is documented, but we see illegal memory accesses
// when target lengths > input lengths for CuDNN
Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
Tensor tlc =
target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
IntArrayRef il(ilc.const_data_ptr<int64_t>(), ilc.numel());
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
use_cudnn = use_cudnn && (tl[b] < 256) && (tl[b] <= il[b]);
if (!use_cudnn) {
break;
}
}
}
} else {
use_cudnn = use_cudnn && !tensor_failed_target_lengths_check;
if (tensor_failed_target_lengths_check) {
TORCH_WARN(
"cuDNN max target length restriction < 256 cannot be checked during graph capture,"
" but target length >= 256 was observed previously e.g., during warmup, so we"
" presume it is unsafe to dispatch to cuDNN ctc_loss.");
} else {
at::_assert_async(at::lt(input_lengths.max(), 256));
at::_assert_async(at::le(target_lengths, input_lengths).all());
}
}

Expand Down
Loading