From b821a59e74e20560a1c868c645a03b7103d690c5 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Wed, 11 Nov 2020 16:28:44 -0800 Subject: [PATCH] [TorchBind] Support using lambda function as TorchBind constructor/init method --- .../jit/test_custom_class_registrations.cpp | 18 ++++++++++ test/jit/test_torchbind.py | 7 ++++ torch/custom_class.h | 33 +++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index f563120bbc6c..fc2d83d76409 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -33,6 +33,14 @@ struct Foo : torch::CustomClassHolder { } }; +struct LambdaInit : torch::CustomClassHolder { + int x, y; + LambdaInit(int x_, int y_) : x(x_), y(y_) {} + int64_t diff() { + return this->x - this->y; + } +}; + struct NoInit : torch::CustomClassHolder { int64_t x; }; @@ -202,6 +210,16 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { .def("add", &Foo::add) .def("combine", &Foo::combine); + m.class_("_LambdaInit") + .def(torch::init([](int64_t x, int64_t y, bool swap) { + if (swap) { + return c10::make_intrusive(y, x); + } else { + return c10::make_intrusive(x, y); + } + })) + .def("diff", &LambdaInit::diff); + m.class_("_NoInit").def( "get_x", [](const c10::intrusive_ptr& self) { return self->x; }); diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py index c1ca50270197..866170545747 100644 --- a/test/jit/test_torchbind.py +++ b/test/jit/test_torchbind.py @@ -338,3 +338,10 @@ def test_torchbind_attr_exception(self): foo = torch.classes._TorchScriptTesting._StackString(["test"]) with self.assertRaisesRegex(AttributeError, 'does not have a field'): foo.bar + + def test_lambda_as_constructor(self): + obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False) + self.assertEqual(obj_no_swap.diff(), 1) + + obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True) + self.assertEqual(obj_swap.diff(), -1) diff --git a/torch/custom_class.h b/torch/custom_class.h index 571a584294db..080d9d9d3c95 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -27,6 +27,21 @@ detail::types init() { return detail::types{}; } +template +struct InitLambda { + Func f; +}; + +template +decltype(auto) init(Func&& f) { + using InitTraits = + c10::guts::infer_function_traits_t>; + using ParameterTypeList = typename InitTraits::parameter_types; + + InitLambda init{std::forward(f)}; + return init; +} + /// Entry point for custom C++ class registration. To register a C++ class /// in PyTorch, instantiate `torch::class_` with the desired class as the /// template parameter. Typically, this instantiation should be done in @@ -95,6 +110,24 @@ class class_ { return *this; } + // Used in combination with torch::init([]lambda(){......}) + template + class_& def( + InitLambda> init, + std::string doc_string = "") { + auto init_lambda_wrapper = [func = std::move(init.f)]( + c10::tagged_capsule self, + ParameterTypes... arg) { + c10::intrusive_ptr classObj = + at::guts::invoke(func, std::forward(arg)...); + auto object = self.ivalue.toObject(); + object->setSlot(0, c10::IValue::make_capsule(classObj)); + }; + defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string)); + + return *this; + } + /// This is the normal method registration API. `name` is the name that /// the method will be made accessible by in Python and TorchScript. /// `f` is a callable object that defines the method. Typically `f`