diff --git a/src/libtorchaudio/forced_align/gpu/compute.cu b/src/libtorchaudio/forced_align/gpu/compute.cu index a78f694b51..e3ee88c639 100644 --- a/src/libtorchaudio/forced_align/gpu/compute.cu +++ b/src/libtorchaudio/forced_align/gpu/compute.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -115,6 +116,10 @@ void forced_align_impl( const torch::Tensor& targets, const int64_t blank, torch::Tensor& paths) { + + // only guard logProbs, since in L263 it'll be verified on targets + const at::cuda::OptionalCUDAGuard device_guard(logProbs.device()); + auto defaultStream = at::cuda::getCurrentCUDAStream(); auto cpuDataTranferStream = at::cuda::getStreamFromPool(); const scalar_t kNegInfinity = -std::numeric_limits::infinity();