Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving entire model #23

Closed
anasvaf opened this issue Jun 2, 2020 · 27 comments
Closed

Saving entire model #23

anasvaf opened this issue Jun 2, 2020 · 27 comments
Assignees
Labels
question ❓ Further information is requested
Projects

Comments

@anasvaf
Copy link

anasvaf commented Jun 2, 2020

Hello I tried to save the entire model for Tacotron-2, instead of the weights, as an h5 file. However, I am getting the following error

NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using save_weights.

I used the following code to successfully load the weights:

tacotron2 = TFTacotron2(config=fs_config, name="tacotron2", training=False)
tacotron2._build()  
tacotron2.load_weights("tacotron2.h5") 

Then I tried to call tacotron2.save("full_tacotron2.h5") and I got the afformentioned error. Should I modify the trainers/base_trainer.py as follows and re-train or is there another way to save the entire model as an h5 file?

def save_checkpoint(self):
      """Save checkpoint."""
      self.ckpt.steps.assign(self.steps)
      self.ckpt.epochs.assign(self.epochs)
      self.ckp_manager.save(checkpoint_number=self.steps)
      self.model.save_weights(self.saved_path + 'model-{}.h5'.format(self.steps))
      self.model.save_model(self.saved_path + 'model-total{}.h5'.format(self.steps))
@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Jun 2, 2020

@anasvaf i won't suggest you save entire model h5, It does not guarantee success. After you load weight from h5 file, you can save it into pb file then do inference on server. Or you can try save_format="tf". On TF 2, we won't use h5 to save entire model anymore :)), see https://www.tensorflow.org/api_docs/python/tf/saved_model/save.

@dathudeptrai dathudeptrai self-assigned this Jun 2, 2020
@dathudeptrai dathudeptrai added the question ❓ Further information is requested label Jun 2, 2020
@dathudeptrai dathudeptrai added this to In progress in Tacotron 2 Jun 2, 2020
@anasvaf
Copy link
Author

anasvaf commented Jun 3, 2020

@dathudeptrai Thank you for the prompt response!
Unfortunately, it did not work either when I added the save_format="tf".
I have read the API docs regarding the saved model, however I am still confused how to go from the saved HDF5 weights to a frozen pb file. I would still need the names for the input and output tensors from the graph. Is this right?

@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Jun 3, 2020

@anasvaf i will do it for you tonight :))). You just want to know how to save to pb ?

@anasvaf
Copy link
Author

anasvaf commented Jun 3, 2020

@dathudeptrai Yes, saving the model as pb would be really helpful, so I can use post-training quantization and try to import it on a raspberry to check the latency on the inference of mel prediction :)
Thank you for all the help!

@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Jun 4, 2020

import yaml
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from tensorflow_tts.processor.ljspeech import LJSpeechProcessor
from tensorflow_tts.processor.ljspeech import symbols, _symbol_to_id

from tensorflow_tts.configs import Tacotron2Config
from tensorflow_tts.models import TFTacotron2

with open('./tacotron2.v1.yaml') as f:
    config = yaml.load(f, Loader=yaml.Loader)

config = Tacotron2Config(**config["tacotron2_params"])
tacotron2 = TFTacotron2(config=config, training=False, name="tacotron2")

input_text = "i think it work"
input_ids = LJSpeechProcessor(None, "english_cleaners").text_to_sequence(input_text.lower())
input_ids = np.concatenate([input_ids, [len(symbols) - 1]], -1)

# pass input to build model.
decoder_output, mel_outputs, stop_token_prediction, alignment_history = tacotron2.inference(
        input_ids=np.expand_dims(input_ids, 0),
        input_lengths=np.array([len(input_ids)]),
        speaker_ids=np.array([0], dtype=np.int32),
        maximum_iterations=4000,
        use_window_mask=False,
        win_front=6,
        win_back=6,
)
# load weight and save to pb.
tacotron2.load_weights("./tacotron2.v1/checkpoints/model-120000.h5")
tf.saved_model.save(tacotron2, "./test_saved")

# load and inference again to check.
tacotron2 = tf.saved_model.load("./test_saved")
decoder_output, mel_outputs, stop_token_prediction, alignment_history = tacotron2.inference(
        input_ids=np.expand_dims(input_ids, 0),
        input_lengths=np.array([len(input_ids)]),
        speaker_ids=np.array([0], dtype=np.int32),
        maximum_iterations=4000,
        use_window_mask=False,
        win_front=6,
        win_back=6,
)

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
ax.set_title(f'Alignment steps')
im = ax.imshow(
    alignment_history[0].numpy(),
    aspect='auto',
    origin='lower',
    interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
plt.show()
plt.close()


@anasvaf let try :)). some how tacotron._build() make it can not to be able save to pb. :))

@anasvaf
Copy link
Author

anasvaf commented Jun 4, 2020

@dathudeptrai Thank you so much! :) works like a charm :). I can get the pb file. Do you know the name of the mel_outputs tensor? I mean in the variables.data file what should be the name, as a string? Something like: "post_net/tf_tacotron_conv_batch_norm_9/batch_norm_._4/moving_variance"

@dathudeptrai
Copy link
Collaborator

@anasvaf why you need the name, you can use tf.saved_model.load and do inference as above code ?. you can print(mel_outputs) to get a name.

@manmay-nakhashi
Copy link

tensorflow.python.saved_model.nested_structure_coder.NotEncodableError: No encoder for object [tf.Tensor(2000, shape=(), dtype=int32)] of type [<class 'tensorflow.python.framework.ops.EagerTensor'>].
i am getting this error , can it be version issue ??

@dathudeptrai
Copy link
Collaborator

@manmay-nakhashi i fixed it today :)) pls git pull :D

@manmay-nakhashi
Copy link

ok thanks

@dathudeptrai
Copy link
Collaborator

@manmay-nakhashi @anasvaf i think you guys need "watch" my repo, to be sure you guys won't missing any update. I will update multiban melgan soon, it's 3x faster than melgan and quality is better.

@manmay-nakhashi
Copy link

sure @dathudeptrai : ))

@anasvaf
Copy link
Author

anasvaf commented Jun 5, 2020

@dathudeptrai I will try printing the tf.Tensor to check the its node name. The reason that I asked is that if you build TF from source and deploy it on Android, my guess is that you would need to specify the node input/output name for the .pb file (as in line 74 https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/tfmobile/src/org/tensorflow/demo/ClassifierActivity.java)

Also another question for the frozen file. When loading the model, it holds the property of the input_id length and cannot accept smaller or larger sentences. I tried to zero pad for smaller ones but I get a weird wav file. Any thoughts on that?

@dathudeptrai
Copy link
Collaborator

@anasvaf

Also another question for the frozen file. When loading the model, it holds the property of the input_id length and cannot accept smaller or larger sentences. I tried to zero pad for smaller ones but I get a weird wav file. Any thoughts on that?

Send me a code that u are using.

@manmay-nakhashi
Copy link

@anasvaf @dathudeptrai i am trying to convert this model to tflite model , but saved_model dosen't have any signatures do you know why , and how can i add it ?

@anasvaf
Copy link
Author

anasvaf commented Jun 5, 2020

@dathudeptrai This is the code for inference.

import yaml
import numpy as np
import matplotlib.pyplot as plt
import soundfile as sf

import tensorflow as tf

from tensorflow_tts.processor.ljspeech import LJSpeechProcessor
from tensorflow_tts.processor.ljspeech import symbols, _symbol_to_id

from tensorflow_tts.configs import Tacotron2Config
from tensorflow_tts.models import TFTacotron2

from tensorflow_tts.configs import MelGANGeneratorConfig
from tensorflow_tts.models import TFMelGANGenerator

input_text = "Hello! World"
input_ids = LJSpeechProcessor(None, "english_cleaners").text_to_sequence(input_text.lower())
input_ids = np.concatenate([input_ids, [len(symbols) - 1]], -1)

# load and inference again to check.
tacotron2 = tf.saved_model.load("test_saved")
decoder_output, mel_outputs, stop_token_prediction, alignment_history = tacotron2.inference(
        input_ids=np.expand_dims(input_ids, 0),
        input_lengths=np.array([len(input_ids)]),
        speaker_ids=np.array([0], dtype=np.int32),
        maximum_iterations=4000,
        use_window_mask=False,
        win_front=6,
        win_back=6,
)

print(mel_outputs)

And the output I am getting is, since I have saved the pb file with a larger sentence:


ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  Positional arguments (7 total):
    * Tensor("input_ids:0", shape=(1, 13), dtype=int64)
    * Tensor("input_lengths:0", shape=(1,), dtype=int64)
    * Tensor("speaker_ids:0", shape=(1,), dtype=int32)
    * False
    * 6
    * 6
    * 4000
  Keyword arguments: {}

Expected these arguments to match one of the following 1 option(s):

Option 1:
  Positional arguments (7 total):
    * TensorSpec(shape=(1, 58), dtype=tf.int64, name='input_ids')
    * TensorSpec(shape=(1,), dtype=tf.int64, name='input_lengths')
    * TensorSpec(shape=(1,), dtype=tf.int32, name='speaker_ids')
    * False
    * 6
    * 6
    * 4000
  Keyword arguments: {}

@anasvaf
Copy link
Author

anasvaf commented Jun 5, 2020

@manmay-nakhashi I am not sure if you can get a tflite from the pb file, since there are multiple @tf.function definitions in the model. E.g., on the call and infer function located at models/tacotron2.
Maybe I am wrong. This is the error I am getting when I am trying to do the following:

# Convert the model.
converter = tf.lite.TFLiteConverter.from_saved_model("test_saved")
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.gfile.GFile('model_tacotron2.tflite', 'wb') as f:
  f.write(tflite_model)

Traceback (most recent call last):
File "convert_h5_to_pb.py", line 59, in
tflite_model = converter.convert()
File "/home/anasvaf/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/lite/python/lite.py", line 452, in convert
raise ValueError("This converter can only convert a single "
ValueError: This converter can only convert a single ConcreteFunction. Converting multiple functions is under development.

@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Jun 5, 2020

ok, i will try to fix those issues tonight. Maybe we should merge call and inference into call function only or call inference function inside a call function

@manmay-nakhashi
Copy link

@manmay-nakhashi I am not sure if you can get a tflite from the pb file, since there are multiple @tf.function definitions in the model. E.g., on the call and infer function located at models/tacotron2.
Maybe I am wrong. This is the error I am getting when I am trying to do the following:

# Convert the model.
converter = tf.lite.TFLiteConverter.from_saved_model("test_saved")
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.gfile.GFile('model_tacotron2.tflite', 'wb') as f:
  f.write(tflite_model)

Traceback (most recent call last):
File "convert_h5_to_pb.py", line 59, in
tflite_model = converter.convert()
File "/home/anasvaf/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/lite/python/lite.py", line 452, in convert
raise ValueError("This converter can only convert a single "
ValueError: This converter can only convert a single ConcreteFunction. Converting multiple functions is under development.
@anasvaf
yes i have tried this but it seens like you can convert any function to tflite if we can add signature and make it a concrete function
tensorflow/tensorflow#34350
i was looking into this issue

@manmay-nakhashi
Copy link

manmay-nakhashi commented Jun 5, 2020

@dathudeptrai @anasvaf i think it would be best if we can convert this to tflite for faster inference to mobile devices and embedded.

@anasvaf
Copy link
Author

anasvaf commented Jun 5, 2020

@manmay-nakhashi you can still quantize the weights on the .pb file. At the moment it is only 2.6 MB (consisting of variableOPs). If you build TensorFlow for mobile from source you can still perform quite fast inference on mobile, utilizing only the CPU. Not sure how much you can speed up Tacotron-2 with the TFLite that can use the GPU. Notice that the most computationally intensive operations, based on the dynamic input, are the Entering and Exiting the while loop on the encoder-decoder.
But we can give it a try with a TFLite file :)

@manmay-nakhashi
Copy link

@anasvaf tflite works on flat buffer and tensorflow pb file is protobuf , flat buffer is faster mostly on low end devices.

@dathudeptrai
Copy link
Collaborator

@dathudeptrai
Copy link
Collaborator

@anasvaf @manmay-nakhashi pls close if it solve ur problem. I don't think we can convert Tacotron to tflite, even we can do that, there is no way make it can be run real-time on mobile devices.

@anasvaf
Copy link
Author

anasvaf commented Jun 5, 2020

@dathudeptrai thank you so much for all your help!! :))

@anasvaf anasvaf closed this as completed Jun 5, 2020
@dathudeptrai dathudeptrai moved this from In progress to Done in Tacotron 2 Jun 10, 2020
@gongchenghhu
Copy link

@dathudeptrai Yes, saving the model as pb would be really helpful, so I can use post-training quantization and try to import it on a raspberry to check the latency on the inference of mel prediction :)
Thank you for all the help!
@anasvaf Thanks for your discussion, and would you mind tell me how do the post-training quantization on the pb file. And I use the Tensorflow Lite, but I encountered some errors.https://github.com/dathudeptrai/TensorflowTTS/issues/47#issue-639624372

@anasvaf
Copy link
Author

anasvaf commented Jun 17, 2020

@gongchenghhu Unfortunately I was not able to do it. There are also some missing ops regarding Tacotron2 that need to be written in C++
I will try though with fastSpeech

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question ❓ Further information is requested
Projects
Tacotron 2
  
Done
Development

No branches or pull requests

4 participants