-
Notifications
You must be signed in to change notification settings - Fork 219
Open
Description
I defined a model using Tensorflow 2 with a Model Sublassing API on Python
class BiLSTMModel(tf.keras.Model):
def __init__(self, lstm_dims):
super().__init__()
self.ldims = lstm_dims
self.blockLstm = FirstBlockLSTMModule(lstm_dims)
self.nextBlockLstm = NextBlockLSTM(lstm_dims)
def call(self, inputs):
# Forward pass
block_lstm1_output = self.blockLstm(inputs)
block_lstm2_output = self.nextBlockLstm(block_lstm1_output)
return block_lstm2_output
I implemented a custom training loop and after each epoch, I just save this Keras model
biLSTMModel = BiLSTMModel(lstm_dims)
for epoch in range(epochs):
# Training code
tf.saved_model.save(biLSTMModel, export_dir)
It generates the expected structure with assets and variables folders, along with saved_model.pb and variables files
Later on, when loading the model I can restore all trainable variables values of the model just like this:
loaded_bi_lstm = tf.saved_model.load(model_path)
infer_bi_lstm = loaded_bi_lstm.signatures["serving_default"]
# LSTM weights
w_first_lstm = tf.Variable(infer_bi_lstm.trainable_variables[0])
wig_first_lstm = tf.Variable(infer_bi_lstm.trainable_variables[1])
wfg_first_lstm = tf.Variable(infer_bi_lstm.trainable_variables[2])
wog_first_lstm = tf.Variable(infer_bi_lstm.trainable_variables[3])
The model on tensorflow-java is loading without errors.
@Test
public void shouldRestoreVariablesFromSavedModel() {
SavedModelBundle model = SavedModelBundle.load(SAVED_MODEL_DP_PATH, SavedModelBundle.DEFAULT_TAG);
ConcreteFunction concreteFunction = model.function("serving_default");
//Trying to find some object that stores trainables variables data
MetaGraphDef metaGraphDef = model.metaGraphDef();
SignatureDef sig = metaGraphDef.getSignatureDefOrThrow("serving_default");
}
But I don't find a way to restore the trainable variables. When debugging, it seems there is no object that stores these data. Could you please guide me on how to get all trainable variables values from a saved model on Java.
Thanks
Metadata
Metadata
Assignees
Labels
No labels