Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using lambda function as TorchBind constructor #47819

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/cpp/jit/test_custom_class_registrations.cpp
Expand Up @@ -33,6 +33,14 @@ struct Foo : torch::CustomClassHolder {
}
};

struct LambdaInit : torch::CustomClassHolder {
int x, y;
LambdaInit(int x_, int y_) : x(x_), y(y_) {}
int64_t diff() {
return this->x - this->y;
}
};

struct NoInit : torch::CustomClassHolder {
int64_t x;
};
Expand Down Expand Up @@ -202,6 +210,16 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
.def("add", &Foo::add)
.def("combine", &Foo::combine);

m.class_<LambdaInit>("_LambdaInit")
.def(torch::init([](int64_t x, int64_t y, bool swap) {
if (swap) {
return c10::make_intrusive<LambdaInit>(y, x);
} else {
return c10::make_intrusive<LambdaInit>(x, y);
}
}))
.def("diff", &LambdaInit::diff);

m.class_<NoInit>("_NoInit").def(
"get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });

Expand Down
7 changes: 7 additions & 0 deletions test/jit/test_torchbind.py
Expand Up @@ -338,3 +338,10 @@ def test_torchbind_attr_exception(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"])
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
foo.bar

def test_lambda_as_constructor(self):
obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
self.assertEqual(obj_no_swap.diff(), 1)

obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
self.assertEqual(obj_swap.diff(), -1)
33 changes: 33 additions & 0 deletions torch/custom_class.h
Expand Up @@ -27,6 +27,21 @@ detail::types<void, Types...> init() {
return detail::types<void, Types...>{};
}

template <typename Func, typename... ParameterTypeList>
struct InitLambda {
Func f;
};

template <typename Func>
decltype(auto) init(Func&& f) {
using InitTraits =
c10::guts::infer_function_traits_t<std::decay_t<Func>>;
using ParameterTypeList = typename InitTraits::parameter_types;

InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
return init;
}

/// Entry point for custom C++ class registration. To register a C++ class
/// in PyTorch, instantiate `torch::class_` with the desired class as the
/// template parameter. Typically, this instantiation should be done in
Expand Down Expand Up @@ -95,6 +110,24 @@ class class_ {
return *this;
}

// Used in combination with torch::init([]lambda(){......})
template <typename Func, typename... ParameterTypes>
class_& def(
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
std::string doc_string = "") {
auto init_lambda_wrapper = [func = std::move(init.f)](
c10::tagged_capsule<CurClass> self,
ParameterTypes... arg) {
c10::intrusive_ptr<CurClass> classObj =
at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(classObj));
};
defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string));

return *this;
}

/// This is the normal method registration API. `name` is the name that
/// the method will be made accessible by in Python and TorchScript.
/// `f` is a callable object that defines the method. Typically `f`
Expand Down