Skip to content

Commit

Permalink
[TorchBind] Support using lambda function as TorchBind constructor/in…
Browse files Browse the repository at this point in the history
…it method
  • Loading branch information
gmagogsfm committed Nov 12, 2020
1 parent febc76a commit b821a59
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
18 changes: 18 additions & 0 deletions test/cpp/jit/test_custom_class_registrations.cpp
Expand Up @@ -33,6 +33,14 @@ struct Foo : torch::CustomClassHolder {
}
};

struct LambdaInit : torch::CustomClassHolder {
int x, y;
LambdaInit(int x_, int y_) : x(x_), y(y_) {}
int64_t diff() {
return this->x - this->y;
}
};

struct NoInit : torch::CustomClassHolder {
int64_t x;
};
Expand Down Expand Up @@ -202,6 +210,16 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
.def("add", &Foo::add)
.def("combine", &Foo::combine);

m.class_<LambdaInit>("_LambdaInit")
.def(torch::init([](int64_t x, int64_t y, bool swap) {
if (swap) {
return c10::make_intrusive<LambdaInit>(y, x);
} else {
return c10::make_intrusive<LambdaInit>(x, y);
}
}))
.def("diff", &LambdaInit::diff);

m.class_<NoInit>("_NoInit").def(
"get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });

Expand Down
7 changes: 7 additions & 0 deletions test/jit/test_torchbind.py
Expand Up @@ -338,3 +338,10 @@ def test_torchbind_attr_exception(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"])
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
foo.bar

def test_lambda_as_constructor(self):
obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
self.assertEqual(obj_no_swap.diff(), 1)

obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
self.assertEqual(obj_swap.diff(), -1)
33 changes: 33 additions & 0 deletions torch/custom_class.h
Expand Up @@ -27,6 +27,21 @@ detail::types<void, Types...> init() {
return detail::types<void, Types...>{};
}

template <typename Func, typename... ParameterTypeList>
struct InitLambda {
Func f;
};

template <typename Func>
decltype(auto) init(Func&& f) {
using InitTraits =
c10::guts::infer_function_traits_t<std::decay_t<Func>>;
using ParameterTypeList = typename InitTraits::parameter_types;

InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
return init;
}

/// Entry point for custom C++ class registration. To register a C++ class
/// in PyTorch, instantiate `torch::class_` with the desired class as the
/// template parameter. Typically, this instantiation should be done in
Expand Down Expand Up @@ -95,6 +110,24 @@ class class_ {
return *this;
}

// Used in combination with torch::init([]lambda(){......})
template <typename Func, typename... ParameterTypes>
class_& def(
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
std::string doc_string = "") {
auto init_lambda_wrapper = [func = std::move(init.f)](
c10::tagged_capsule<CurClass> self,
ParameterTypes... arg) {
c10::intrusive_ptr<CurClass> classObj =
at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(classObj));
};
defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string));

return *this;
}

/// This is the normal method registration API. `name` is the name that
/// the method will be made accessible by in Python and TorchScript.
/// `f` is a callable object that defines the method. Typically `f`
Expand Down

0 comments on commit b821a59

Please sign in to comment.