Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading jittable models #41

Closed
nec4 opened this issue Sep 3, 2021 · 11 comments
Closed

Loading jittable models #41

nec4 opened this issue Sep 3, 2021 · 11 comments
Labels
help wanted Extra attention is needed

Comments

@nec4
Copy link
Contributor

nec4 commented Sep 3, 2021

Hello. After catching up to main, I am no longer able to load my models after training them. When calling torch.load() on a .pt file, I get the following error:

ModuleNotFoundError: No module named 'CFConvJittable_07e26a'

Is there a new procedure for loading models for prediction/simulation?

@nec4 nec4 added the help wanted Extra attention is needed label Sep 3, 2021
@nec4
Copy link
Contributor Author

nec4 commented Sep 3, 2021

running torch.jit.load() on my trained model results inThe following error as well:

RuntimeError: [enforce fail at inline_container.cc:222] . file not found: archive/constants.pkl

@PhilippThoelke
Copy link
Collaborator

PhilippThoelke commented Sep 3, 2021

The problem is due to the way PyTorch Geometric constructs jittable modules (https://pytorch-geometric.readthedocs.io/en/latest/notes/jit.html#converting-gnn-models). It dynamically creates the code for the jitted MessagePassing modules at runtime, which causes problems with torch.load() as this requires the classes of the saved model to be accessible. I'm not sure if PyTorch Geometric has some best practice for this problem but torchmd-net implements the torchmdnet.models.model.load_model function for loading a pytorch-lightning checkpoint and extracting the network from it: https://github.com/compsciencelab/torchmd-net/blob/2e2fbcca2e4d5c818b94c1f0ea4589ff067ed3cf/torchmdnet/models/model.py#L99-L112

This works because pytorch-lightning just stores the state dict and load_model creates a new model from the saved args and just loads the state dict. Are you able to use this function for your purposes?

@nec4
Copy link
Contributor Author

nec4 commented Sep 3, 2021

I see - this does not work because torch.load() throws the first error for me. I just save my models (which are just nn.Modules) normally at each epoch using torch.save(). Is there are new way to save models?

@PhilippThoelke
Copy link
Collaborator

No, currently load_model unfortunately only works with pytorch-lightning checkpoints. If you have all the args required for creating a new model you could load models analogously though. Save the args together with the model's state dict and when loading the model, create the model from scratch using the stored args and then call model.load_state_dict() and pass the stored state dict.

@nec4
Copy link
Contributor Author

nec4 commented Sep 3, 2021

I see. I do not use PTLightning so it looks like I will just have to retrain my models.

@PhilippThoelke
Copy link
Collaborator

This seems to be a problem of PyTorch Geometric's jit functionalities so it might make sense to ask them how they are planning to load models with the dynamically generated code.

@nec4
Copy link
Contributor Author

nec4 commented Sep 3, 2021

Yeah - for now I will stick to saving state dictionary and the args.

@PhilippThoelke
Copy link
Collaborator

I will ask in their discussions forum.

@nec4
Copy link
Contributor Author

nec4 commented Sep 3, 2021

Thanks for the help!

@PhilippThoelke
Copy link
Collaborator

The people from PyTorch Geometric currently don't have a solution for this and recommend saving and loading just the state dict as I suggested. Maybe it will work in a future release. See this discussion for details: pyg-team/pytorch_geometric#3075 (comment)

@nec4
Copy link
Contributor Author

nec4 commented Sep 6, 2021

Thanks - I am currently just saving the state dict as a pickle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants