Skip to content

Commit

Permalink
rename "private_use1" to "privateuse1" and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
caizhi-mt committed Apr 25, 2023
1 parent d6624e1 commit d217ca1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 6 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/DispatchStub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ void* DispatchStubImpl::get_call_ptr(
#endif

case DeviceType::PrivateUse1:
TORCH_INTERNAL_ASSERT(private_use1_dispatch_ptr, "DispatchStub: missing PrivateUse1 kernel");
return private_use1_dispatch_ptr;
TORCH_INTERNAL_ASSERT(privateuse1_dispatch_ptr, "DispatchStub: missing PrivateUse1 kernel");
return privateuse1_dispatch_ptr;

default:
AT_ERROR("DispatchStub: unsupported device type", device_type);
Expand Down
18 changes: 14 additions & 4 deletions aten/src/ATen/native/DispatchStub.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ struct TORCH_API DispatchStubImpl {
void* cuda_dispatch_ptr;
void* hip_dispatch_ptr;
void* mps_dispatch_ptr;
void* private_use1_dispatch_ptr;
void* privateuse1_dispatch_ptr;
#else
std::atomic<void*> cpu_dispatch_ptr{nullptr};
void* cuda_dispatch_ptr = nullptr;
void* hip_dispatch_ptr = nullptr;
void* mps_dispatch_ptr = nullptr;
void* private_use1_dispatch_ptr = nullptr;
void* privateuse1_dispatch_ptr = nullptr;
#endif
};

Expand Down Expand Up @@ -172,8 +172,8 @@ struct DispatchStub<rT (*)(Args...), T> {
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}

void set_private_use1_dispatch_ptr(FnPtr fn_ptr) {
impl.private_use1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
void setprivateuse1_dispatch_ptr(FnPtr fn_ptr) {
impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}

static TORCH_API FnPtr DEFAULT;
Expand Down Expand Up @@ -216,6 +216,13 @@ struct RegisterHIPDispatch {
}
};

template <typename DispatchStub>
struct RegisterPRIVATEUSE1Dispatch {
RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
stub.set_privateuse1_dispatch_ptr(value);
}
};

} // anonymous namespace
// Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
Expand Down Expand Up @@ -280,6 +287,9 @@ struct RegisterHIPDispatch {
#define REGISTER_MPS_DISPATCH(name, fn) \
static RegisterMPSDispatch<struct name> name ## __register(name, fn);

#define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
static RegisterPRIVATEUSE1Dispatch<struct name> name ## __register(name, fn);

// NB: This macro must be used in an actual 'cu' file; if you try using
// it from a 'cpp' file it will not work!
#if defined(__CUDACC__)
Expand Down
41 changes: 41 additions & 0 deletions test/cpp_extensions/open_registration_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Resize.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/ops/abs_native.h>
#include <ATen/EmptyTensor.h>
#include <ATen/core/GeneratorForPrivateuseone.h>

Expand All @@ -23,13 +26,50 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::

}} // namespace at::detail

namespace {

template <typename T>
static T abs_impl(T v) {
return std::abs(v);
}

template <>
C10_UNUSED uint8_t abs_impl(uint8_t v) {
return v;
}

void abs_kernel(TensorIteratorBase& iter) {
auto dtype = iter.dtype();
if (dtype == kComplexHalf) {
using scalar_t = c10::complex<Half>;
using opmath_t = at::opmath_type<scalar_t>;
cpu_kernel(iter, [=](scalar_t a) -> scalar_t { return abs_impl(opmath_t{a}); });
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "abs_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return abs_impl(a); },
[=](Vectorized<scalar_t> a) { return a.abs(); });
});
}
}

REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);

} // namespace

// basic dummy add function
at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
add_counter += 1;
// Since this custom device is just for testing, not bothering to implement kernels.
return at::empty(self.sizes(), self.options());
}

// basic abs function
at::Tensor & abs_out(at::Tensor & out, const at::Tensor & self) {
return at::native::abs_out(self, out);
}

// A dummy allocator for our custom device, that secretly uses the CPU
struct DummyCustomAllocator final : at::Allocator {
DummyCustomAllocator() = default;
Expand Down Expand Up @@ -138,6 +178,7 @@ bool custom_is_pinned(const at::Tensor& self, c10::optional<at::Device> device)
// This macro registers your kernels to the PyTorch Dispatcher.
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("abs.out", &abs_out);
m.impl("add.Tensor", &custom_add_Tensor);
m.impl("empty.memory_format", &custom_empty_symint);
m.impl("fill_.Scalar", &custom_fill__scalar);
Expand Down
8 changes: 8 additions & 0 deletions test/test_cpp_extensions_open_device_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ def test_generator_registration():
"Only can register a generator to the PrivateUse1 dispatch key once"):
self.module.register_generator()

def test_open_device_dispatchstub():
# test kernels could be reused by privateuse1 backend through dispatchstub
torch.utils.rename_privateuse1_backend('foo')
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device='cpu')
foo_input_data = input_data.to('foo')
self.assertTrue(torch.abs(input_data) == torch.abs(foo_input_data).to('cpu'))

def test_open_device_random():
with torch.random.fork_rng(device_type="foo"):
pass
Expand Down Expand Up @@ -269,6 +276,7 @@ def test_open_device_serialization():
test_common_registration()
test_after_common_registration()
test_generator_registration()
test_open_device_dispatchstub()
test_open_device_random()
test_open_device_tensor()
test_open_device_storage()
Expand Down

0 comments on commit d217ca1

Please sign in to comment.