Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restored SavedModel + saved_model_cli raise exception when the object is deleted #44774

Closed
galeone opened this issue Nov 11, 2020 · 5 comments
Closed
Assignees
Labels
comp:apis Highlevel API related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.3 Issues related to TF 2.3 type:bug Bug

Comments

@galeone
Copy link

galeone commented Nov 11, 2020

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Archlinux
  • TensorFlow installed from (source or binary): Binary
  • TensorFlow version (use command below): v2.3.0-54-gfcc4b966f1 2.3.1
  • Python version: 3.8.6
  • CUDA/cuDNN version: no
  • GPU model and memory: cpu

Current and expected behavior

I'm exporting using a tf.Module two graphs created by decorating two method with @tf.function. I expect the SavedModel to be correctly exported and to not have a crash. Instead

  • I guess the SavedModel is not correctly created, since in the "serving_default" I can find only the information of one method and I don't know how to call the other method I'm exporting.
  • When I use saved_model_cli show --all --dir at I got an exception (see below)
  • I can get the same exception if I re-load (using tf.saved_model.load("at")) the model and I delete it (the exception is when the object goes out of scope, and not when the load method is invoked).

Standalone code to reproduce the issue

import sys

import tensorflow as tf
import tensorflow.keras as k


class ActivityTracker(tf.Module):
    def __init__(self):
        super().__init__()

        self.num_classes = 6  # activities in the training set
        self.mapping = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(
                keys=tf.range(self.num_classes, dtype=tf.int32),
                values=[
                    "Walking",
                    "Jogging",
                    "Upstairs",
                    "Downstairs",
                    "Sitting",
                    "Standing",
                ],
            ),
            "Unknown",
        )

        self.num_features = 3  # sensor (x,y,z)
        self.batch_size = 32

        # 33,Jogging,49106062271000,5.012288,11.264028,0.95342433;
        self._model = k.Sequential(
            [
                k.layers.Input(
                    shape=(1, self.num_features), batch_size=self.batch_size
                ),
                # Note the stateful=True
                k.layers.LSTM(64, stateful=True),
                k.layers.Dense(self.num_classes),
            ]
        )

        self._global_step = tf.Variable(0, dtype=tf.int32, trainable=False)
        self._optimizer = k.optimizers.SGD(learning_rate=1e-4)
        # Sparse, so we can feed the scalar and get the one hot representation
        # From logits so we can feed the unscaled (linear activation fn)
        # directly to the loss
        self._loss = k.losses.SparseCategoricalCrossentropy(from_logits=True)

        self._last_tracked_activity = tf.Variable(-1, dtype=tf.int32, trainable=False)

    @tf.function(
        input_signature=[
            tf.TensorSpec(shape=(None, 1, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None,), dtype=tf.int32),
        ]
    )
    def learn(self, sensor_data, labels):
        # All the sensor data should be about the same activity
        tf.assert_equal(labels, tf.zeros_like(labels) + labels[0])

        # If the activity changes, we must reset the RNN state since the last update
        # and the current update are not related.

        if tf.not_equal(self._last_tracked_activity, labels[0]):
            tf.print(
                "Resetting states. Was: ",
                self._last_tracked_activity,
                " is ",
                labels[0],
            )
            self._last_tracked_activity.assign_sub(labels[0])
            self._model.reset_states()

        self._global_step.assign_add(1)
        with tf.GradientTape() as tape:
            loss = self._loss(labels, self._model(sensor_data))
            tf.print(self._global_step, ": loss: ", loss)

        gradient = tape.gradient(loss, self._model.trainable_variables)
        self._optimizer.apply_gradients(zip(gradient, self._model.trainable_variables))
        return loss

    @tf.function(input_signature=[tf.TensorSpec(shape=(None, 1, 3), dtype=tf.float32)])
    def predict(self, sensor_data):
        predictions = self._model(sensor_data)
        predicted = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
        tf.print(self.mapping.lookup(predicted))
        return predicted


def main() -> int:
    at = ActivityTracker()

    # Executing an invocation of every graph we want to export is mandatory
    at.learn(
        tf.zeros((at.batch_size, 1, 3), dtype=tf.float32),
        tf.zeros((at.batch_size), dtype=tf.int32),
    )
    at.predict(tf.zeros((at.batch_size, 1, 3), dtype=tf.float32))

    tf.saved_model.save(at, "at")

    restored = tf.saved_model.load("at")
    return 0


if __name__ == "__main__":
    sys.exit(main())

Exception and SavedModel (that I guess is wrong)

saved_model_cli show --all --dir at/           
2020-11-11 18:16:54.735051: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory
2020-11-11 18:16:54.735089: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 
2020-11-11 18:16:57.323153: W tensorflow/core/common_runtime/graph_constructor.cc:808] Node 'while' has 11 outputs but the _output_shapes attribute specifies shapes for 20 outputs. Output shapes may be inaccurate.
2020-11-11 18:16:57.817962: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2020-11-11 18:16:57.817991: W tensorflow/stream_executor/cuda/cuda_driver.cc:312] failed call to cuInit: UNKNOWN ERROR (303)
2020-11-11 18:16:57.818017: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (i3): /proc/driver/nvidia/version does not exist

Defined Functions:
  Function Name: 'learn'
    Option #1
      Callable with:
        Argument #1
          sensor_data: TensorSpec(shape=(None, 1, 3), dtype=tf.float32, name='sensor_data')
        Argument #2
          labels: TensorSpec(shape=(None,), dtype=tf.int32, name='labels')

  Function Name: 'predict'
    Option #1
      Callable with:
        Argument #1
          sensor_data: TensorSpec(shape=(None, 1, 3), dtype=tf.float32, name='sensor_data')
Exception ignored in: <function CapturableResourceDeleter.__del__ at 0x7f7b9f8af700>
Traceback (most recent call last):
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/training/tracking/tracking.py", line 202, in __del__
    self._destroy_resource()
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 696, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3065, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 237, in restored_function_body
    return _call_concrete_function(function, inputs)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 74, in _call_concrete_function
    result = function._call_flat(tensor_inputs, function._captured_inputs)  # pylint: disable=protected-access
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 105, in _call_flat
    return super(_WrapperFunction, self)._call_flat(args, captured_inputs,
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1938, in _call_flat
    flat_outputs = forward_function.call(ctx, args_with_tangents)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 573, in call
    outputs = functional_ops.partitioned_call(
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/ops/functional_ops.py", line 1192, in partitioned_call
    f.add_to_graph(graph)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 495, in add_to_graph
    g._add_function(self)
  File "/home/paolo/projects/venv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 3344, in _add_function
    pywrap_tf_session.TF_GraphCopyFunction(self._c_graph, function._c_func.func,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 'func' argument to TF_GraphCopyFunction cannot be null
WARNING:tensorflow:Unresolved object in checkpoint: (root).mapping._initializer
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
@galeone galeone added the type:bug Bug label Nov 11, 2020
@ravikyram ravikyram added comp:apis Highlevel API related issues TF 2.3 Issues related to TF 2.3 labels Nov 12, 2020
@ravikyram
Copy link
Contributor

I have tried in colab with TF version 2.3, nightly version(2.5.0-dev20201111) and was able to reproduce the issue. Please, find the gist here.Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 12, 2020
@dellis23
Copy link
Contributor

Thanks for reporting. This one was a fun one to investigate. I have a fix out that should make it into 2.5. I'll update this issue with the commit ID once it makes it over to github.

@galeone
Copy link
Author

galeone commented Feb 19, 2021

Thanks for reporting. This one was a fun one to investigate. I have a fix out that should make it into 2.5. I'll update this issue with the commit ID once it makes it over to github.

Thank you @dellis23 😄

@dellis23
Copy link
Contributor

Fixed in c29e9f2

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@geetachavan1 geetachavan1 added this to Done in TensorFlow 2.5 Mar 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:apis Highlevel API related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.3 Issues related to TF 2.3 type:bug Bug
Projects
Development

No branches or pull requests

4 participants