New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Module::replace_module to C++ api #22546
Conversation
This adds a replace_module method to the C++ api. This is needed to be able to replace modules. The primary use case I am aware of is to enable finetuning of models. Given that finetuning is fairly popular these days, I think it would be good to facilitate this in the C++ api as well. This has been reported by Jean-Christophe Lombardo on the forums. https://discuss.pytorch.org/t/finetuning-a-model-on-multiple-gpu-in-c/49195
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @t-vi! Just one minor comment, and then it should be good to go.
test/cpp/api/module.cpp
Outdated
TEST_F(ModuleTest, ReplaceModule) { | ||
struct TestModel : public torch::nn::Module { | ||
using torch::nn::Module::register_module; | ||
using torch::nn::Module::replace_module; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to do this in order to use register_module
or replace_module
on TestModel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I took this from the register_module
tests.
My understanding is that we do this here as a convenience to use register_module
and replace_module
outside TestModel
in the test. When you go the usual way of subclassing the module and then using replace_module
, it will work. The obvious alternative would be to make replace_module
public, but I don't necessarily think we should, as the member assignment needs to be coordinated with replace_module
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After more investigation, I realized there is a bigger problem with the replace_module
API:
#include <torch/torch.h>
#include <iostream>
struct TestModel : public torch::nn::Module {
using torch::nn::Module::replace_module;
TestModel() {
l1 = register_module("l1", torch::nn::Linear(10, 3));
}
torch::nn::Linear l1{nullptr};
};
int main() {
TestModel model;
model.replace_module("l1", torch::nn::Linear(5, 6));
std::cout << model.l1 << std::endl; // this doens't agree with what we just did
return 0;
}
The problem is that we always have two shared_ptr
s to the same layer: the first one is torch::nn::Linear l1{nullptr}
(defined as class member of TestModel), and the second one is in TestModel.children_
after calling register_module("l1", torch::nn::Linear(10, 3))
. They are in sync after the TestModel
constructor, but if we implement replace_module
only by changing the corresponding shared_ptr
in TestModel.children_
, the l1
class member doesn't know about the change and thus will not be updated. Thus they will be out-of-sync which can cause user confusion.
One bigger idea is to deprecate the usage of explicitly defined class members for nn::Module, and instead we should overload the []
operator so that model['l1']
always looks up the stored value in model.children_
. This way, there is only one value we ever need to update when replacing a submodule, which causes less surprises to users and developers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I documented the requirement as
//// This takes care of the registration, you should assign the submodule
/// as well. It only works when a module of the name is already registered.
My take on this is
- Adding the submodules as class members is a convention suggested in the tutorials and not a requirement, or is it? It adds the convenience of accessing them as
m->mychild
. - While I can see that only using replace_module and not changing the class member if it exists is bad, I would argue that currently people will try to replace the class member and be left alone with it not working and no way to change it short of redefining the entire module cascade.
- I do think that finetuning with a different number of classes is an important enough case in vision to justify adding this imperfect api, in particular since I would say going from "if you followed the tutorials, you need to change two places, you can change one but not the other" to "if you followed the tutorials, you need to change two places" is an improvement.
That said if you think it's a bad idea, so we could just skip the PR, and anyone desperately needing the feature can add her own replace_module function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I think we should provide this API so that people can replace a module easily. As for how we should keep the submodule class member and the submodule registered in children_
in sync, I think we should encourage people to always do model.l1 = model.replace_module("l1", torch::nn::Linear(3, 4))
if they have a submodule class member. We should add a new test for this use case and have an assert to make sure model.l1
points to the same module we just registered in children_
.
test/cpp/api/module.cpp
Outdated
TEST_F(ModuleTest, ReplaceModule) { | ||
struct TestModel : public torch::nn::Module { | ||
using torch::nn::Module::register_module; | ||
using torch::nn::Module::replace_module; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I think we should provide this API so that people can replace a module easily. As for how we should keep the submodule class member and the submodule registered in children_
in sync, I think we should encourage people to always do model.l1 = model.replace_module("l1", torch::nn::Linear(3, 4))
if they have a submodule class member. We should add a new test for this use case and have an assert to make sure model.l1
points to the same module we just registered in children_
.
test/cpp/api/module.cpp
Outdated
@@ -99,6 +99,28 @@ TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) { | |||
"Submodule 'linear' already defined"); | |||
} | |||
|
|||
TEST_F(ModuleTest, ReplaceModuleThrowsForUnknownModuleName) { | |||
struct TestModel : public torch::nn::Module { | |||
using torch::nn::Module::replace_module; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested with:
#include <torch/torch.h>
#include <iostream>
struct TestModel : public torch::nn::Module {
TestModel() {
l1 = register_module("l1", torch::nn::Linear(10, 3));
}
torch::nn::Linear l1{nullptr};
};
int main() {
TestModel model;
model.l1 = model.replace_module("l1", torch::nn::Linear(5, 6));
return 0;
}
However the compiler complains that it isn't able to access the replace_module
function. Since replacing a submodule outside of the model definition is a common use case, I think we should make the replace_module
function public to support this use case. And we should also update our test cases to not do using torch::nn::Module::replace_module
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made replace_module public and updated the test and also the comment to propose assigning + replacing.
model.replace_module("linear", torch::nn::Linear(5, 6)); | ||
ASSERT_EQ(model.named_parameters()["linear.weight"].size(0), 6); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a test to check that model->l1
points to the same submodule as model->named_modules()["l1"]
after model->l1 = model.replace_module("l1", torch::nn::Linear(5, 6))
? I would like to illustrate this use case in the test so that people can have it as a reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, thank you for the suggestion, added that now.
@pytorchbot rebase this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @t-vi !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This adds a replace_module method to the C++ api. This is needed to be able to replace modules.
The primary use case I am aware of is to enable finetuning of models.
Given that finetuning is fairly popular these days, I think it would be good to facilitate this in the C++ api as well.
This has been reported by Jean-Christophe Lombardo on the forums.