Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Size mismatch when loading pre-trained models #37

Open
Rich2333 opened this issue May 6, 2022 · 6 comments
Open

Size mismatch when loading pre-trained models #37

Rich2333 opened this issue May 6, 2022 · 6 comments

Comments

@Rich2333
Copy link

Rich2333 commented May 6, 2022

Hi,

When I try to load pre-trained models to test predict.py, I was noticed as follows:

python predict.py pre-trained/final-energy-per-atom.pth.tar mp/
=> loading model params 'pre-trained/final-energy-per-atom.pth.tar'
=> loaded model params 'pre-trained/final-energy-per-atom.pth.tar'
=> loading model 'pre-trained/final-energy-per-atom.pth.tar'
Traceback (most recent call last):
File "E:\cgcnn-master\predict.py", line 298, in
main()
File "E:\cgcnn-master\predict.py", line 94, in main
model.load_state_dict(checkpoint['state_dict'])
File "C:\ProgramData\Anaconda3\envs\cgcnn1\lib\site-packages\torch\nn\modules\module.py", line 1497, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CrystalGraphConvNet:
size mismatch for convs.0.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
size mismatch for convs.1.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
size mismatch for convs.2.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
size mismatch for convs.3.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).

btw, then I tried to train my own model and use it to predict. The errors above didn't show up, but I got a TOO large MAE.

(cgcnn) E:\cgcnn-master>python predict.py E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar mp/
=> loading model params 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar'
=> loaded model params 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar'
=> loading model 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar'
=> loaded model 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar' (epoch 484, validation 0.05862389877438545)
C:\ProgramData\Anaconda3\envs\cgcnn\lib\site-packages\pymatgen\io\cif.py:1155: UserWarning: Issues encountered while parsing CIF: Some fractional coordinates rounded to ideal values to avoid issues with finite precision.
warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings))
Test: [0/74] Time 26.633 (26.633) Loss inf (inf) MAE 5.977 (5.977)
Test: [10/74] Time 24.787 (27.052) Loss inf (inf) MAE 6.005 (6.013)
Test: [20/74] Time 28.383 (28.096) Loss inf (inf) MAE 5.941 (6.010)
Test: [30/74] Time 31.305 (28.518) Loss inf (inf) MAE 6.081 (6.008)
Test: [40/74] Time 30.491 (29.037) Loss inf (inf) MAE 5.860 (6.010)
Test: [50/74] Time 35.822 (29.651) Loss inf (inf) MAE 6.035 (6.008)
Test: [60/74] Time 33.488 (30.191) Loss inf (inf) MAE 6.033 (6.012)
Test: [70/74] Time 34.823 (30.565) Loss inf (inf) MAE 5.955 (6.008)
** MAE 6.009

Thanks for your attention!

@txie-93
Copy link
Owner

txie-93 commented May 7, 2022

Thanks for reaching out. Can you provide more details? What is your train/test data? What are other hyperparameters?

@Rich2333
Copy link
Author

Rich2333 commented May 8, 2022

My train/test data is crystal structures downloaded from Materials Project database (using MPRester API).
In the case of using my own trained model, I use same data to train and predict. And I'm quite confused that the prediction MAE(e.g. 5.977) is 100 times more than the training MAE(e.g. 0.0586).

As for parameters, I think I didn't change any default ones other than epochs:
python main.py --epochs 500 --train-ratio 0.6 --val-ratio 0.2 --test-ratio 0.2 mp/
python predict.py E:/cgcnn-master/trained_files/from_cmd/mp-2/mp_model_best.pth.tar mp/

@Rich2333
Copy link
Author

Rich2333 commented May 9, 2022

And for the size mismatch problem, I'm wondering if my environment is different from the environment of pre-trained models.

The packages in my vritual environment are listed as follows:

#Name                    Version                   Build  Channel
ase                       3.22.1             pyhd8ed1ab_1    conda-forge
blas                      2.114                       mkl    conda-forge
blas-devel                3.9.0              14_win64_mkl    conda-forge
brotli                    1.0.9                h8ffe710_7    conda-forge
brotli-bin                1.0.9                h8ffe710_7    conda-forge
brotlipy                  0.7.0           py310he2412df_1004    conda-forge
bzip2                     1.0.8                h8ffe710_4    conda-forge
ca-certificates           2021.10.8            h5b45459_0    conda-forge
certifi                   2021.10.8       py310h5588dad_2    conda-forge
cffi                      1.15.0          py310hcbf9ad4_0    conda-forge
cftime                    1.6.0           py310h2873277_1    conda-forge
charset-normalizer        2.0.12             pyhd8ed1ab_0    conda-forge
click                     8.1.3           py310h5588dad_0    conda-forge
colorama                  0.4.4              pyh9f0ad1d_0    conda-forge
cryptography              36.0.2          py310ha857299_1    conda-forge
cudatoolkit               11.3.1               h59b6b97_2
curl                      7.83.0               h789b8ee_0    conda-forge
cycler                    0.11.0             pyhd8ed1ab_0    conda-forge
cython                    0.29.28         py310h8a704f9_2    conda-forge
double-conversion         3.2.0                h0e60522_0    conda-forge
eigen                     3.4.0                h2d74725_0    conda-forge
expat                     2.4.8                h39d44d4_0    conda-forge
ffmpeg                    4.3.1                ha925a31_0    conda-forge
flask                     2.1.2              pyhd8ed1ab_1    conda-forge
fonttools                 4.33.3          py310he2412df_0    conda-forge
freetype                  2.10.4               h546665d_1    conda-forge
future                    0.18.2          py310h5588dad_5    conda-forge
gl2ps                     1.4.2                h0597ee9_0    conda-forge
glew                      2.1.0                h39d44d4_2    conda-forge
hdf4                      4.2.15               h0e5069d_3    conda-forge
hdf5                      1.12.1          nompi_h2a0e4a3_104    conda-forge
icu                       69.1                 h0e60522_0    conda-forge
idna                      3.3                pyhd8ed1ab_0    conda-forge
importlib-metadata        4.11.3          py310h5588dad_1    conda-forge
intel-openmp              2022.0.0          h57928b3_3663    conda-forge
itsdangerous              2.1.2              pyhd8ed1ab_0    conda-forge
jbig                      2.1               h8d14728_2003    conda-forge
jinja2                    3.1.2              pyhd8ed1ab_0    conda-forge
joblib                    1.1.0              pyhd8ed1ab_0    conda-forge
jpeg                      9e                   h8ffe710_1    conda-forge
jsoncpp                   1.9.5                h2d74725_1    conda-forge
kiwisolver                1.4.2           py310h476a331_1    conda-forge
krb5                      1.19.3               h1176d77_0    conda-forge
latexcodec                2.0.1              pyh9f0ad1d_0    conda-forge
lcms2                     2.12                 h2a16943_0    conda-forge
lerc                      3.0                  h0e60522_0    conda-forge
libblas                   3.9.0              14_win64_mkl    conda-forge
libbrotlicommon           1.0.9                h8ffe710_7    conda-forge
libbrotlidec              1.0.9                h8ffe710_7    conda-forge
libbrotlienc              1.0.9                h8ffe710_7    conda-forge
libcblas                  3.9.0              14_win64_mkl    conda-forge
libclang                  13.0.1          default_h81446c8_0    conda-forge
libcurl                   7.83.0               h789b8ee_0    conda-forge
libdeflate                1.10                 h8ffe710_0    conda-forge
libffi                    3.4.2                h8ffe710_5    conda-forge
libiconv                  1.16                 he774522_0    conda-forge
liblapack                 3.9.0              14_win64_mkl    conda-forge
liblapacke                3.9.0              14_win64_mkl    conda-forge
libnetcdf                 4.8.1           nompi_h1cc8e9d_102    conda-forge
libogg                    1.3.4                h8ffe710_1    conda-forge
libpng                    1.6.37               h1d00b33_2    conda-forge
libssh2                   1.10.0               h680486a_2    conda-forge
libtheora                 1.1.1             h8d14728_1005    conda-forge
libtiff                   4.3.0                hc4061b1_3    conda-forge
libuv                     1.43.0               h8ffe710_0    conda-forge
libwebp                   1.2.2                h57928b3_0    conda-forge
libwebp-base              1.2.2                h8ffe710_1    conda-forge
libxcb                    1.13              hcd874cb_1004    conda-forge
libxml2                   2.9.14               hf5bbc77_0    conda-forge
libzip                    1.8.0                hfed4ece_1    conda-forge
libzlib                   1.2.11            h8ffe710_1014    conda-forge
loguru                    0.6.0           py310h5588dad_1    conda-forge
lz4-c                     1.9.3                h8ffe710_1    conda-forge
m2w64-gcc-libgfortran     5.3.0                         6    conda-forge
m2w64-gcc-libs            5.3.0                         7    conda-forge
m2w64-gcc-libs-core       5.3.0                         7    conda-forge
m2w64-gmp                 6.1.0                         2    conda-forge
m2w64-libwinpthread-git   5.0.0.4634.697f757               2    conda-forge
markupsafe                2.1.1           py310he2412df_1    conda-forge
matplotlib-base           3.5.2           py310h79a7439_0    conda-forge
mkl                       2022.0.0           h0e2418a_796    conda-forge
mkl-devel                 2022.0.0           h57928b3_797    conda-forge
mkl-include               2022.0.0           h0e2418a_796    conda-forge
monty                     2022.4.26          pyhd8ed1ab_0    conda-forge
mpmath                    1.2.1              pyhd8ed1ab_0    conda-forge
msys2-conda-epoch         20160418                      1    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
netcdf4                   1.5.8           nompi_py310h5489b47_101    conda-forge
networkx                  2.8                pyhd8ed1ab_0    conda-forge
numpy                     1.22.3          py310hed7ac4c_2    conda-forge
openjpeg                  2.4.0                hb211442_1    conda-forge
openssl                   1.1.1o               h8ffe710_0    conda-forge
packaging                 21.3               pyhd8ed1ab_0    conda-forge
palettable                3.3.0                      py_0    conda-forge
pandas                    1.4.2           py310hf5e1058_1    conda-forge
pillow                    9.1.0           py310h767b3fd_2    conda-forge
pip                       22.0.4             pyhd8ed1ab_0    conda-forge
plotly                    5.7.0              pyhd8ed1ab_0    conda-forge
proj                      9.0.0                h1cfcee9_1    conda-forge
pthread-stubs             0.4               hcd874cb_1001    conda-forge
pugixml                   1.11.4               h0e60522_0    conda-forge
pybtex                    0.24.0             pyhd8ed1ab_2    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pymatgen                  2022.4.26       py310h476a331_0    conda-forge
pyopenssl                 22.0.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.0.8              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1           py310h5588dad_5    conda-forge
python                    3.10.4          h9a09f29_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python_abi                3.10                    2_cp310    conda-forge
pytorch                   1.11.0          py3.10_cuda11.3_cudnn8_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2022.1             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0             py310he2412df_4    conda-forge
qt                        5.12.9               h556501e_6    conda-forge
requests                  2.27.1             pyhd8ed1ab_0    conda-forge
ruamel.yaml               0.17.21         py310he2412df_1    conda-forge
ruamel.yaml.clib          0.2.6           py310he2412df_1    conda-forge
scikit-learn              1.0.2           py310h4dafddf_0    conda-forge
scipy                     1.8.0           py310h33db832_1    conda-forge
setuptools                62.1.0          py310h5588dad_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
spglib                    1.16.4          py310h2873277_0    conda-forge
sqlite                    3.38.4               h8ffe710_0    conda-forge
sympy                     1.10.1          py310h5588dad_0    conda-forge
tabulate                  0.8.9              pyhd8ed1ab_0    conda-forge
tbb                       2021.5.0             h2d74725_1    conda-forge
tbb-devel                 2021.5.0             h2d74725_1    conda-forge
tenacity                  8.0.1              pyhd8ed1ab_0    conda-forge
threadpoolctl             3.1.0              pyh8a188c0_0    conda-forge
tk                        8.6.12               h8ffe710_0    conda-forge
torchvision               0.12.0              py310_cu113    pytorch
tqdm                      4.64.0             pyhd8ed1ab_0    conda-forge
typing_extensions         4.2.0              pyha770c72_1    conda-forge
tzdata                    2022a                h191b570_0    conda-forge
ucrt                      10.0.20348.0         h57928b3_0    conda-forge
uncertainties             3.1.6              pyhd8ed1ab_0    conda-forge
unicodedata2              14.0.0          py310he2412df_1    conda-forge
urllib3                   1.26.9             pyhd8ed1ab_0    conda-forge
utfcpp                    3.2.1                h57928b3_0    conda-forge
vc                        14.2                 hb210afc_6    conda-forge
vs2015_runtime            14.29.30037          h902a5da_6    conda-forge
vtk                       9.1.0           qt_py310h99a8838_207    conda-forge
werkzeug                  2.1.2              pyhd8ed1ab_1    conda-forge
wheel                     0.37.1             pyhd8ed1ab_0    conda-forge
win32_setctime            1.1.0              pyhd8ed1ab_0    conda-forge
win_inet_pton             1.1.0           py310h5588dad_4    conda-forge
xorg-libxau               1.0.9                hcd874cb_0    conda-forge
xorg-libxdmcp             1.1.3                hcd874cb_0    conda-forge
xz                        5.2.5                h62dcd97_1    conda-forge
yaml                      0.2.5                h8ffe710_2    conda-forge
zipp                      3.8.0              pyhd8ed1ab_0    conda-forge
zlib                      1.2.11            h8ffe710_1014    conda-forge
zstd                      1.5.2                h6255e5f_0    conda-forge

@SANTKJD
Copy link

SANTKJD commented Aug 20, 2022

Dear author,
I have trained 40000+ cifs and I have set “tarin size 0.8”,“epoch 1000”,I think I will get the same result of yours ,but I just get the MAE (0.049).Only 300~400 cifs of all cifs are different from yours.Is my result correct within the margin of error?

@liaokkkkk
Copy link

yeah,i also found this problem

@liaokkkkk
Copy link

I know, it’s because there are differences in prediction codes between models trained using GPU and models trained using CPU. You can go to some tutorials online and it’s easier to solve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants