Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subclassing tf.keras.models.Model save_model saves h5py but not h5/hdf5 file types #29545

Closed
Ryandry1st opened this issue Jun 7, 2019 · 29 comments
Assignees
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@Ryandry1st
Copy link

Ryandry1st commented Jun 7, 2019

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Google Colab
  • TensorFlow installed from (source or binary): Pip
  • TensorFlow version (use command below): 2.0.0b
  • Python version: 3.6.7
  • GPU model and memory: Tesla T4

Describe the current behavior
-Using tf 2.0.0b-gpu on google colab.

While using the subclassing API for a subclassed layer and model, I was unable to use the model.save_model() function for h5 or hdf5 file types, but I could successfully save and load the model if it was saved as h5py file type. In the toy example being used it worked correctly, although this may not be the case. Note that the get_config method was implemented in the custom layer and custom model.

Describe the expected behavior
Either the save_model should always work (I believe this is a feature goal) and the documentation should reflect this, or if the save is likely to produce incorrect results it should raise an error and the documentation should continue to suggest that custom models can only be saved with the save_weights feature.

Code to reproduce the issue

import tensorflow as tf
from tensorflow import keras

class resblock(keras.layers.Layer):
    def __init__(self, n_layers, n_neurons, **kwargs):
        super().__init__(**kwargs)
        self.hidden = [keras.layers.Dense(n_neurons,
                                          activation='elu',
                                          kernel_initializer='he_normal')
                       for _ in range(n_layers)]
        
    def call(self, inputs):
        z = inputs
        for layer in self.hidden:
            z = layer(z)
        return inputs + z
    
    def get_config(self):
        base_config = super().get_config()
        return {**base_config}

class res_mod(keras.models.Model):
    def __init__(self, output_dim, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.f1 = keras.layers.Flatten()
        self.hidden1 = keras.layers.Dense(100, activation='elu', kernel_initializer='he_normal')
        self.b1 = resblock(2, 100)
        self.b2 = resblock(2, 100)
        self.output1 = keras.layers.Dense(output_dim, activation=keras.activations.get(activation))
        
    def call(self, inputs):
        z = self.f1(inputs)
        z = self.hidden1(z)
        for _ in range(4):
            z = self.b1(z)
        z = self.b2(z)
        return self.output1(z)
    
    def get_config(self):
        base_config = super().get_config()
        return{**base_config, "output_dim" : output_dim, "activation": activation}
    

model = res_mod(10, activation='softmax')
model.compile(loss='sparse_categorical_crossentropy', optimizer='nadam', metrics=['accuracy'])

model.fit(train, epochs=25, validation_data=test)
# This is able to save and works correctly, returning the trained model
model.save('custom_model.h5py')
del model
model = keras.models.load_model('custom_model.h5py', custom_objects={'resblock': resblock})

Other info / logs
This will raise an error that only sequential or functional models can be saved
model.save('custom_model.h5')

@gadagashwini-zz gadagashwini-zz self-assigned this Jun 11, 2019
@gadagashwini-zz gadagashwini-zz added 2.0.0-beta0 comp:keras Keras related issues labels Jun 11, 2019
@gadagashwini-zz
Copy link
Contributor

@Ryandry1st I tried reproducing the issue but looks code snippet is incomplete. Please provide minimal code snippet to reproduce issue reported here. Also provide the error message that you have got. Thanks!

@gadagashwini-zz gadagashwini-zz added the stat:awaiting response Status - Awaiting response from author label Jun 11, 2019
@Ryandry1st
Copy link
Author

The only code not provided was the actual data for the train and test data objects. It is possible to reproduce with the MNIST/ fashion MNIST or really any other dataset, simply change the 10 from the line
model = res_mod(10, activation='softmax')
to whatever you are using it for.
Assuming you have loaded train and test data then the code executes up to the save_model function, the error is produced by attempting to save the model as either h5 or hdf5 (while h5py works fine) and produces the following error:

NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using save_weights.

The entire traceback is attached in the text file.

Custom_Model_Save_Error.txt

@Ryandry1st
Copy link
Author

The point of this issue is that if it cannot be safely serialized then saving as an h5py either is handling it improperly, and should raise an error similar to the above, OR it works as it should with h5py (based on only this one test case and limited testing besides checking weights and accuracy) and the documentation should include this information, and probably adopt the method/recommend it.

@goldiegadde goldiegadde assigned k-w-w and unassigned gadagashwini-zz Jun 11, 2019
@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jun 12, 2019
@Ryandry1st
Copy link
Author

Any luck on this issue? One option for the data is of course just randomly generating x and y values. It does not matter what data is fed into the network, only that it is a trained subclassed network that is attempted to save

@ymodak
Copy link
Contributor

ymodak commented Aug 9, 2019

The code snippet you have provided looks incomplete. Can you please check with latest tf-2.0--nightly version?
pip install tf-2.0-nightly-preview

@Ryandry1st
Copy link
Author

Ryandry1st commented Aug 9, 2019

Here is a more basic code snippet that throws the error. I do want to stress though that the only difference is that I am providing some randomly generated x and y data, and removing layers from the custom model for simplicity. There is no real new code or any ideas and it was no stretch to reach the warning using the previous code, and using random data as was suggested a month ago.

x = tf.random.uniform((100,))
y = tf.random.uniform((100,))


class test_model(keras.models.Model):
    def __init__(self, output_dim, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.f1 = keras.layers.Flatten()
        self.hidden1 = keras.layers.Dense(100, activation='elu', kernel_initializer='he_normal', name="h1")
        self.output1 = keras.layers.Dense(output_dim, activation=keras.activations.get(activation))
        
    def call(self, inputs):
        z = self.f1(inputs)
        z = self.hidden1(z)
        return self.output1(z)
    
    def get_config(self):
        base_config = super().get_config()
        return{**base_config, "output_dim" : output_dim, "activation": activation}
    
model = test_model(1)
model.compile(loss='mse', optimizer='adam', metrics=['mae'])
model.fit(x, y, epochs=5)
print(model.weights[0])
model.save('custom_model.hdf5')
del model
model = keras.models.load_model('custom_model.hdf5')
print(model.weights[0])

You will see that this throws the error, but the following would not.

x = tf.random.uniform((100,))
y = tf.random.uniform((100,))


class test_model(keras.models.Model):
    def __init__(self, output_dim, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.f1 = keras.layers.Flatten()
        self.hidden1 = keras.layers.Dense(100, activation='elu', kernel_initializer='he_normal', name="h1")
        self.output1 = keras.layers.Dense(output_dim, activation=keras.activations.get(activation))
        
    def call(self, inputs):
        z = self.f1(inputs)
        z = self.hidden1(z)
        return self.output1(z)
    
    def get_config(self):
        base_config = super().get_config()
        return{**base_config, "output_dim" : output_dim, "activation": activation}
    
model = test_model(1)
model.compile(loss='mse', optimizer='adam', metrics=['mae'])
model.fit(x, y, epochs=5)
print(model.weights[0])
model.save('custom_model.h5py')
del model
model = keras.models.load_model('custom_model.h5py')
print(model.weights[0])

So somewhere along the way it should either be adopted, the error removed, or it should throw a warning using this file format as well.

@Ryandry1st
Copy link
Author

I also tried using the preview with the more complicated code and found that it throws a new error, which appears to be due to improperly recalling custom objects. This may be some of the premise behind the warning, but it did work when using the beta version and no longer works with the nightly preview.

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _convert_inputs_to_signature(inputs, input_signature, flat_input_signature)
   1701       flatten_inputs)):
   1702     raise ValueError("Python inputs incompatible with input_signature:\n%s" %
-> 1703                      format_error_message(inputs, input_signature))
   1704 
   1705   if need_packing:

ValueError: Python inputs incompatible with input_signature:
  inputs: (
    Tensor("res_mod_1_2/Cast:0", shape=(None, 16), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None, 16), dtype=tf.float64, name='input_1'))

@k-w-w
Copy link
Contributor

k-w-w commented Aug 14, 2019

In TensorFlow 2.0, you should save the model with save_format="h5" extension, otherwise it will default to saving to the SavedModel format. Another way to save to H5 is to set the extension to ".h5" or ".hdf5" or ".keras". The H5 format does not support saving Subclassed models, so the initial error stating that only Sequential and Functional models may be saved is intended.

(The documentation describes the argument on this page: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model#save)


I have an internal change that should fix the ValueError (Python inputs incompatible with input_signature), but I can't be sure because the examples above do not raise the error. I'll update this post when the change is copied to github.

k-w-w added a commit to k-w-w/tensorflow that referenced this issue Aug 15, 2019
@Ryandry1st
Copy link
Author

I agree with the save format, my question comes down to the use of h5py, which seems to allow the model to be saved, and it is able to load the model correctly then too. So the question is, should it not allow h5py save format with subclassed models, or if it does work correctly, then shouldn't this be a standard?

@k-w-w
Copy link
Contributor

k-w-w commented Sep 19, 2019

Calling model.save('custom_model.h5py') still saves out in the SavedModel format, even though the extension says .h5py (so it's not actually using h5py).

@k-w-w k-w-w closed this as completed Sep 19, 2019
@Ryandry1st
Copy link
Author

That is fine, however, that does not solve the actual issue. The issue is that there is no error or warning if this method is used to save the model, even though it explicitly states that you should not be able to save a subclassed models, yet it clearly allows it if you change the saved model extension to be ".h5py". So regardless of what is occurring in the background, there is a problem that subclassed models are being saved, unless this can be adopted for all instances.

@Ryandry1st
Copy link
Author

This problem has not been solved in rc0, and should be reopened.

@k-w-w
Copy link
Contributor

k-w-w commented Sep 19, 2019

I would like to get a better understanding of the issue is.

Perhaps docstring for the save_format is unclear? Currently, it says:

save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

@k-w-w
Copy link
Contributor

k-w-w commented Sep 19, 2019

The NotImplementedError error from an earlier post above states that HDF5 cannot save out subclassed models. The SavedModel format is able to handle these models, so you should not get an error.

@k-w-w k-w-w reopened this Sep 19, 2019
@Ryandry1st
Copy link
Author

I see, so the SavedModel format is able to handle these models. I believe this should be documented then in the warning, as the warning seems to suggest that there is not a clear method for saving these models outside of saving the weights independently from the model, which is obviously more complicated and prone to error

@k-w-w
Copy link
Contributor

k-w-w commented Sep 19, 2019

Specifying that the SavedModel format is able to handle custom models sounds reasonable and would remove the ambiguity. I'll add that to the warning message. Thanks for the suggestion!

@Ryandry1st
Copy link
Author

Ryandry1st commented Sep 19, 2019

Thank you for clearing that up, appreciate it!

@jvishnuvardhan jvishnuvardhan added type:bug Bug TF 2.0 Issues relating to TensorFlow 2.0 and removed TF 2.0.0-beta0 labels Oct 2, 2019
@syw2014
Copy link

syw2014 commented Oct 21, 2019

Now what is the best way to save model created in subclassed ? Were there some examples for how to save and load models ? thank you!

@Ryandry1st
Copy link
Author

I understand that so long as you use the save model format, which is done by calling
model.save("NameOfYourModel", save_format='tf')
This should be the most clear method of successfully saving a subclassed model, while also being clear of the format being used. If you specify the file format, e.g. .h5, .keras, then it will not work for a subclassed model. If you specify anything else, as I was experiencing with .h5py, it will default to using save_format=tf

@jvishnuvardhan
Copy link
Contributor

jvishnuvardhan commented Jan 23, 2020

Closing this issue since I understand it to be resolved by updating the warning, but please let me know if I'm mistaken.

The current warning is clearly mentions what to do to save subclass model. Please check the gist here.

Current Warning message is as follows
NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using save_weights.

@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@hanzigs
Copy link

hanzigs commented Jul 20, 2020

Hi @Ryandry1st, may I have help on this similar issue please, using tf==2.2.0
#41543

I have tried

model.save("NameOfModel", save_format='tf')

and loaded back with

loaded_tfkmodel = tf.keras.models.load_model('./NameOfModel')

still i get

    ValueError: Python inputs incompatible with input_signature:
      inputs: (
        Tensor("IteratorGetNext:0", shape=(None, 2), dtype=int32))
      input_signature: (
        TensorSpec(shape=(None, 2), dtype=tf.int64, name='input_1'))

and I tried with tf=2.0.0 version, getting same error, I need of help please
Thanks

@hanzigs
Copy link

hanzigs commented Jul 20, 2020

Found it, Thanks
Actually recreating the model with

keras.models.load_model('path_to_my_model')

didn't work for me

First we have to save_weights from the built model

model.save_weights('model_weights', save_format='tf')

Then
we have to initiate a new instance for the subclass Model then compile and train_on_batch with one record and load_weights of built model

loaded_model = ClassifierModel(parameter)
loaded_model.compile(parameters)
loaded_model.train_on_batch(x_train[:1], y_train[:1])
loaded_model.load_weights('model_weights')

This work perfectly in TensorFlow==2.2.0

@Ryandry1st
Copy link
Author

Glad you got that figured out @hanzigs, the way you are doing it is one that I use as well when I have custom classes/layers which require the subclassing method. It seems like with the definition of base_config it should work to save the entire model as a SavedModel format, but this has not been my experience. As you said, I just save the weights and then instantiate a new model and load the weights.

@GabrielLaw
Copy link

GabrielLaw commented Sep 30, 2020

@Ryandry1st Similar issue, I'm trying to save a subclass model and convert it to TensorFlowJS layers model format. This only works if I'm able to get the model to be saved as h5 but I get the error about not being able to save a non-sequential or functional model. The insistence for me to have the layers model format is so I can retrain using TFJS. Any suggestions?

@Ryandry1st
Copy link
Author

I believe you cannot save the model in its entirety as an h5 format because this ignores the information of the control flow/ops that are in the subclasses model. Instead, you need to save the weights in h5 format then make a new model the same way, train_on_batch as shown above, and then load the weights.

I have not had to deploy a subclasses model in TF JS so I could not say if this is usable or not.

@meg261995
Copy link

meg261995 commented Jan 22, 2022

Hi,

Adding to @hanzigs , When I train the model and the save the weights using,

model.save_weights(path)

And load it in another session using,

model1 = DCN( )
model1.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))
model1.load_weights(path)

The prediction got in the first session from model and the prediction got from the second session model1 is completely different. How do I solve this ??

NOTE: If I save the model using,
model.save(os.path.join(model_path,'my_model'),save_format='tf')

Then I get the below warning.

WARNING:absl:Found untraced functions such as ranking_layer_call_fn, ranking_layer_call_and_return_conditional_losses, dense_layer_call_fn, dense_layer_call_and_return_conditional_losses, ranking_layer_call_fn while saving (showing 5 of 10). These functions will not be directly callable after loading.

This warning cannot be ignored because when I load the model and try to evaluate it with the same test set , I get the following
error
ValueError: Exception encountered when calling layer "dcn" (type DCN).
Could not find matching concrete function to call loaded from the SavedModel.

@ayalaall
Copy link

ayalaall commented Jan 12, 2023

What should I do if I get this error when saving the model using ModelCheckpoint with save_weights_only=False ?

The error:
NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using save_weights.

Thanks

@therc01
Copy link

therc01 commented Apr 24, 2023

Found it, Thanks Actually recreating the model with

keras.models.load_model('path_to_my_model')

didn't work for me

First we have to save_weights from the built model

model.save_weights('model_weights', save_format='tf')

Then we have to initiate a new instance for the subclass Model then compile and train_on_batch with one record and load_weights of built model

loaded_model = ClassifierModel(parameter)
loaded_model.compile(parameters)
loaded_model.train_on_batch(x_train[:1], y_train[:1])
loaded_model.load_weights('model_weights')

This work perfectly in TensorFlow==2.2.0

I'm doing something similar but it didnt work for me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
None yet
Development

No branches or pull requests