Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchScript support #33

Closed
raimis opened this issue Jul 22, 2021 · 2 comments
Closed

TorchScript support #33

raimis opened this issue Jul 22, 2021 · 2 comments

Comments

@raimis
Copy link
Collaborator

raimis commented Jul 22, 2021

TorchMD_GN cannot be converted to TorchScript:

import torch
from torchmdnet.models.torchmd_gn import TorchMD_GN

model = TorchMD_GN()
torch.jit.script(model)
Traceback (most recent call last):
  File "tmn_jit.py", line 5, in <module>
    torch.jit.script(model)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch/jit/_script.py", line 942, in script
    return torch.jit._recursive.create_script_module(
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch/jit/_recursive.py", line 391, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch/jit/_recursive.py", line 448, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch/jit/_script.py", line 391, in _construct
    init_fn(script_module)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch/jit/_recursive.py", line 428, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch/jit/_recursive.py", line 452, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch/jit/_recursive.py", line 335, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch/jit/_recursive.py", line 757, in try_compile_fn
    return torch.jit.script(fn, _rcb=rcb)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch/jit/_script.py", line 989, in script
    fn = torch._C._jit_script_compile(
RuntimeError: 
python value of type 'module' cannot be used as a value. Perhaps it is a closed over global variable? If so, please consider passing it in as an argument or use a local varible instead.:
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torch_geometric/nn/pool/__init__.py", line 210
        edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
    """
    if torch_cluster is None:
       ~~~~~~~~~~~~~ <--- HERE
        raise ImportError('`radius_graph` requires `torch-cluster`.')
'radius_graph' is being compiled since it was called from 'Distance.forward'
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torchmdnet/models/utils.py", line 179
    def forward(self, pos, batch):
        edge_index = radius_graph(pos, r=self.cutoff_upper, batch=batch, loop=self.loop,
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                                  max_num_neighbors=self.max_num_neighbors)
                                  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        edge_vec = pos[edge_index[0]] - pos[edge_index[1]]

PyTorch Geometric support TorchScript (https://pytorch-geometric.readthedocs.io/en/latest/notes/jit.html), but a few modifications are needed.

@PhilippThoelke
Copy link
Collaborator

I just pushed a couple of changes to add TorchScript support. Before, the output module would either return a tensor containing energy predictions or a tuple with energy and force predictions depending on self.derivative. I changed this to now always return a tuple:
https://github.com/compsciencelab/torchmd-net/blob/dbd72b496987b08f5213f0e88a6de8eea4feab05/torchmdnet/models/output_modules.py#L84-L92
PyTorch was supposed to support Union typing with 1.8.0 but there were delays and it seems like it didn't even make 1.9.0 now as the pull request is still open: pytorch/pytorch#53180. Once that is merged it will be easy to add back only returning energy predictions if self.derivative is false.
Training with forces under TorchScript doesn't work for some obscure reason but inference works fine. Let me know if you encounter any problems with it.

@raimis
Copy link
Collaborator Author

raimis commented Jul 22, 2021

Thanks! I'll test the changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants