Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Module::replace_module to C++ api #22546

Closed
wants to merge 4 commits into from

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Jul 5, 2019

This adds a replace_module method to the C++ api. This is needed to be able to replace modules.

The primary use case I am aware of is to enable finetuning of models.
Given that finetuning is fairly popular these days, I think it would be good to facilitate this in the C++ api as well.

This has been reported by Jean-Christophe Lombardo on the forums.

This adds a replace_module method to the C++ api. This is needed
to be able to replace modules.

The primary use case I am aware of is to enable finetuning of models.
Given that finetuning is fairly popular these days, I think it
would be good to facilitate this in the C++ api as well.

This has been reported by Jean-Christophe Lombardo on the forums.
https://discuss.pytorch.org/t/finetuning-a-model-on-multiple-gpu-in-c/49195
@pytorchbot pytorchbot added the module: cpp Related to C++ API label Jul 5, 2019
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 8, 2019
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @t-vi! Just one minor comment, and then it should be good to go.

TEST_F(ModuleTest, ReplaceModule) {
struct TestModel : public torch::nn::Module {
using torch::nn::Module::register_module;
using torch::nn::Module::replace_module;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to do this in order to use register_module or replace_module on TestModel?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I took this from the register_module tests.
My understanding is that we do this here as a convenience to use register_module and replace_module outside TestModel in the test. When you go the usual way of subclassing the module and then using replace_module, it will work. The obvious alternative would be to make replace_module public, but I don't necessarily think we should, as the member assignment needs to be coordinated with replace_module.

Copy link
Contributor

@yf225 yf225 Jul 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After more investigation, I realized there is a bigger problem with the replace_module API:

#include <torch/torch.h>
#include <iostream>

struct TestModel : public torch::nn::Module {
    using torch::nn::Module::replace_module;

    TestModel() {
      l1 = register_module("l1", torch::nn::Linear(10, 3));
    }

    torch::nn::Linear l1{nullptr};
};

int main() {
  TestModel model;
  model.replace_module("l1", torch::nn::Linear(5, 6));
  std::cout << model.l1 << std::endl;  // this doens't agree with what we just did
  return 0;
}

The problem is that we always have two shared_ptrs to the same layer: the first one is torch::nn::Linear l1{nullptr} (defined as class member of TestModel), and the second one is in TestModel.children_ after calling register_module("l1", torch::nn::Linear(10, 3)). They are in sync after the TestModel constructor, but if we implement replace_module only by changing the corresponding shared_ptr in TestModel.children_, the l1 class member doesn't know about the change and thus will not be updated. Thus they will be out-of-sync which can cause user confusion.

One bigger idea is to deprecate the usage of explicitly defined class members for nn::Module, and instead we should overload the [] operator so that model['l1'] always looks up the stored value in model.children_. This way, there is only one value we ever need to update when replacing a submodule, which causes less surprises to users and developers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I documented the requirement as

  //// This takes care of the registration, you should assign the submodule
  /// as well. It only works when a module of the name is already registered.

My take on this is

  • Adding the submodules as class members is a convention suggested in the tutorials and not a requirement, or is it? It adds the convenience of accessing them as m->mychild.
  • While I can see that only using replace_module and not changing the class member if it exists is bad, I would argue that currently people will try to replace the class member and be left alone with it not working and no way to change it short of redefining the entire module cascade.
  • I do think that finetuning with a different number of classes is an important enough case in vision to justify adding this imperfect api, in particular since I would say going from "if you followed the tutorials, you need to change two places, you can change one but not the other" to "if you followed the tutorials, you need to change two places" is an improvement.

That said if you think it's a bad idea, so we could just skip the PR, and anyone desperately needing the feature can add her own replace_module function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I think we should provide this API so that people can replace a module easily. As for how we should keep the submodule class member and the submodule registered in children_ in sync, I think we should encourage people to always do model.l1 = model.replace_module("l1", torch::nn::Linear(3, 4)) if they have a submodule class member. We should add a new test for this use case and have an assert to make sure model.l1 points to the same module we just registered in children_.

TEST_F(ModuleTest, ReplaceModule) {
struct TestModel : public torch::nn::Module {
using torch::nn::Module::register_module;
using torch::nn::Module::replace_module;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I think we should provide this API so that people can replace a module easily. As for how we should keep the submodule class member and the submodule registered in children_ in sync, I think we should encourage people to always do model.l1 = model.replace_module("l1", torch::nn::Linear(3, 4)) if they have a submodule class member. We should add a new test for this use case and have an assert to make sure model.l1 points to the same module we just registered in children_.

@@ -99,6 +99,28 @@ TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
"Submodule 'linear' already defined");
}

TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) {
struct TestModel : public torch::nn::Module {
using torch::nn::Module::replace_module;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested with:

#include <torch/torch.h>
#include <iostream>

struct TestModel : public torch::nn::Module {
    TestModel() {
      l1 = register_module("l1", torch::nn::Linear(10, 3));
    }

    torch::nn::Linear l1{nullptr};
};

int main() {
  TestModel model;
  model.l1 = model.replace_module("l1", torch::nn::Linear(5, 6));
  return 0;
}

However the compiler complains that it isn't able to access the replace_module function. Since replacing a submodule outside of the model definition is a common use case, I think we should make the replace_module function public to support this use case. And we should also update our test cases to not do using torch::nn::Module::replace_module.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made replace_module public and updated the test and also the comment to propose assigning + replacing.

torch/csrc/api/include/torch/nn/module.h Outdated Show resolved Hide resolved
torch/csrc/api/include/torch/nn/module.h Outdated Show resolved Hide resolved
model.replace_module("linear", torch::nn::Linear(5, 6));
ASSERT_EQ(model.named_parameters()["linear.weight"].size(0), 6);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a test to check that model->l1 points to the same submodule as model->named_modules()["l1"] after model->l1 = model.replace_module("l1", torch::nn::Linear(5, 6))? I would like to illustrate this use case in the test so that people can have it as a reference.

Copy link
Collaborator Author

@t-vi t-vi Jul 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thank you for the suggestion, added that now.

@yf225
Copy link
Contributor

yf225 commented Jul 22, 2019

@pytorchbot rebase this please

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @t-vi !

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in 0dabaad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cpp Related to C++ API open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants