Skip to content

Commit

Permalink
modify unit test for custom dispatch_stub module
Browse files Browse the repository at this point in the history
  • Loading branch information
caizhi-mt committed May 1, 2023
1 parent 3040101 commit effa7de
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 30 deletions.
63 changes: 36 additions & 27 deletions test/cpp_extensions/open_registration_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <c10/core/impl/alloc_cpu.h>
#include <c10/core/Allocator.h>
#include <c10/core/ScalarType.h>
#include <c10/util/ArrayRef.h>

#include <torch/csrc/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
Expand All @@ -11,13 +13,18 @@
#include <ATen/native/Resize.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/cpu/vec/vec_base.h>
#include <ATen/ops/abs_native.h>
#include <ATen/EmptyTensor.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/TensorIterator.h>
#include <ATen/OpMathType.h>

static uint64_t add_counter = 0;
static uint64_t last_saved_value = 0;

static uint64_t abs_counter = 0;
static uint64_t last_abs_saved_value = 0;
// register guard
namespace at {
namespace detail {
Expand All @@ -28,35 +35,18 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::

namespace {

template <typename T>
static T abs_impl(T v) {
return std::abs(v);
}

template <>
C10_UNUSED uint8_t abs_impl(uint8_t v) {
return v;
}

void abs_kernel(TensorIteratorBase& iter) {
auto dtype = iter.dtype();
if (dtype == kComplexHalf) {
using scalar_t = c10::complex<Half>;
using opmath_t = at::opmath_type<scalar_t>;
cpu_kernel(iter, [=](scalar_t a) -> scalar_t { return abs_impl(opmath_t{a}); });
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "abs_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return abs_impl(a); },
[=](Vectorized<scalar_t> a) { return a.abs(); });
});
}
void abs_kernel(::at::TensorIteratorBase& iter) {
// Since this custom device is just for testing, not bothering to implement kernels.
abs_counter += 1;
}

} // namespace

namespace at::native {

REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);

} // namespace
} // namespace at::native

// basic dummy add function
at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
Expand All @@ -66,10 +56,18 @@ at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other,
}

// basic abs function
at::Tensor & abs_out(at::Tensor & out, const at::Tensor & self) {
at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
return at::native::abs_out(self, out);
}

const at::Tensor& custom_resize_(
const at::Tensor& self,
c10::IntArrayRef size,
c10::optional<c10::MemoryFormat> optional_memory_format) {
// Since this custom device is just for testing, not bothering to implement kernels.
return self;
}

// A dummy allocator for our custom device, that secretly uses the CPU
struct DummyCustomAllocator final : at::Allocator {
DummyCustomAllocator() = default;
Expand Down Expand Up @@ -192,7 +190,8 @@ const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
// This macro registers your kernels to the PyTorch Dispatcher.
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("abs.out", &abs_out);
m.impl("abs.out", &custom_abs_out);
m.impl("resize_", &custom_resize_);
m.impl("add.Tensor", &custom_add_Tensor);
m.impl("empty.memory_format", &custom_empty_symint);
m.impl("fill_.Scalar", &custom_fill__scalar);
Expand Down Expand Up @@ -222,6 +221,15 @@ bool custom_add_called() {
return called;
}

bool custom_abs_called() {
bool called = false;
if (abs_counter > last_abs_saved_value) {
called = true;
last_abs_saved_value = abs_counter;
}
return called;
}

class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
public:
// Constructors
Expand All @@ -248,5 +256,6 @@ void register_generator() {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_device", &get_custom_device, "get custom device object");
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
m.def("custom_abs_called", &custom_abs_called, "check if our custom abs function was called");
m.def("register_generator", &register_generator, "register generator for custom device");
}
8 changes: 5 additions & 3 deletions test/test_cpp_extensions_open_device_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ def test_generator_registration():
def test_open_device_dispatchstub():
# test kernels could be reused by privateuse1 backend through dispatchstub
torch.utils.rename_privateuse1_backend('foo')
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device='cpu')
foo_input_data = input_data.to('foo')
self.assertTrue(torch.abs(input_data) == torch.abs(foo_input_data).to('cpu'))
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu")
foo_input_data = input_data.to("foo")
self.assertFalse(self.module.custom_abs_called())
torch.abs(foo_input_data)
self.assertTrue(self.module.custom_abs_called())

def test_open_device_random():
with torch.random.fork_rng(device_type="foo"):
Expand Down

0 comments on commit effa7de

Please sign in to comment.