-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make TensorNet compatible with TorchScript #186
Conversation
I can torch.jit.script TensorNet (the
I dont really understand why this error arises here and not in the test. I also do not understand why this error only happens when using jit.script. |
Is the problem happening only when trying to train, or also when doing inference on both energies and forces? |
I updated the TorchScript test to try and do a double backwards, it passes without issues. |
Given that TorchScript is not really intended to be used in training within TorchMD-Net as of now this PR in its current state leaves tensornet at the same level of compatibility as ET (I cannot use script during training with any model), that is the tests pass. One can script TorchMD-Net set up with tensornet, store it, load it and do inference and backprop on it. Provided this I think we should merge this now. Please review! |
I am ok with leaving it just compatible at this point, but accelerating training might be useful even if it is a 10%. Anyway, I remember that at some point we could train? That's when we saw the limited speedup of 10-20%, right? |
I was not able to trace the model. I am leaving that aside for now... Just for the record, the error I get in test_train when trying to trace is this:
Batch size is 8 in this test, I believe the issue is that the batch size is different for testing, val, and training. The scripted/traced model is cool with training, but when the shape of the input changes for testing it gets confused. I remember we were able to train with a scripted module @guillemsimeon, but that was without using LNNP as far as I can remember. |
Interestingly this test passes without any issue: @mark.parametrize("model_name", models.__all__)
def test_torchscript_dynamic_shapes(model_name):
z, pos, batch = create_example_batch()
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
)
#Repeat the input to make it dynamic
for rep in range(0, 10):
zi = z.repeat_interleave(rep+1, dim=0)
posi = pos.repeat_interleave(rep+1, dim=0)
batchi = torch.cat([batch + i for i in range(rep+1)])
y, neg_dy = model(zi, posi, batch=batchi)
grad_outputs = [torch.ones_like(neg_dy)]
ddy = torch.autograd.grad(
[neg_dy],
[posi],
grad_outputs=grad_outputs,
)[0] |
I found out why this test passes and training does not work. |
Does removing the enumerates fix the error? |
No, you cannot do that in TorchScript:
|
I would like to merge this now, since my upcoming PR with TensorNet optimization depends on this. @raimis @guillemsimeon , could you please review again? |
I understand that it has been made compatible, but we cannot use this compatibility in training in the end? Or the situation changed? |
I still cannot train with a TorchScript model. But not only TensorNet, no model at all. I believe it is Torch lightning related... |
well, we can deal with this later or even never at all, if other optimizations really make a larger difference. should I proceed to close the PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
No description provided.