Skip to content

Commit

Permalink
modify unit test for custom dispatch_stub module
Browse files Browse the repository at this point in the history
  • Loading branch information
caizhi-mt committed May 6, 2023
1 parent 5b69691 commit aba8401
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 92 deletions.
111 changes: 22 additions & 89 deletions test/cpp_extensions/open_registration_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <c10/core/impl/alloc_cpu.h>
#include <c10/core/Allocator.h>
#include <c10/core/ScalarType.h>
#include <c10/util/ArrayRef.h>

#include <torch/csrc/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
Expand All @@ -10,14 +12,15 @@
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Resize.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/ops/abs_native.h>
#include <ATen/EmptyTensor.h>
#include <ATen/core/GeneratorForPrivateuseone.h>

static uint64_t add_counter = 0;
static uint64_t last_saved_value = 0;

static uint64_t abs_counter = 0;
static uint64_t last_abs_saved_value = 0;
// register guard
namespace at {
namespace detail {
Expand All @@ -26,100 +29,20 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::

}} // namespace at::detail

struct CustomBackendMetadata : public c10::BackendMeta {
// for testing this field will mutate when clone() is called by shallow_copy_from.
int backend_version_format_{-1};
int format_number_{-1};
mutable bool cloned_{false};
// define the constructor
CustomBackendMetadata(int backend_version_format, int format_number): backend_version_format_(backend_version_format), format_number_(format_number) {}
c10::intrusive_ptr<c10::BackendMeta> clone(const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
cloned_ = true;
return c10::BackendMeta::clone(ptr);
}
};

// we need to register two functions for serialization
void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr() == nullptr) {
return;
}
CustomBackendMetadata* tmeta = dynamic_cast<CustomBackendMetadata*>(t.unsafeGetTensorImpl()->get_backend_meta());
if (tmeta->backend_version_format_ == 1) {
m["backend_version_format"] = true;
}
if (tmeta->format_number_ == 29) {
m["format_number"] = true;
}
}

void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
int backend_version_format{-1};
int format_number{-1};
if (m.find("backend_version_format") != m.end()) {
backend_version_format = 1;
}
if (m.find("format_number") != m.end()) {
format_number = 29;
}
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata(backend_version_format, format_number))};
t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
}

void custom_serialization_registry(){
torch::jit::TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1, &for_serialization, &for_deserialization);
}

//check if BackendMeta serialization correctly
bool check_backend_meta(const at::Tensor& t) {
if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr()) {
CustomBackendMetadata* tmeta = dynamic_cast<CustomBackendMetadata*>(t.unsafeGetTensorImpl()->get_backend_meta());
if (tmeta->backend_version_format_==1 && tmeta->format_number_==29) {
return true;
}
}
return false;
}

// a fake set function is exposed to the Python side
void custom_set_backend_meta(const at::Tensor& t) {
int backend_version_format{1};
int format_number{29};
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata(backend_version_format, format_number))};
t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
}

namespace {

template <typename T>
static T abs_impl(T v) {
return std::abs(v);
void abs_kernel(::at::TensorIteratorBase& iter) {
// Since this custom device is just for testing, not bothering to implement kernels.
abs_counter += 1;
}

template <>
C10_UNUSED uint8_t abs_impl(uint8_t v) {
return v;
}
} // namespace

void abs_kernel(TensorIteratorBase& iter) {
auto dtype = iter.dtype();
if (dtype == kComplexHalf) {
using scalar_t = c10::complex<Half>;
using opmath_t = at::opmath_type<scalar_t>;
cpu_kernel(iter, [=](scalar_t a) -> scalar_t { return abs_impl(opmath_t{a}); });
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "abs_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return abs_impl(a); },
[=](Vectorized<scalar_t> a) { return a.abs(); });
});
}
}
namespace at::native {

REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);

} // namespace
} // namespace at::native

// basic dummy add function
at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
Expand All @@ -129,7 +52,7 @@ at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other,
}

// basic abs function
at::Tensor & abs_out(at::Tensor & out, const at::Tensor & self) {
at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
return at::native::abs_out(self, out);
}

Expand Down Expand Up @@ -255,7 +178,7 @@ const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
// This macro registers your kernels to the PyTorch Dispatcher.
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("abs.out", &abs_out);
m.impl("abs.out", &custom_abs_out);
m.impl("add.Tensor", &custom_add_Tensor);
m.impl("empty.memory_format", &custom_empty_symint);
m.impl("fill_.Scalar", &custom_fill__scalar);
Expand Down Expand Up @@ -285,6 +208,15 @@ bool custom_add_called() {
return called;
}

bool custom_abs_called() {
bool called = false;
if (abs_counter > last_abs_saved_value) {
called = true;
last_abs_saved_value = abs_counter;
}
return called;
}

class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
public:
// Constructors
Expand All @@ -311,5 +243,6 @@ void register_generator() {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_device", &get_custom_device, "get custom device object");
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
m.def("custom_abs_called", &custom_abs_called, "check if our custom abs function was called");
m.def("register_generator", &register_generator, "register generator for custom device");
}
8 changes: 5 additions & 3 deletions test/test_cpp_extensions_open_device_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ def test_generator_registration():
def test_open_device_dispatchstub():
# test kernels could be reused by privateuse1 backend through dispatchstub
torch.utils.rename_privateuse1_backend('foo')
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device='cpu')
foo_input_data = input_data.to('foo')
self.assertTrue(torch.abs(input_data) == torch.abs(foo_input_data).to('cpu'))
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu")
foo_input_data = input_data.to("foo")
self.assertFalse(self.module.custom_abs_called())
torch.abs(foo_input_data)
self.assertTrue(self.module.custom_abs_called())

def test_open_device_random():
with torch.random.fork_rng(device_type="foo"):
Expand Down

0 comments on commit aba8401

Please sign in to comment.