Skip to content

Unable to freeze graph with Elmo Embedding layers with graph_util.convert_variables_to_constants #208

@ashokramadass

Description

@ashokramadass

Problem
Tensorflow serving throws the below specified error when freezing a trained graph
This has been referenced and noticed that the error is thrown only when using the tensorflow hub character embedding Elmo. Works alright for Universal encoder.

Error on the server

'elmo/module_apply_tokens/bilm/embedding_lookup' expects to be colocated with unknown node 'elmo/module_apply_tokens/bilm/embedding_lookup/Read/ReadVariableOp'

Code to freeze the model,
saved_model_dir - Path that contains all the checkpoints
output_node_names - All the output_nodes from the graph + 'init_all_tables' (idea got from here)

def freeze_model(saved_model_dir, output_node_names, output_filename):
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(saved_model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path
    print(input_checkpoint)

    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        gd = tf.get_default_graph().as_graph_def()
        
        for table_init_op in tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS):
            print(table_init_op.name)
            output_node_names.append(table_init_op.name)
                
        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            gd, # The graph_def is used to retrieve the nodes
            output_node_names # The output node names are used to select the usefull nodes
        )

        tf.train.export_meta_graph(
                filename=output_graph,
                graph_def=output_graph_def,
                collection_list=[tf.GraphKeys.TABLE_INITIALIZERS])

Code to load the frozen graph

def load_frozen_graph(export_dir, frozen_graph_filepath):
    tf.reset_default_graph()
    with tf.Graph().as_default() as graph:
        tf.train.import_meta_graph(frozen_graph_filepath)
        with tf.Session(graph=graph) as sess:
            try:
                sess.run(graph.get_operation_by_name('init_all_tables'))
            except KeyError:
                pass
            inputs = {node.name: sess.graph.get_tensor_by_name('{}:0'.format(node.name))
                               for node in sess.graph_def.node if node.op=='IteratorGetNext'}
            outputs = {'class_ids': sess.graph.get_tensor_by_name(
                                   'dnn/head/predictions/class_ids:0')}
            builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
            signature = predict_signature_def(inputs=inputs, outputs=outputs)
            
            builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING],
                                                 signature_def_map={'predict': signature},
                                                  main_op=tf.tables_initializer())
            builder.save()

This is the line that causes the error.
tf.train.import_meta_graph(frozen_graph_filepath)

The same code works well in case of using Universal sentence encoder.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions