Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TensorNet compatible with TorchScript #186

Merged
merged 24 commits into from
Jun 27, 2023

Conversation

RaulPPelaez
Copy link
Collaborator

No description provided.

@RaulPPelaez
Copy link
Collaborator Author

I can torch.jit.script TensorNet (the test_forward_torchscript passes), but when I try to torch.jit.script TorchMD_Net and do a training I get an error:

Traceback of TorchScript (most recent call last):
  File "/shared/raul/torchmd-net/torchmdnet/models/model.py", line 254, in forward
        if self.derivative:
            grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
            dy = grad(
                 ~~~~ <--- HERE
                [y],
                [pos],
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The size of tensor a (128) must match the size of tensor b (3) at non-singleton dimension 3

I dont really understand why this error arises here and not in the test. I also do not understand why this error only happens when using jit.script.

@guillemsimeon
Copy link
Collaborator

Is the problem happening only when trying to train, or also when doing inference on both energies and forces?

@RaulPPelaez
Copy link
Collaborator Author

I updated the TorchScript test to try and do a double backwards, it passes without issues.
I can do inference and backpropagation of an scripted TorchMD_Net module, but if I give LNNP an scripted module, or try to script LLNP itself (either calling jit.script on it or using the Lighning to_torchscript() member) I get that error.

@RaulPPelaez
Copy link
Collaborator Author

Given that TorchScript is not really intended to be used in training within TorchMD-Net as of now this PR in its current state leaves tensornet at the same level of compatibility as ET (I cannot use script during training with any model), that is the tests pass. One can script TorchMD-Net set up with tensornet, store it, load it and do inference and backprop on it.

Provided this I think we should merge this now. Please review!

@guillemsimeon
Copy link
Collaborator

I am ok with leaving it just compatible at this point, but accelerating training might be useful even if it is a 10%. Anyway, I remember that at some point we could train? That's when we saw the limited speedup of 10-20%, right?

@RaulPPelaez
Copy link
Collaborator Author

I was not able to trace the model. I am leaving that aside for now... Just for the record, the error I get in test_train when trying to trace is this:

torchmd-net/torchmdnet/module.py:128: UserWarning: Using a target size (torch.Size([8, 1])) that is different to the input size (torch.Size([50, 1])). 
This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.                                                                       
    loss_y = loss_fn(y, batch.y)   

Batch size is 8 in this test, I believe the issue is that the batch size is different for testing, val, and training. The scripted/traced model is cool with training, but when the shape of the input changes for testing it gets confused.

I remember we were able to train with a scripted module @guillemsimeon, but that was without using LNNP as far as I can remember.

@RaulPPelaez
Copy link
Collaborator Author

Interestingly this test passes without any issue:

@mark.parametrize("model_name", models.__all__)
def test_torchscript_dynamic_shapes(model_name):
    z, pos, batch = create_example_batch()
    model = torch.jit.script(
        create_model(load_example_args(model_name, remove_prior=True, derivative=True))
    )
    #Repeat the input to make it dynamic
    for rep in range(0, 10):
        zi = z.repeat_interleave(rep+1, dim=0)
        posi = pos.repeat_interleave(rep+1, dim=0)
        batchi = torch.cat([batch + i for i in range(rep+1)])
        y, neg_dy = model(zi, posi, batch=batchi)
        grad_outputs = [torch.ones_like(neg_dy)]
        ddy = torch.autograd.grad(
            [neg_dy],
            [posi],
            grad_outputs=grad_outputs,
        )[0]

@RaulPPelaez
Copy link
Collaborator Author

I found out why this test passes and training does not work.
The tests are cpu ony, trying the above test with gpu-stored tensors yields the error I posted initially.
Its only with tensornet, after a couple of differently sized inputs to the model it gives that error.

@guillemsimeon
Copy link
Collaborator

Does removing the enumerates fix the error?

@RaulPPelaez
Copy link
Collaborator Author

Does removing the enumerates fix the error?

No, you cannot do that in TorchScript:

E       RuntimeError: 
E       Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
E         File "torchmd-net/torchmdnet/models/tensornet.py", line 201
E               # Interaction layers
E               for i in range(self.num_layers):
E                   X = self.layers[i](X, edge_index, edge_weight, edge_attr)
E                       ~~~~~~~~~~~~~~ <--- HERE
E               I, A, S = decompose_tensor(X)
E               x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)

../../mambaforge/envs/test2/lib/python3.10/site-packages/torch/jit/_recursive.py:397: RuntimeError

@RaulPPelaez
Copy link
Collaborator Author

I would like to merge this now, since my upcoming PR with TensorNet optimization depends on this. @raimis @guillemsimeon , could you please review again?

@guillemsimeon
Copy link
Collaborator

I understand that it has been made compatible, but we cannot use this compatibility in training in the end? Or the situation changed?

@RaulPPelaez
Copy link
Collaborator Author

I still cannot train with a TorchScript model. But not only TensorNet, no model at all. I believe it is Torch lightning related...
I spoke to @PhilippThoelke and apparently he also had issues with this.

@guillemsimeon
Copy link
Collaborator

well, we can deal with this later or even never at all, if other optimizations really make a larger difference. should I proceed to close the PR?

Copy link
Collaborator

@guillemsimeon guillemsimeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@RaulPPelaez RaulPPelaez merged commit a116847 into torchmd:main Jun 27, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants