Skip to content

Restoring Trainable Variables from Saved Model #300

@danilojsl

Description

@danilojsl

I defined a model using Tensorflow 2 with a Model Sublassing API on Python

class BiLSTMModel(tf.keras.Model):

    def __init__(self, lstm_dims):
        super().__init__()
        self.ldims = lstm_dims
        self.blockLstm = FirstBlockLSTMModule(lstm_dims)
        self.nextBlockLstm = NextBlockLSTM(lstm_dims)

    def call(self, inputs):
        # Forward pass
        block_lstm1_output = self.blockLstm(inputs)
        block_lstm2_output = self.nextBlockLstm(block_lstm1_output)
        return block_lstm2_output

I implemented a custom training loop and after each epoch, I just save this Keras model

biLSTMModel = BiLSTMModel(lstm_dims)
for epoch in range(epochs):
 # Training code
 tf.saved_model.save(biLSTMModel, export_dir)

It generates the expected structure with assets and variables folders, along with saved_model.pb and variables files

Later on, when loading the model I can restore all trainable variables values of the model just like this:

loaded_bi_lstm = tf.saved_model.load(model_path)
infer_bi_lstm = loaded_bi_lstm.signatures["serving_default"]

# LSTM weights
w_first_lstm = tf.Variable(infer_bi_lstm.trainable_variables[0])
wig_first_lstm = tf.Variable(infer_bi_lstm.trainable_variables[1])
wfg_first_lstm = tf.Variable(infer_bi_lstm.trainable_variables[2])
wog_first_lstm = tf.Variable(infer_bi_lstm.trainable_variables[3])

The model on tensorflow-java is loading without errors.

@Test
public void shouldRestoreVariablesFromSavedModel() {
    SavedModelBundle model = SavedModelBundle.load(SAVED_MODEL_DP_PATH, SavedModelBundle.DEFAULT_TAG);
    ConcreteFunction concreteFunction = model.function("serving_default");

    //Trying to find some object that stores trainables variables data
    MetaGraphDef metaGraphDef = model.metaGraphDef();
    SignatureDef sig = metaGraphDef.getSignatureDefOrThrow("serving_default");
}

But I don't find a way to restore the trainable variables. When debugging, it seems there is no object that stores these data. Could you please guide me on how to get all trainable variables values from a saved model on Java.

Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions