Skip to content

Conversation

goldsborough
Copy link
Contributor

@goldsborough goldsborough commented May 18, 2018

To solve the problem of references to parameters/submodules in the base module getting invalidated on clone() and reassignment of variables in submodules, this PR implements a change discussed by @ezyang, @ebetica and I, which, instead of storing references, stores functions that can be passed the this pointer to access the members. You can think of this as a fancy way of storing the offsets into the subclass.

Posting as MVP to discuss. It's a little batshit crazy, but only a little.

Review order:

  1. torch/detail/member_ref.h
  2. torch/nn/module.h
  3. torch/nn/modules/linear.cpp and other modules
  4. Other stuff

Pros:

  • Solves reference invalidation
  • Allows use of constructors instead of reset() function (not implemented yet!)

Cons:

  • Complicated code
  • Unsafe, I accidentally passed the wrong object to access the member and fortunately (!) it segfaulted. This could be improved by using our std::any shim instead of void* in the implementation of MemberRef
  • Registration is a little verbose: register_parameter("weight", &MyModule::weight, at::ones(at::CPU(at::kFloat), {3, 3})

@ezyang @ebetica @apaszke

@goldsborough goldsborough changed the title Using new registration mechanism [C]Using new registration mechanism May 18, 2018
@goldsborough goldsborough changed the title [C]Using new registration mechanism [C++ API] Using new registration mechanism May 18, 2018
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This syntax is quite verbose. What if we did it more like

reqister_module(l3, Linear(3, 5).build());

also, a helper for building modules might end up being useful as well

build_module(l3, Linear(3, 5));

Ideally we would do a Catch-style thing like

add(l3 = Linear(3, 5));

This is almost good, because the expression actually evaluates to an lvalue, so we can get access to the module too (and e.g. detect if it's been constructed, and call build() if it hasn't). This would still work if we changed MemberRef to use pointer offsets. I think we could just store lvalue_addr - this in that case.

The only problem with it, is that we don't know how to assign a name to it. This should be still solvable, although I don't have a pretty solution in mind just yet. A dirty one would be to turn the expression into the string and slice out the part that stands before the equal sign at run time.

Note that this macro could evaluate to a different function based on the lvalue type, so you could also have

add(weight = torch::CPU(torch::kFloat).ones({num_features}));

WDYT?

@@ -152,7 +152,7 @@ thread_local size_t deleteFunctionRecursionDepth = 0;
#ifdef _WIN32
constexpr size_t kDeleteFunctionMaxRecursionDepth = 3000;
#else
constexpr size_t kDeleteFunctionMaxRecursionDepth = 10000;
constexpr size_t kDeleteFunctionMaxRecursionDepth = 5000;

This comment was marked as off-topic.

This comment was marked as off-topic.

ihb_.emplace_back();
hhb_.emplace_back();
// ihb_.emplace_back();
// hhb_.emplace_back();

This comment was marked as off-topic.

}

template <typename Derived>
void register_parameters(ParameterList Derived::*parameter_list) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}),
const_getter_([member](ConstThis self) -> const T& {
return static_cast<const Class*>(self)->*member;
}) {}

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@goldsborough
Copy link
Contributor Author

Where would you get the parameter name from in all your suggestions? We could just name them with progressive numbers, but it might be nice to have names for parameters, at least in our core modules

@ebetica
Copy link
Contributor

ebetica commented May 21, 2018

And progressive numbers is not very like pytorch, and is not super backwards compatible: you always have to add new parameters to the end of your list. I'm generally happy with a macro for ease of naming for users, like

#define REGISTER(x, expr) register(#x, &(decltype(*this):: ## x), expr);

usage:
REGISTER(weight, at::ones(at::CPU(at::kFloat), {3, 3})));

though it does add a little bit of magic, and you can't have unbound parameters. However, I think this is the same case in pytorch.

@apaszke
Copy link
Contributor

apaszke commented May 21, 2018

Yep, I've been thinking of doing #x to turn the first argument into a string at compile time. The

REGISTER(weight = at::ones(...))

syntax is nice, although that would be more hacky because it would create a compile-time string of weight = at::ones(...), and we'd need to slice it at the first equal sign at run-time. It doesn't seem to bad either.

@goldsborough
Copy link
Contributor Author

I will add a nice macro to make registration less verbose in another PR. I want to merge this now to make progress on my other in-flight PRs.

@goldsborough goldsborough merged commit 549b406 into pytorch:master May 22, 2018
@goldsborough goldsborough deleted the register branch May 22, 2018 00:59
petrex pushed a commit to petrex/pytorch that referenced this pull request May 23, 2018
…e2_core_hip

* 'caffe2_core_hip' of github.com:petrex/pytorch: (24 commits)
  Allow empty storage for the 'Edge' class. (pytorch#7595)
  Process group base class and Gloo implementation (pytorch#7628)
  _LRSchedulers getstate include optimizer info (pytorch#7757)
  [PyTorch] [gradcheck] change backward() to grad() (pytorch#7710)
  Update test_nn.py (pytorch#7787)
  Define general default scheduler for TBB and fix ppc64le bug (pytorch#7761)
  Add support for accepting Tensor as input in clip_grad_*  functions. (pytorch#7769)
  [Easy] Remove unused code (pytorch#7782)
  Update tbb (pytorch#7734)
  Add @generated annotation (pytorch#7780)
  fix legacy comment after variable tensor merge (pytorch#7771)
  Revert pytorch#7750 and pytorch#7762 to fix Windows CI on master (pytorch#7772)
  Temporarily disable build env check (pytorch#7768)
  Add missing brace (pytorch#7762)
  [C++ API] Add backward() to Tensor and Variable  (pytorch#7750)
  [auto] Update onnx to d43b550 - Fix .gitignore and add missing files (onnx/onnx#1005) onnx/onnx@d43b550
  [auto] Update onnx to ea1aa13 - add tests for reduce ops (onnx/onnx#675) onnx/onnx@ea1aa13
  include cudnn_h (pytorch#7749)
  [C++ API] Using new registration mechanism (pytorch#7663)
  [auto] Update onnx to 5dd68e6 - Add a util function: polish_model (onnx/onnx#1000) onnx/onnx@5dd68e6
  ...
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
* Using new registration mechanism

* Fix signature of param() in module.cpp

* Remove ParameterList

* Fix tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants