-
Notifications
You must be signed in to change notification settings - Fork 303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Size mismatch when loading pre-trained models #37
Comments
Thanks for reaching out. Can you provide more details? What is your train/test data? What are other hyperparameters? |
My train/test data is crystal structures downloaded from Materials Project database (using MPRester API). As for parameters, I think I didn't change any default ones other than epochs: |
And for the size mismatch problem, I'm wondering if my environment is different from the environment of pre-trained models. The packages in my vritual environment are listed as follows:
|
Dear author, |
yeah,i also found this problem |
I know, it’s because there are differences in prediction codes between models trained using GPU and models trained using CPU. You can go to some tutorials online and it’s easier to solve. |
Hi,
When I try to load pre-trained models to test predict.py, I was noticed as follows:
python predict.py pre-trained/final-energy-per-atom.pth.tar mp/
=> loading model params 'pre-trained/final-energy-per-atom.pth.tar'
=> loaded model params 'pre-trained/final-energy-per-atom.pth.tar'
=> loading model 'pre-trained/final-energy-per-atom.pth.tar'
Traceback (most recent call last):
File "E:\cgcnn-master\predict.py", line 298, in
main()
File "E:\cgcnn-master\predict.py", line 94, in main
model.load_state_dict(checkpoint['state_dict'])
File "C:\ProgramData\Anaconda3\envs\cgcnn1\lib\site-packages\torch\nn\modules\module.py", line 1497, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CrystalGraphConvNet:
size mismatch for convs.0.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
size mismatch for convs.1.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
size mismatch for convs.2.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
size mismatch for convs.3.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
btw, then I tried to train my own model and use it to predict. The errors above didn't show up, but I got a TOO large MAE.
(cgcnn) E:\cgcnn-master>python predict.py E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar mp/
=> loading model params 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar'
=> loaded model params 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar'
=> loading model 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar'
=> loaded model 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar' (epoch 484, validation 0.05862389877438545)
C:\ProgramData\Anaconda3\envs\cgcnn\lib\site-packages\pymatgen\io\cif.py:1155: UserWarning: Issues encountered while parsing CIF: Some fractional coordinates rounded to ideal values to avoid issues with finite precision.
warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings))
Test: [0/74] Time 26.633 (26.633) Loss inf (inf) MAE 5.977 (5.977)
Test: [10/74] Time 24.787 (27.052) Loss inf (inf) MAE 6.005 (6.013)
Test: [20/74] Time 28.383 (28.096) Loss inf (inf) MAE 5.941 (6.010)
Test: [30/74] Time 31.305 (28.518) Loss inf (inf) MAE 6.081 (6.008)
Test: [40/74] Time 30.491 (29.037) Loss inf (inf) MAE 5.860 (6.010)
Test: [50/74] Time 35.822 (29.651) Loss inf (inf) MAE 6.035 (6.008)
Test: [60/74] Time 33.488 (30.191) Loss inf (inf) MAE 6.033 (6.012)
Test: [70/74] Time 34.823 (30.565) Loss inf (inf) MAE 5.955 (6.008)
** MAE 6.009
Thanks for your attention!
The text was updated successfully, but these errors were encountered: