Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error occur when save and reload the model #19

Closed
zhengzetao opened this issue Jan 9, 2018 · 2 comments
Closed

error occur when save and reload the model #19

zhengzetao opened this issue Jan 9, 2018 · 2 comments

Comments

@zhengzetao
Copy link

Hi Thomas,
A problem occur when i learning the gcn.I tried to save the model and reload it at the end of the train.py,

model.save('gcn_model.h5')
print('save successfully')
model2 = load_model('gcn_model.h5')
preds = model.predict(graph, batch_size=A.shape[0])
print(preds)

and the result as follow:


...
Epoch: 0196 train_loss= 0.4453 train_acc= 0.9714 val_loss= 0.8001 val_acc= 0.8233 time= 0.0692
Epoch: 0197 train_loss= 0.4434 train_acc= 0.9714 val_loss= 0.7983 val_acc= 0.8233 time= 0.0692
Epoch: 0198 train_loss= 0.4415 train_acc= 0.9714 val_loss= 0.7969 val_acc= 0.8200 time= 0.0672
Epoch: 0199 train_loss= 0.4394 train_acc= 0.9714 val_loss= 0.7952 val_acc= 0.8200 time= 0.0672
Epoch: 0200 train_loss= 0.4373 train_acc= 0.9714 val_loss= 0.7932 val_acc= 0.8167 time= 0.0762
Test set results: loss= 0.8567 accuracy= 0.8000
Traceback (most recent call last):
  File "train.py", line 107, in <module>
    model.save('gcn_model.h5')
  File "D:\Anaconda3\lib\site-packages\keras\engine\topology.py", line 2553, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "D:\Anaconda3\lib\site-packages\keras\models.py", line 107, in save_model
    'config': model.get_config()
  File "D:\Anaconda3\lib\site-packages\keras\engine\topology.py", line 2326, in get_config
    layer_config = layer.get_config()
  File "D:\Anaconda3\lib\site-packages\kegra-0.0.1-py3.5.egg\kegra\layers\graph.py", line 80, in get_config
AttributeError: 'VarianceScaling' object has no attribute '__name__'

I run the code on windows 10,keras version==2.0.5,tensorflow==1.3.0, I have try many other methods but they didn't work.
Look forward your reply!

@tkipf
Copy link
Owner

tkipf commented Jan 15, 2018

Thanks for catching this. Looks like keras (again) updated its layer API. I have pushed the necessary changes to this repo. Model saving now works out of the box. In order to load a GCN model, you need to pass the following argument to the load_model function:

from keras.models import load_model
from kegra.layers.graph import GraphConvolution
model = load_model('gcn_model.h5', custom_objects={'GraphConvolution': GraphConvolution})

Hope this helps.

@zhengzetao
Copy link
Author

Thanks! It works well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants