Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/libtorchaudio/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#pragma once

#include <cuda_runtime_api.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/csrc/stable/device.h>

namespace libtorchaudio::cuda {

inline cudaStream_t getCurrentCUDAStream(
torch::stable::DeviceIndex device_index = -1) {
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
return static_cast<cudaStream_t>(stream_ptr);
}

inline void setCurrentCUDAStream(
cudaStream_t stream,
torch::stable::DeviceIndex device_index = -1) {
TORCH_ERROR_CODE_CHECK(
torch_set_current_cuda_stream(static_cast<void*>(stream), device_index));
}

inline cudaStream_t getStreamFromPool(
const bool isHighPriority = false,
torch::stable::DeviceIndex device_index = -1) {
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(torch_get_cuda_stream_from_pool(
isHighPriority, device_index, &stream_ptr));
return static_cast<cudaStream_t>(stream_ptr);
}

inline void synchronize(
cudaStream_t stream,
torch::stable::DeviceIndex device_index = -1) {
TORCH_ERROR_CODE_CHECK(
torch_cuda_stream_synchronize(static_cast<void*>(stream), device_index));
}

} // namespace libtorchaudio::cuda
21 changes: 12 additions & 9 deletions src/libtorchaudio/forced_align/gpu/compute.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <libtorchaudio/cuda_utils.h>
#include <libtorchaudio/utils.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/macros.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/ScalarType.h>

Expand Down Expand Up @@ -119,8 +121,9 @@ void forced_align_impl(
const Tensor& targets,
const int64_t blank,
Tensor& paths) {
auto defaultStream = at::cuda::getCurrentCUDAStream();
auto cpuDataTranferStream = at::cuda::getStreamFromPool();
auto device_index = logProbs.get_device_index();
auto defaultStream = libtorchaudio::cuda::getCurrentCUDAStream(device_index);
auto cpuDataTranferStream = libtorchaudio::cuda::getStreamFromPool(false, device_index);
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
Expand Down Expand Up @@ -204,29 +207,29 @@ void forced_align_impl(
backPtrBufferLen,
torchaudio::packed_accessor32<scalar_t, 2>(alphas),
torchaudio::packed_accessor32<int8_t, 2>(backPtrBuffer));
C10_CUDA_KERNEL_LAUNCH_CHECK();
STD_CUDA_KERNEL_LAUNCH_CHECK();
++backPtrBufferLen;
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
cpuDataTranferStream.synchronize();
libtorchaudio::cuda::synchronize(cpuDataTranferStream, device_index);
// GPU -> GPU copy
bufferCopy = torch::stable::clone(backPtrBuffer);
STD_TORCH_CHECK(bufferCopy.is_contiguous(), "unexpected fail, need to implement stable::Tensor::contiguous()")
defaultStream.synchronize();
at::cuda::setCurrentCUDAStream(cpuDataTranferStream);
libtorchaudio::cuda::synchronize(defaultStream, device_index);
libtorchaudio::cuda::setCurrentCUDAStream(cpuDataTranferStream, device_index);
// Copy ASYNC from GPU to CPU
int64_t offset =
static_cast<int64_t>(t + 1 - backPtrBufferLen) * S * sizeof(int8_t);
C10_CUDA_CHECK(cudaMemcpyAsync(
STD_CUDA_CHECK(cudaMemcpyAsync(
static_cast<int8_t*>(backPtrCpu.data_ptr()) + offset,
bufferCopy.data_ptr(),
backPtrBufferLen * S * sizeof(int8_t),
cudaMemcpyDeviceToHost,
cpuDataTranferStream));
at::cuda::setCurrentCUDAStream(defaultStream);
libtorchaudio::cuda::setCurrentCUDAStream(defaultStream, device_index);
backPtrBufferLen = 0;
}
}
cpuDataTranferStream.synchronize();
libtorchaudio::cuda::synchronize(cpuDataTranferStream, device_index);
auto alphasCpu = torchaudio::stable::cpu(alphas);
auto alphasCpu_a = torchaudio::accessor<scalar_t, 2>(alphasCpu);
int curIdxOffset = ((T - 1) % 2);
Expand Down
10 changes: 4 additions & 6 deletions src/libtorchaudio/iir_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#include <libtorchaudio/utils.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/macros.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/core/ScalarType.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/core/DeviceGuard.h>

using torch::headeronly::ScalarType;
using torch::stable::Tensor;
Expand Down Expand Up @@ -65,8 +64,7 @@ Tensor cuda_lfilter_core_loop(

STD_TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));

const at::cuda::OptionalCUDAGuard device_guard(in.get_device_index());

const torch::stable::accelerator::DeviceGuard device_guard(in.get_device_index());
const dim3 threads(256);
const dim3 blocks((N * C + threads.x - 1) / threads.x);

Expand All @@ -76,7 +74,7 @@ Tensor cuda_lfilter_core_loop(
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out)));
C10_CUDA_KERNEL_LAUNCH_CHECK();
STD_CUDA_KERNEL_LAUNCH_CHECK();
}), AT_FLOATING_TYPES);
return padded_out;
}
11 changes: 5 additions & 6 deletions src/libtorchaudio/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
#include <libtorchaudio/cuda_utils.h>
#include <libtorchaudio/stable/ops.h>

#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/Dispatch_v2.h>
Expand Down Expand Up @@ -76,9 +76,8 @@ std::tuple<Tensor, Tensor> compute(
"blank must be within [0, logits.shape[-1])");

auto max_ivalue = [](const Tensor& t) {
int32_t value;
C10_CUDA_CHECK(cudaMemcpy(&value, torch::stable::amax(t, {}).data_ptr(), sizeof(int32_t), cudaMemcpyDeviceToHost));
return value;
auto mx = torchaudio::stable::cpu(torch::stable::amax(t, {}));
return reinterpret_cast<int32_t*>(mx.data_ptr())[0];
};

STD_TORCH_CHECK(
Expand All @@ -100,7 +99,7 @@ std::tuple<Tensor, Tensor> compute(
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;
options.stream_ = at::cuda::getCurrentCUDAStream();
options.stream_ = libtorchaudio::cuda::getCurrentCUDAStream(logits.get_device_index());
cudaSetDevice(logits.get_device());
options.device_ = GPU;

Expand Down
36 changes: 36 additions & 0 deletions src/libtorchaudio/shim_temporary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once
// TODO: remove this file once https://github.com/pytorch/pytorch/pull/169376
// has landed in nightly.

#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/stable/c/shim.h>

inline AOTITorchError tmp_torch_set_current_cuda_stream(
void* stream,
int32_t device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::cuda::setCurrentCUDAStream(at::cuda::getStreamFromExternal(
static_cast<cudaStream_t>(stream), device_index));
});
}

inline AOTITorchError tmp_torch_get_cuda_stream_from_pool(
const bool isHighPriority,
int32_t device_index,
void** ret_stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
*(cudaStream_t*)(ret_stream) =
at::cuda::getStreamFromPool(isHighPriority, device_index);
});
}

inline AOTITorchError tmp_torch_cuda_stream_synchronize(
void* stream,
int32_t device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::cuda::getStreamFromExternal(
static_cast<cudaStream_t>(stream), device_index)
.synchronize();
});
}
10 changes: 1 addition & 9 deletions src/libtorchaudio/stable/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>

#ifdef USE_CUDA
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#endif

namespace torchaudio::stable {

using torch::stable::Tensor;
Expand Down Expand Up @@ -83,10 +78,7 @@ T item(const Tensor& self) {
return reinterpret_cast<const T*>(self.const_data_ptr())[0];
#ifdef USE_CUDA
} else if (self.is_cuda()) {
T value;
C10_CUDA_CHECK(cudaMemcpyAsync(
&value, self.data_ptr(), sizeof(T), cudaMemcpyDeviceToHost));
return value;
return torchaudio::stable::item<T>(torchaudio::stable::cpu(self));
#endif
} else {
STD_TORCH_CHECK(false, "unreachable"); // not implemented
Expand Down
Loading