Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-trained model #19

Closed
raimis opened this issue May 17, 2021 · 32 comments
Closed

Pre-trained model #19

raimis opened this issue May 17, 2021 · 32 comments

Comments

@raimis
Copy link
Collaborator

raimis commented May 17, 2021

We are writing a paper about NNP/MM in ACEMD. So far, we have used ANI-2x for protein-ligand simulations, but to demonstrate a general utility, it would be good to include one more NNP.

Would it be possible to have a pre-trained TorchMD-NET model?

@giadefa
Copy link
Contributor

giadefa commented May 17, 2021 via email

@PhilippThoelke
Copy link
Collaborator

So you need a checkpoint file for a graph network trained e.g. on aspirin from the MD17 dataset?
Would it work for you if the model is trained with the next version of the code, e.e. when #20 is merged? This change adds some features that improve the performance on MD17.

@giadefa
Copy link
Contributor

giadefa commented May 18, 2021 via email

@PhilippThoelke
Copy link
Collaborator

PhilippThoelke commented May 19, 2021

I have added a graph network checkpoint pretrained on aspirin from the MD17 dataset. It used 950 samples for training, 50 for validation and the remaining samples for testing, which is the standard benchmark for this dataset. I used energies and forces for training, the exact hyperparameters can be found in the hparams.yaml file. You can find the model checkpoint, hyperparameters and splits here.

I also included the metrics.csv, which contains losses and learning rate for each epoch during training. The model checkpoint comes from epoch 1269 and reached an MAE of 0.224 for the energy and 0.630 for the forces on the test set.

The model was trained on version 0.1.0.

@raimis
Copy link
Collaborator Author

raimis commented Jul 22, 2021

The pre-trained model cannot be loaded:

import torch
from torchmdnet.models import load_model

model = load_model('examples/pretrained/md17-aspirin-graph-network/epoch=1269-val_loss=0.8859-test_loss=0.5893.ckpt')
Traceback (most recent call last):
  File "tmn_load.py", line 10, in <module>
    model = load_model('examples/pretrained/md17-aspirin-graph-network/epoch=1269-val_loss=0.8859-test_loss=0.5893.ckpt')
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torchmdnet/models/model.py", line 78, in load_model
    model = create_model(args)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torchmdnet/models/model.py", line 27, in create_model
    max_num_neighbors=args['max_num_neighbors'],
KeyError: 'max_num_neighbors'

@raimis raimis reopened this Jul 22, 2021
@PhilippThoelke
Copy link
Collaborator

It should be possible to load it using version 0.1.x under which it was trained. Since then I didn't update the model but I can do that now. I'll have to retrain it using the current version which will take half a day roughly.

@raimis
Copy link
Collaborator Author

raimis commented Jul 22, 2021

Thanks! It will be the most useful, if I can test the simulations with the latest version.

@PhilippThoelke
Copy link
Collaborator

Nevermind, I still had a recent model checkpoint from the most recent version, I just pushed it.

@raimis
Copy link
Collaborator Author

raimis commented Jul 22, 2021

Thanks! Now it works.

@raimis
Copy link
Collaborator Author

raimis commented Jul 26, 2021

I tried to run MD simulations with the NNP model:

  • Solvated aspirin with NNP/MM
  • Aspirin in vacuum with NNP

In both cases, the simulations "explode" with 1 ps. I tried to reduce the timestep to 0.5 fs, but it doesn't help. The same simulations with ANI-2x run without problems.

@raimis raimis reopened this Jul 26, 2021
@raimis
Copy link
Collaborator Author

raimis commented Jul 26, 2021

@PhilippThoelke would it be possible to run the system with TorchMD to verify the problem?

@PhilippThoelke
Copy link
Collaborator

Yes that is possible. You can use torchmdnet.calculators.External as an external force inside TorchMD.

@raimis
Copy link
Collaborator Author

raimis commented Jul 26, 2021

@stefdoerr could you help to step up the simulations?

@stefdoerr
Copy link
Collaborator

Don't you already have the input files since you ran them?

@raimis
Copy link
Collaborator Author

raimis commented Jul 26, 2021

@stefdoerr I do have input files (PDB and PRMTOP), but I haven't used TorchMD.

@stefdoerr
Copy link
Collaborator

https://github.com/torchmd/torchmd/blob/master/examples/tutorial.ipynb
It's relatively simple, but if you don't want to try it send me the input files and I can take a look

@raimis
Copy link
Collaborator Author

raimis commented Jul 26, 2021

Where do I need to add torchmdnet.calculators.External?

@PhilippThoelke
Copy link
Collaborator

You have to enter that as the --external arg in run.py: https://github.com/torchmd/torchmd/blob/3e12a6858c603af6c2b76696ff4edf0956dc0ea5/torchmd/run.py#L50
It additionally requires the path of the model checkpoint and embedding indices in this arg. I'm not sure if it's possible to pass all of that via the command line or whether you have to use a yaml config file for that. I believe you would have to enter it as

external:
  module: torchmdnet.calculators.External
  embeddings: [ 1, 1, 6, 6, ...]
  file: path/to/checkpoint

@raimis
Copy link
Collaborator Author

raimis commented Jul 28, 2021

I tried a simulation of aspirin with TorchMD:

coordinates: aspirin.pdb
cutoff: null
device: cuda
extended_system: null
external:
  embeddings:
  - 8
  - 8
  - 8
  - 8
  - 6
  - 6
  - 6
  - 6
  - 6
  - 6
  - 6
  - 6
  - 6
  - 1
  - 1
  - 1
  - 1
  - 1
  - 1
  - 1
  - 1
  file: ../torchmd-net.git/examples/pretrained/md17-aspirin-graph-network/epoch=1359-val_loss=0.5227-test_loss=0.4333.ckpt
  module: torchmdnet.calculators
forcefield: aspirin.prmtop
forceterms: null
langevin_gamma: 0.1
langevin_temperature: 300
log_dir: ./
minimize: 100
output: output
output_period: 1
precision: single
replicas: 1
rfa: false
save_period: 1
seed: 1
steps: 100
structure: null
switch_dist: null
temperature: 300
timestep: 1
topology: aspirin.prmtop

All the input files: aspirin_torchmd.zip

The simulation is unstable: the temperature resize uncontrollably from the first steps.

iter,ns,epot,ekin,etot,T,t
1,1e-06,-406880.0887773633,15.103168487548828,-406864.98560887575,241.27809871894112,4.027363538742065
2,2e-06,-406864.92625647783,14.164608001708984,-406850.7616484761,226.28428535170937,4.054379224777222
3,3e-06,-406861.4632962942,14.159696578979492,-406847.30359971523,226.20582375345904,4.08146595954895
4,4e-06,-406852.1209000349,16.686145782470703,-406835.43475425243,266.566683187086,4.108911752700806
5,4.9999999999999996e-06,-406845.173047781,35.16876983642578,-406810.00427794456,561.8326993711507,4.13638710975647
6,6e-06,-406819.02439904213,19.055503845214844,-406799.9688951969,304.417959827123,4.16302752494812
7,7e-06,-406835.88994038105,36.33468246459961,-406799.55525791645,580.4585382095441,4.190194845199585
8,8e-06,-406816.4535654783,69.41360473632812,-406747.039960742,1108.905233349976,4.2174201011657715
9,9e-06,-406790.12160658836,100.08391571044922,-406690.0376908779,1598.87356847484,4.24408483505249
10,9.999999999999999e-06,-406718.0202502012,50.58804702758789,-406667.43220317364,808.1607389061002,4.270519733428955
11,1.1e-05,-406711.58965563774,57.65319061279297,-406653.93646502495,921.0287382811944,4.2969443798065186
12,1.2e-05,-406745.146247983,97.73880767822266,-406647.40744030476,1561.4096940717563,4.3239805698394775
13,1.3e-05,-406780.00586628914,137.84088134765625,-406642.1649849415,2202.053549539873,4.350756406784058
14,1.4e-05,-406737.1390030384,140.08062744140625,-406597.058375597,2237.8342322197154,4.377088785171509
15,1.4999999999999999e-05,-406667.7502512932,89.35147857666016,-406578.3987727165,1427.4193449192999,4.403895616531372
16,1.6e-05,-406643.7843030691,93.51447296142578,-406550.2698301077,1493.924553476163,4.430569410324097
17,1.7e-05,-406690.91243588924,150.41690063476562,-406540.4955352545,2402.959606143029,4.457834005355835
18,1.8e-05,-406717.71087527275,181.7862091064453,-406535.9246661663,2904.094656871925,4.500119924545288
19,1.8999999999999998e-05,-406734.17486667633,194.3318328857422,-406539.8430337906,3104.5151352111134,4.535091400146484
20,1.9999999999999998e-05,-406659.9773603678,175.61720275878906,-406484.360157609,2805.5427454783216,4.569720983505249
21,2.1e-05,-406585.90433883667,176.90951538085938,-406408.9948234558,2826.1878659151803,4.612438201904297
22,2.2e-05,-406574.9152857065,188.20639038085938,-406386.70889532566,3006.659170576359,4.6477577686309814
23,2.3e-05,-406620.7030599117,219.14315795898438,-406401.55990195274,3500.884237847075,4.679394721984863
24,2.4e-05,-406645.82324945927,184.114501953125,-406461.70874750614,2941.289903139019,4.707014799118042
25,2.4999999999999998e-05,-406701.47101426125,153.501708984375,-406547.96930527687,2452.2404371236057,4.734936475753784
26,2.6e-05,-406696.4496754408,133.06582641601562,-406563.3838490248,2125.770471844318,4.761986255645752
27,2.7e-05,-406683.0426847935,122.48103332519531,-406560.5616514683,1956.6749105790145,4.788716793060303
28,2.8e-05,-406666.3019833565,107.2762451171875,-406559.0257382393,1713.77340330412,4.815651178359985
29,2.9e-05,-406697.5885055065,154.06419372558594,-406543.52431178093,2461.2263164130854,4.843580484390259
30,2.9999999999999997e-05,-406699.77242171764,160.65704345703125,-406539.1153782606,2566.549265677288,4.872596979141235
31,3.1e-05,-406650.5109888315,117.29191589355469,-406533.21907293797,1873.7770478578418,4.89995551109314
32,3.2e-05,-406690.5119087696,119.58175659179688,-406570.9301521778,1910.3580083693134,4.927629232406616
33,3.2999999999999996e-05,-406737.56370294094,149.2154998779297,-406588.348203063,2383.7668327426763,4.95459771156311
34,3.4e-05,-406735.5253405571,144.68548583984375,-406590.83985471725,2311.3983641540776,4.980782508850098
35,3.5e-05,-406708.95912611485,159.289794921875,-406549.66933119297,2544.707019309533,5.007444620132446
36,3.6e-05,-406693.1858738661,140.61151123046875,-406552.5743626356,2246.315275874317,5.037170886993408
37,3.7e-05,-406639.2426587343,125.33792877197266,-406513.90472996235,2002.3147577544935,5.06561803817749
38,3.7999999999999995e-05,-406632.5847103596,132.43565368652344,-406500.14905667305,2115.7032546135924,5.09315037727356
39,3.9e-05,-406643.4117741585,185.24386596679688,-406458.1679081917,2959.3318658043354,5.120187997817993
40,3.9999999999999996e-05,-406671.7958230972,248.03517150878906,-406423.76065158844,3962.443685005843,5.147157669067383
41,4.1e-05,-406553.23693335056,179.92616271972656,-406373.31077063084,2874.3798022646606,5.175485134124756
42,4.2e-05,-406463.92002630234,114.83243560791016,-406349.0875906944,1834.486038978066,5.201632022857666
43,4.2999999999999995e-05,-406518.50567913055,158.01731872558594,-406360.48836040497,2524.378792318035,5.228895425796509
44,4.4e-05,-406609.36943364143,244.32044982910156,-406365.04898381233,3903.0997807857493,5.256029367446899
45,4.4999999999999996e-05,-406683.1779911518,327.4761047363281,-406355.7018864155,5231.538798749737,5.285285234451294
46,4.6e-05,-406645.70803165436,328.3685607910156,-406317.33947086334,5245.796078620695,5.312484264373779
47,4.7e-05,-406546.0222530365,272.8658447265625,-406273.15640830994,4359.121880633125,5.339726686477661
48,4.8e-05,-406524.3602576256,271.5589294433594,-406252.8013281822,4338.243477867644,5.367532968521118
49,4.9e-05,-406589.9699988365,378.48468017578125,-406211.48531866074,6046.417617756432,5.394550561904907
50,4.9999999999999996e-05,-406630.7346057892,405.6400146484375,-406225.09459114075,6480.233043773889,5.421330451965332
51,5.1e-05,-406475.4479403496,313.1947326660156,-406162.25320768356,5003.389168884751,5.448347568511963
52,5.2e-05,-406316.43815875053,332.8667907714844,-405983.57136797905,5317.65678640415,5.476334810256958
53,5.3e-05,-406198.64964079857,332.2574768066406,-405866.3921639919,5307.92279481943,5.5077526569366455
54,5.4e-05,-406406.16769218445,666.53564453125,-405739.6320476532,10648.126793625168,5.534675359725952
55,5.4999999999999995e-05,-406424.89574217796,873.1506958007812,-405551.7450463772,13948.870394421729,5.561588525772095
56,5.6e-05,-406168.2948439121,687.5379028320312,-405480.7569410801,10983.644798062065,5.588784694671631
57,5.6999999999999996e-05,-405880.53273034096,471.45465087890625,-405409.07807946205,7531.643568039544,5.615290641784668
58,5.8e-05,-405950.1033626795,561.5723876953125,-405388.53097498417,8971.304141106837,5.642324209213257
59,5.9e-05,-406326.1232010126,890.3369140625,-405435.7862869501,14223.425900425305,5.66890287399292
60,5.9999999999999995e-05,-406499.3240991831,944.90673828125,-405554.41736090183,15095.196843441758,5.69551157951355
61,6.1e-05,-406246.7270579338,662.7271728515625,-405583.99988508224,10587.285202229465,5.7228147983551025
62,6.2e-05,-405939.76182341576,383.782470703125,-405555.97935271263,6131.051569029803,5.749572992324829
63,6.3e-05,-405940.30504751205,388.75714111328125,-405551.5479063988,6210.52356984244,5.77661657333374
64,6.4e-05,-406131.46995961666,551.8641357421875,-405579.6058238745,8816.211613663636,5.803932428359985
65,6.5e-05,-406212.5055809021,601.13427734375,-405611.37130355835,9603.318378647316,5.8413918018341064
66,6.599999999999999e-05,-406231.4842660427,620.8232421875,-405610.6610238552,9917.856086887838,5.8763813972473145
67,6.7e-05,-406165.6405694485,587.7647705078125,-405577.87579894066,9389.736096701363,5.913886547088623
68,6.8e-05,-406179.55393338203,581.9989013671875,-405597.55503201485,9297.624435174239,5.952528476715088
69,6.9e-05,-406209.120762825,632.543701171875,-405576.57706165314,10105.094285428406,5.990658521652222
70,7e-05,-406223.5871543884,640.5198974609375,-405583.0672569275,10232.516652279484,6.025377988815308
71,7.099999999999999e-05,-406243.78583192825,633.56640625,-405610.21942567825,10121.432336414438,6.0626561641693115
72,7.2e-05,-406195.6852698326,616.900390625,-405578.7848792076,9855.187239133347,6.09675407409668
73,7.3e-05,-406080.0879020691,685.86767578125,-405394.22022628784,10956.962369962712,6.13096809387207
74,7.4e-05,-405992.6495513916,722.2652587890625,-405270.38429260254,11538.425765099284,6.16308069229126
75,7.5e-05,-405902.5705215931,666.3377075195312,-405236.23281407356,10644.964684568848,6.1934545040130615
76,7.599999999999999e-05,-405925.06811475754,673.291015625,-405251.77709913254,10756.046075263628,6.223395109176636
77,7.7e-05,-406178.0604672432,936.1538696289062,-405241.9065976143,14955.366879383875,6.253476858139038
78,7.8e-05,-406207.9582891464,1001.4998779296875,-405206.45841121674,15999.290917884982,6.283278226852417
79,7.9e-05,-406096.5415635109,953.51416015625,-405143.02740335464,15232.703247252584,6.313279867172241
80,7.999999999999999e-05,-406192.39066147804,1030.43701171875,-405161.9536497593,16461.57117574985,6.343334436416626
81,8.099999999999999e-05,-406306.50572681427,1111.22900390625,-405195.276722908,17752.249902057247,6.3734519481658936
82,8.2e-05,-406053.14411354065,925.3079833984375,-405127.8361301422,14782.10026908506,6.403440713882446
83,8.3e-05,-405683.14057159424,582.029052734375,-405101.11151885986,9298.106113211188,6.4332404136657715
84,8.4e-05,-405481.014734745,430.88470458984375,-405050.1300301552,6883.525293134223,6.4636194705963135
85,8.499999999999999e-05,-405629.4579527378,558.5642700195312,-405070.8936827183,8923.248468938862,6.494351148605347
86,8.599999999999999e-05,-405920.31214261055,825.638671875,-405094.67347073555,13189.850139264538,6.524912118911743
87,8.7e-05,-406069.9568309784,1007.77734375,-405062.1794872284,16099.575505132141,6.554839134216309
88,8.8e-05,-406112.59860801697,1110.46728515625,-405002.1313228607,17740.081193755406,6.585700750350952
89,8.9e-05,-406089.90273809433,1081.825927734375,-405008.07681035995,17282.52606092505,6.615431308746338
90,8.999999999999999e-05,-405868.9691205025,820.49609375,-405048.4730267525,13107.69575731902,6.645550966262817
91,9.099999999999999e-05,-405675.4124674797,607.3097534179688,-405068.10271406174,9701.973646056946,6.675546169281006
92,9.2e-05,-405572.9032449722,525.2464599609375,-405047.6567850113,8390.985462600582,6.705816984176636
93,9.3e-05,-405769.4916372299,714.1212158203125,-405055.3704214096,11408.32199217208,6.737977981567383
94,9.4e-05,-406103.68017959595,1042.67578125,-405061.00439834595,16657.0895562535,6.768180847167969
95,9.499999999999999e-05,-406297.6377902031,1278.121826171875,-405019.5159640312,20418.41779121806,6.7977800369262695
96,9.6e-05,-406258.2769627571,1284.951904296875,-404973.32505846024,20527.530542324588,6.827901363372803
97,9.7e-05,-405986.5238389969,1074.440673828125,-404912.08316516876,17164.54419357545,6.857760190963745
98,9.8e-05,-405746.92120170593,938.0362548828125,-404808.8849468231,14985.438604763453,6.887917518615723
99,9.9e-05,-405668.9798822403,967.4277954101562,-404701.55208683014,15454.978160168965,6.917438983917236
100,9.999999999999999e-05,-405585.2989048958,924.967529296875,-404660.3313755989,14776.661402505919,6.9477856159210205

@raimis
Copy link
Collaborator Author

raimis commented Jul 28, 2021

I observe the same with ACEMD. So, probably the problem is the pre-trained model or some bug in TorchMD-Net.

@raimis
Copy link
Collaborator Author

raimis commented Jul 28, 2021

Visualised the trajectory, the molecule just explodes literally.

@raimis
Copy link
Collaborator Author

raimis commented Jul 28, 2021

@PhilippThoelke would it be possible to train a less "explosive" model?

@PhilippThoelke
Copy link
Collaborator

PhilippThoelke commented Jul 28, 2021

I tried some small simulations myself and while visualizing I found that the simulation step before the "explosion" usually has two hydrogens that are very close to each other. In one of the runs I checked the distance, which turned out to be 0.11A. I then compared this to the minimum distance two atoms ever are in the MD17 aspirin dataset, which is 0.89A. The dataset I trained the model on might just be not very good for simulation. I can start training a model on the ANI dataset, which probably makes it easier to compare to the ANI model as well. This will however take some time as the ANI dataset is much larger.

@raimis
Copy link
Collaborator Author

raimis commented Jul 29, 2021

@PhilippThoelke yes, I think that training with the ANI data is the easiest solution. Anyway, I don't need anything very accurate, just good enough that simulation stays stable and looks physical.

@PhilippThoelke
Copy link
Collaborator

I just merged a couple of changes into main, including two new model checkpoints from the ANI1 dataset. One is from a Transformer model and the other one is an equivariant Transformer checkpoint. The equivariant model currently only works with TorchScript on the PyTorch Geometric main branch as they had a bug that was only recently fixed, however, it has a lower loss than the Transformer checkpoint. So for testing I recommend using the Transformer checkpoint instead of the equivariant one so you don't have to install PyTorch Geometric from GitHub.

I tested simulating with both models using TorchMD and both are capable of simulating aspirin without "explosions".

Since the ANI1 dataset only includes energies and not forces, the model checkpoint has set the derivative flag to False. In TorchMD this does not change anything because the torchmdnet.calculators.External module overwrites the flag during loading. If you want to load the model yourself and enable force computation you will have to pass derivative=True to the load_model function.

@raimis
Copy link
Collaborator Author

raimis commented Aug 4, 2021

I have tried to run aspirin with ANI1-transformer:

  • The MD simulations with TorchMD don't "explode" within 100 ps. However, it seems that TorchMD doesn't remove the translation of COM and the molecule just flies away. Also, there is no flag to disable thermostat, so the molecule may not "explode", because it is "frozen".
  • The MD simulations with ACEMD "explode" within ~10 ps, which is better from the previous networks. It seems that some bond start to elongate non-physically. Also, there is possibility, that the problem caused by OpenMM-Torch/PyTorch-Geometric incompatibility.
  • The NNP/MM simulations with ACEMD "explode" immediately (< 1 ps).

@PhilippThoelke
Copy link
Collaborator

I just pushed the most recent checkpoints from training on ANI1, which at least have better loss than the ones you tested. The models are also still training and haven't converged yet. It might also make sense to try the equivariant Transformer as it has better loss. There hasn't been a new torch-geometric release yet so the TorchScript fix is still only on their main branch.

Do you have any ideas why it might explode? What is the difference between the ACEMD MD simulation and ACEMD NNP/MD simulation?

What do you mean by OpenMM-Torch/PyTorch-Geometric incompatibility, how did you write the interface?

@raimis
Copy link
Collaborator Author

raimis commented Aug 5, 2021

Thanks @PhilippThoelke, I'll try with the new model.

Do you have any ideas why it might explode? What is the difference between the ACEMD MD simulation and ACEMD NNP/MD simulation?

MD is just a molecule in vacuum. NNP/MM adds solvent at MM level.

What do you mean by OpenMM-Torch/PyTorch-Geometric incompatibility, how did you write the interface?

Current PyTorch Geometric packages are not compatible with conda-forge packages, so I had to rebuild PyG.

@raimis
Copy link
Collaborator Author

raimis commented Sep 28, 2021

I have managed to run the latest checkpoint of ANI1-transformer with ACEMD on GPU.

  • The simulations of aspirin is stable after ~0.1 ns and keeps running
  • Speed ~10 ns/day on GTX 1080 Ti

Meanwhile ANI1-equivariant_transformer fails with the following error:

The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The size of tensor a (128) must match the size of tensor b (384) at non-singleton dimension 2

@raimis
Copy link
Collaborator Author

raimis commented Sep 28, 2021

For some reason that issues does not manifest outside of ACEMD.

@giadefa
Copy link
Contributor

giadefa commented Sep 28, 2021 via email

@PhilippThoelke
Copy link
Collaborator

At what point does the error occur? During loading or when running inference? Could you maybe share the code snippet where the error occured?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants