Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature]requesting for pretrained weights #51

Closed
jianyin2016 opened this issue Apr 26, 2020 · 3 comments
Closed

[feature]requesting for pretrained weights #51

jianyin2016 opened this issue Apr 26, 2020 · 3 comments
Labels
enhancement New feature or request

Comments

@jianyin2016
Copy link

Is your feature request related to a problem? Please describe.

I would like to train a CNN-classifier with my custom data using the widely-used models like ResNet series,I found it is useful to initialize the model weights with ImageNet pretrained weights, and it is easy to implement with the torch::load API when the image channels of my dataset is 3, the same as ImageNet,under which situation no change should be made to the Conv1 layer.
It is the other situation when I try to train with gray images,as the Conv1 weights is supposed to be of in_channels=3, In the python fronten, I guess this maybe solved but imdieatly repalce the model.conv1 like this:

model.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)

but as for the C++ fronten, repalcing seems not work:

model->conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, 64, 7).stride(2).padding(3).bias(false).dilation(1));

Describe the solution you'd like

Correcting the API use to rightly loading the pretrained weights.

Describe alternatives you've considered

Maybe a pretrained model on gray image dataset would bypass the problem.

Additional context

Exception occurs during the model forward process.

@jianyin2016 jianyin2016 added the enhancement New feature or request label Apr 26, 2020
@mfl28
Copy link
Collaborator

mfl28 commented Apr 26, 2020

Hi @jianyin2016 !
I think the easiest way to achieve what you want to do - if I understand correctly - is:

  1. Create your model as you want it in python:
model = torchvision.models.resnet...(pretrained=True)
model.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
  1. Save it as a scriptmodule:
example = torch.rand(1, 1, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("my_resnet.pt")
  1. Load it in C++ via auto model = torch::jit::load("path/to/my_resnet.pt")

There is a nice and much more detailed description of this process in the official Pytorch tutorials:
Loading a TorchScript Model in C++

@jianyin2016
Copy link
Author

Hi,@mfl28 , thanks for your nice advice.

There is no doubt that your solution will work, I quite believe in that , but the solution seems to be sort of bypassing rather than solving. I think maybe it is not the right way using the C++ API while at the same time relying too much on the PYTHON parts. so, still, I am here asking for a solution absolutly within libtorch.

thanks!

@mfl28
Copy link
Collaborator

mfl28 commented Apr 26, 2020

To my knowledge it is currently not possible to do this completely within C++, as torch::load() cannot load pickled weight files and there is also no load_state_dict() function for models in libtorch (see this recent issue in the official pytorch repo). Maybe you can find more information on this in the Pytorch forum or the official repo's issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants