Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/libtorchaudio/forced_align/cpu/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ std::tuple<Tensor, Tensor> compute(
ScalarType::Long);
const auto B = logProbs.size(0);
const auto T = logProbs.size(1);
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T});
Tensor paths = torch::stable::empty({B, T}, targets.scalar_type());
torch::stable::zero_(paths);
THO_DISPATCH_V2(
logProbs.scalar_type(),
"forced_align_impl",
Expand Down
3 changes: 2 additions & 1 deletion src/libtorchaudio/forced_align/gpu/compute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ std::tuple<Tensor, Tensor> compute(
auto B = logProbs.size(0);
auto T = logProbs.size(1); // num frames

Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}, /*dtype=*/std::nullopt, /*layout=*/std::nullopt, /*device=*/torch::stable::DeviceType::CPU);
Tensor paths = torch::stable::empty({B, T}, targets.scalar_type());
torch::stable::zero_(paths);

THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] {
if (targets.scalar_type() == ScalarType::Long) {
Expand Down
5 changes: 3 additions & 2 deletions src/libtorchaudio/rnnt/cpu/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ std::tuple<Tensor, Tensor> compute(
// when stable ABI Tensor supports mutable_data_ptr templates.
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/reinterpret_cast<float*>(float_workspace.data_ptr()),
/*dtype_data=*/
reinterpret_cast<float*>(float_workspace.mutable_data_ptr()),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/reinterpret_cast<int*>(int_workspace.data_ptr()),
/*int_data=*/reinterpret_cast<int*>(int_workspace.mutable_data_ptr()),
/*int_size=*/int_workspace.numel());

THO_DISPATCH_V2(
Expand Down
113 changes: 12 additions & 101 deletions src/libtorchaudio/stable/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,15 @@
#include <c10/cuda/CUDAException.h>
#endif

using torch::stable::Tensor;

namespace torchaudio::stable {

using Layout = int32_t;

// TODO: When sizes and strides are implemented in torch::stable,
// eliminate sizes and strides function below.
inline std::vector<int64_t> sizes(const Tensor& t) {
int64_t* ptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(t.get(), &ptr));
std::vector<int64_t> r(ptr, ptr + t.dim());
return r;
}

inline std::vector<int64_t> strides(const Tensor& t) {
int64_t* ptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(t.get(), &ptr));
std::vector<int64_t> r(ptr, ptr + t.dim());
return r;
}

// TODO: When https://github.com/pytorch/pytorch/pull/161891 lands,
// eliminate mutable_data_ptr and const_data_ptr templates.
#define aoti_torch_get_mutable_data_ptr aoti_torch_get_data_ptr
#define aoti_torch_get_const_data_ptr aoti_torch_get_data_ptr
template <typename T>
T* mutable_data_ptr(const Tensor& t) {
void* data_ptr{};
TORCH_ERROR_CODE_CHECK(aoti_torch_get_mutable_data_ptr(t.get(), &data_ptr));
return reinterpret_cast<T*>(data_ptr);
}

template <typename T>
const T* const_data_ptr(const Tensor& t) {
const void* data_ptr{};
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_const_data_ptr(t.get(), const_cast<void**>(&data_ptr)));
return reinterpret_cast<const T*>(data_ptr);
}
using torch::stable::Tensor;

// TODO: When cpu is implemented in torch::stable, eliminate
// cpu function below.
// TODO: When cpu op is implemented in torch::stable, eliminate cpu
// function below.
inline Tensor cpu(const Tensor& self) {
auto sizes_ = sizes(self);
auto cpu_type = aoti_torch_device_type_cpu();
auto sizes_ = self.sizes();
int32_t cpu_type = static_cast<int32_t>(torch::stable::DeviceType::CPU);
int32_t dtype;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &dtype));
int32_t layout;
Expand All @@ -83,10 +46,11 @@ inline Tensor cpu(const Tensor& self) {
return result;
}

// TODO:
// TODO: When cuda op is implemented in torch::stable, eliminate cuda
// function below.
inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
auto sizes_ = sizes(self);
auto cuda_type = aoti_torch_device_type_cuda();
auto sizes_ = self.sizes();
int32_t cuda_type = static_cast<int32_t>(torch::stable::DeviceType::CUDA);
int32_t dtype;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &dtype));
int32_t layout;
Expand All @@ -107,69 +71,16 @@ inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
return result;
}

// TODO: remove when torch::stable provides new_zeros
inline Tensor new_zeros(
const Tensor& self,
std::vector<int64_t> size,
std::optional<c10::ScalarType> dtype = std::nullopt,
std::optional<Layout> layout = std::nullopt,
std::optional<torch::stable::Device> device = std::nullopt,
std::optional<bool> pin_memory = std::nullopt) {
int32_t target_dtype{};
if (dtype.has_value()) {
target_dtype = torch::stable::detail::to<int32_t>(
torch::stable::detail::from(dtype.value()));
} else {
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
}

Layout layout_;
if (layout.has_value()) {
layout_ = layout.value();
} else {
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout_));
}

int32_t device_type;
torch::stable::DeviceIndex device_index = 0;
if (device.has_value()) {
auto device_ = device.value();
device_type = static_cast<int32_t>(device_.type());
device_index = device_.index();
} else {
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_type(self.get(), &device_type));
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(self.get(), &device_index));
}

// TODO: pin_memory

AtenTensorHandle ret0;
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty(
self.get(),
size.data(),
static_cast<int64_t>(size.size()),
&target_dtype,
&layout_,
&device_type,
device_index,
nullptr, // pin_memory (nullptr for default)
&ret0));

auto result = Tensor(ret0);
torch::stable::zero_(result);
return result;
}

// An analog of item template function defined in
// ATen/templates/TensorBody.h
template <typename T>
T item(const Tensor& self) {
STD_TORCH_CHECK(
self.numel() == 1, "item requires single element tensor input");
if (self.is_cpu()) {
return torchaudio::stable::const_data_ptr<T>(self)[0];
// TODO: use `return self.const_data_ptr<T>()[0];` after torch
// stable supports const_data_ptr templates.
return reinterpret_cast<const T*>(self.const_data_ptr())[0];
#ifdef USE_CUDA
} else if (self.is_cuda()) {
T value;
Expand Down
9 changes: 5 additions & 4 deletions src/libtorchaudio/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

// TODO: replace the include libtorchaudio/stable/ops.h with
// torch/stable/ops.h when torch::stable provides all required
// features (torch::stable::item<T> or similar):
// features (torch::stable::item<T> et al):
#include <libtorchaudio/stable/ops.h>

namespace torchaudio {
Expand All @@ -25,7 +25,7 @@ using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
// after Tensor::accessor is supported in stable ABI
template <typename T, size_t N>
inline TensorAccessor<T, N> accessor(Tensor t) {
inline TensorAccessor<T, N> accessor(torch::stable::Tensor t) {
return TensorAccessor<T, N>(
reinterpret_cast<T*>(t.data_ptr()), t.sizes().data(), t.strides().data());
}
Expand All @@ -42,7 +42,7 @@ using PackedTensorAccessor32 =
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
// after Tensor::accessor is supported in stable ABI
template <typename T, size_t N>
inline PackedTensorAccessor32<T, N> packed_accessor32(Tensor t) {
inline PackedTensorAccessor32<T, N> packed_accessor32(torch::stable::Tensor t) {
return PackedTensorAccessor32<T, N>(
static_cast<typename PackedTensorAccessor32<T, N>::PtrType>(t.data_ptr()),
t.sizes().data(),
Expand All @@ -58,7 +58,8 @@ using PackedTensorAccessorSizeT =
size_t>;

template <typename T, size_t N>
inline PackedTensorAccessorSizeT<T, N> packed_accessor_size_t(Tensor t) {
inline PackedTensorAccessorSizeT<T, N> packed_accessor_size_t(
torch::stable::Tensor t) {
return PackedTensorAccessorSizeT<T, N>(
static_cast<typename PackedTensorAccessorSizeT<T, N>::PtrType>(
t.data_ptr()),
Expand Down
Loading