-
Notifications
You must be signed in to change notification settings - Fork 25.4k
[C++ API] Using new registration mechanism #7663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This syntax is quite verbose. What if we did it more like
reqister_module(l3, Linear(3, 5).build());
also, a helper for building modules might end up being useful as well
build_module(l3, Linear(3, 5));
Ideally we would do a Catch-style thing like
add(l3 = Linear(3, 5));
This is almost good, because the expression actually evaluates to an lvalue, so we can get access to the module too (and e.g. detect if it's been constructed, and call build()
if it hasn't). This would still work if we changed MemberRef
to use pointer offsets. I think we could just store lvalue_addr - this
in that case.
The only problem with it, is that we don't know how to assign a name to it. This should be still solvable, although I don't have a pretty solution in mind just yet. A dirty one would be to turn the expression into the string and slice out the part that stands before the equal sign at run time.
Note that this macro could evaluate to a different function based on the lvalue type, so you could also have
add(weight = torch::CPU(torch::kFloat).ones({num_features}));
WDYT?
torch/csrc/autograd/function.cpp
Outdated
@@ -152,7 +152,7 @@ thread_local size_t deleteFunctionRecursionDepth = 0; | |||
#ifdef _WIN32 | |||
constexpr size_t kDeleteFunctionMaxRecursionDepth = 3000; | |||
#else | |||
constexpr size_t kDeleteFunctionMaxRecursionDepth = 10000; | |||
constexpr size_t kDeleteFunctionMaxRecursionDepth = 5000; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ihb_.emplace_back(); | ||
hhb_.emplace_back(); | ||
// ihb_.emplace_back(); | ||
// hhb_.emplace_back(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
} | ||
|
||
template <typename Derived> | ||
void register_parameters(ParameterList Derived::*parameter_list) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
}), | ||
const_getter_([member](ConstThis self) -> const T& { | ||
return static_cast<const Class*>(self)->*member; | ||
}) {} |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Where would you get the parameter name from in all your suggestions? We could just name them with progressive numbers, but it might be nice to have names for parameters, at least in our core modules |
And progressive numbers is not very like pytorch, and is not super backwards compatible: you always have to add new parameters to the end of your list. I'm generally happy with a macro for ease of naming for users, like #define REGISTER(x, expr) register(#x, &(decltype(*this):: ## x), expr);
usage:
REGISTER(weight, at::ones(at::CPU(at::kFloat), {3, 3}))); though it does add a little bit of magic, and you can't have unbound parameters. However, I think this is the same case in pytorch. |
Yep, I've been thinking of doing REGISTER(weight = at::ones(...)) syntax is nice, although that would be more hacky because it would create a compile-time string of |
I will add a nice macro to make registration less verbose in another PR. I want to merge this now to make progress on my other in-flight PRs. |
…e2_core_hip * 'caffe2_core_hip' of github.com:petrex/pytorch: (24 commits) Allow empty storage for the 'Edge' class. (pytorch#7595) Process group base class and Gloo implementation (pytorch#7628) _LRSchedulers getstate include optimizer info (pytorch#7757) [PyTorch] [gradcheck] change backward() to grad() (pytorch#7710) Update test_nn.py (pytorch#7787) Define general default scheduler for TBB and fix ppc64le bug (pytorch#7761) Add support for accepting Tensor as input in clip_grad_* functions. (pytorch#7769) [Easy] Remove unused code (pytorch#7782) Update tbb (pytorch#7734) Add @generated annotation (pytorch#7780) fix legacy comment after variable tensor merge (pytorch#7771) Revert pytorch#7750 and pytorch#7762 to fix Windows CI on master (pytorch#7772) Temporarily disable build env check (pytorch#7768) Add missing brace (pytorch#7762) [C++ API] Add backward() to Tensor and Variable (pytorch#7750) [auto] Update onnx to d43b550 - Fix .gitignore and add missing files (onnx/onnx#1005) onnx/onnx@d43b550 [auto] Update onnx to ea1aa13 - add tests for reduce ops (onnx/onnx#675) onnx/onnx@ea1aa13 include cudnn_h (pytorch#7749) [C++ API] Using new registration mechanism (pytorch#7663) [auto] Update onnx to 5dd68e6 - Add a util function: polish_model (onnx/onnx#1000) onnx/onnx@5dd68e6 ...
* Using new registration mechanism * Fix signature of param() in module.cpp * Remove ParameterList * Fix tests
To solve the problem of references to parameters/submodules in the base module getting invalidated on
clone()
and reassignment of variables in submodules, this PR implements a change discussed by @ezyang, @ebetica and I, which, instead of storing references, stores functions that can be passed thethis
pointer to access the members. You can think of this as a fancy way of storing the offsets into the subclass.Posting as MVP to discuss. It's a little batshit crazy, but only a little.
Review order:
torch/detail/member_ref.h
torch/nn/module.h
torch/nn/modules/linear.cpp
and other modulesPros:
reset()
function (not implemented yet!)Cons:
std::any
shim instead ofvoid*
in the implementation ofMemberRef
register_parameter("weight", &MyModule::weight, at::ones(at::CPU(at::kFloat), {3, 3})
@ezyang @ebetica @apaszke