Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leak when repeatedly loading and deleting keras models #40171

Closed
idfah opened this issue Jun 4, 2020 · 21 comments
Closed

Memory leak when repeatedly loading and deleting keras models #40171

idfah opened this issue Jun 4, 2020 · 21 comments
Assignees
Labels
comp:keras Keras related issues TF 2.2 Issues related to TF 2.2 type:performance Performance Issue

Comments

@idfah
Copy link

idfah commented Jun 4, 2020

If a Keras model is saved using tf.saved_model.save and then repeatedly loaded with tf.saved_model.load and deleted with del it becomes apparent that there is a slow memory leak. keras.backend.clear_session does not resolve this issue. See attached gist for an example that reproduces this issue in TensorFlow 2.2 on Google Colab.

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
    I have attached a custom repro case, but this appears to happen for various types of typical keras models.
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    Can reproduce in Google Colab and Docker RedHat images
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
    not tested
  • TensorFlow installed from (source or binary):
    binary (from pip)
  • TensorFlow version (use command below):
    ('2.2.0', 'v2.2.0-0-g2b96f3662b')
  • Python version:
    3.6.9 (google colab)
  • Bazel version (if compiling from source):
    N/A
  • GCC/Compiler version (if compiling from source):
    N/A
  • CUDA/cuDNN version:
    default in google colab
  • GPU model and memory:
    default in google colab

Describe the current behavior
When Keras models are saved / loaded repeatedly, memory usage gradually continues to grow over time. For dynamic model servers that load and unload models over time, this may eventually lead to a crash due to memory exhaustion.

Describe the expected behavior
All memory should be recovered after a keras model instance is deleted with del and the garbage collector is run with gc.collect().

Standalone code to reproduce the issue
The following GitHub gist demonstrates the issue (can also be run in Colab):
https://gist.github.com/idfah/dff83de8d2a6406c9b92221e6282a8d6

@idfah idfah added the type:bug Bug label Jun 4, 2020
@amahendrakar amahendrakar added comp:keras Keras related issues TF 2.2 Issues related to TF 2.2 type:performance Performance Issue and removed type:bug Bug labels Jun 5, 2020
@amahendrakar
Copy link
Contributor

Was able to reproduce the issue with TF v2.2 and TF-nightly. Please find the attached gist. Thanks!

@gowthamkpr gowthamkpr assigned k-w-w and unassigned gowthamkpr Jun 8, 2020
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 8, 2020
@idfah
Copy link
Author

idfah commented Jun 9, 2020

verified that the same behavior is present in 2.0, 2.1 and 2.2

@kmh4321
Copy link
Contributor

kmh4321 commented Jun 19, 2020

Hello,
I have some more info to add to this issue. I have run a bunch of experiments and have observed the following:

  • There is a small but persistent memory leak whenever models are loaded multiple times.
  • The leak seems to depend on the number of layers.
  • The leak seems to be independent of number of variables/parameters.

I have uploaded my experiments here. I have 3 models having roughly the same number of parameters but spread across different layers. The memory leak seems to be the least when there is only 1 hidden layer. There is almost a 10x memory leak increase when I increase the layers from 1 hidden layer to 15 hidden layers. All while keeping the # of parameters/variables (roughly) fixed. With just one layer I observed a memory usage increase of 210KB per load cycle v/s 1.7MB per load cycle with multiple layers/

Maybe there is something in the way Keras instantiates layers that is causing this issue?

cc: @frreiss

@frreiss
Copy link
Contributor

frreiss commented Jun 23, 2020

Preliminary results from profiling the example program @kmh4321 provided with pympler and gperftools:

  • The leak does not appear to be in the Python heap (or at least the portion of the Python heap that pympler can see).
  • There's a pretty substantial amount of memory leakage of objects allocated from various subroutines of TF_GraphImportGraphDefWithResults. I'm guessing that some part of the Python code is leaking handles to TF_Graph objects located on the C++ heap.
  • The memory that gperftools' profiler can see adds up to about 75% of the actual memory usage of my test process. That's probably good enough for this case.
  • The stack traces I'm getting from pprof --text --stacks are kind of garbled right now, probably because I'm using TensorFlow 2.2.0 binaries from PyPI. I'll try building a version of TensorFlow with debug symbols overnight.

@frreiss
Copy link
Contributor

frreiss commented Jun 23, 2020

Here's the output of pprof --text --stacks in case anyone else can make more sense out of it:pprof.out.txt

@frreiss
Copy link
Contributor

frreiss commented Jun 25, 2020

I've identified the cause of this leak. While loading functions from the SavedModel file, function_deserialization.load_function_def_library() executes these two lines:

    if context.executing_eagerly():
      func.add_to_graph(ops.get_default_graph())

See source file here.

These lines create a second copy of the body of each function and stick that second copy into a dictionary hanging off a tf.Graph that is in turn hanging off a global variable. That second copy of the function is never deleted.

This problem appears to have been patched on the master branch four months ago; see this commit.

It would be a really good idea to backport this fix to the 2.2 branch. I'm not sure why it wasn't backported in the first place. TensorFlow 2.2.0 came out almost two months after this fix was pushed into the master branch.

It would also be a good idea to backport this fix to the 2.1 branch.

@frreiss
Copy link
Contributor

frreiss commented Jun 26, 2020

I've identified a workaround that cuts this memory leakage by about 80%: Load the model from a temporary background thread.

Here's some code to copy and paste:

def load_model_hack(saved_model_path: str):
  """
  Load a SavedModel from a background thread so that most of the garbage 
  that saved_model.load() leaves on the calling thread's environment will
  be collected. Still leaks memory, but at a lower rate.
  """
  import threading
  from typing import Any, Dict
  def callback(path: str, result_holder: Dict[str, Any]) -> None:
    try:
      result_holder["model"] = tf.saved_model.load(path)
    except Exception as e:
      result_holder["error"] = e

  # Call saved_model.load() in a background thread 
  thread_results = {}
  t = threading.Thread(target=callback, args=[saved_model_path, thread_results])
  t.start()
  t.join()

  # Forward any exceptions thrown by the background thread
  if "error" in thread_results:
    raise thread_results["error"]
  return thread_results["model"]

After putting this workaround in place, the test case described above still leaks memory on TensorFlow 2.2.0. However, the leakage occurs at a much slower rate. I suspect that there is a second memory leak in TensorFlow 2.2.0's model loading path, and that this second leak has also been patched on the master branch without backporting the patch to the 2.2 branch.

@kmh4321
Copy link
Contributor

kmh4321 commented Jun 29, 2020

@k-w-w would it be possible to add the fix from the current master branch to the 2.2 branch?

I've identified the cause of this leak. While loading functions from the SavedModel file, function_deserialization.load_function_def_library() executes these two lines:

    if context.executing_eagerly():
      func.add_to_graph(ops.get_default_graph())

See source file here.

These lines create a second copy of the body of each function and stick that second copy into a dictionary hanging off a tf.Graph that is in turn hanging off a global variable. That second copy of the function is never deleted.

This problem appears to have been patched on the master branch four months ago; see this commit.

It would be a really good idea to backport this fix to the 2.2 branch. I'm not sure why it wasn't backported in the first place. TensorFlow 2.2.0 came out almost two months after this fix was pushed into the master branch.

It would also be a good idea to backport this fix to the 2.1 branch.

@frreiss
Copy link
Contributor

frreiss commented Jul 13, 2020

Some updates:

  • I have tracked down the second memory leak.
  • I have identified a workaround for the second leak.
  • There is a third leak.

Details follow.

The root cause of the second memory leak is the lines immediately before the lines I pointed out in my previous comment:

(python/saved_model/function_deserialization.py:

316    func = function_lib.ConcreteFunction(func_graph)
317    func.add_to_graph()
318    if context.executing_eagerly():
319      func.add_to_graph(ops.get_default_graph())

Line 316 creates a ConcreteFunction object. The initializer for ConcreteFunction adds the function to the current eager execution context and creates an _EagerDefinedFunctionDeleter callback object to that will remove the function from the context when the ConcreteFunction goes out of scope.

Then line 317 calls add_to_graph() on the ConcreteFunction, implicitly passing None for the g (graph) parameter. ConcreteFunction.add_to_graph() passes its g argument down to _EagerDefinedFunction.add_to_graph(). And _EagerDefinedFunction.add_to_graph() will call TFE_ContextAddFunctionDef if its g parameter is None. TFE_ContextAddFunctionDef is actually a thin wrapper around TFE_ContextAddFunction.

So every function in the SavedModel is passed twice to TFE_ContextAddFunction.

When TFE_ContextAddFunction is called with a function definition whose name happens to match an existing function in the context, TFE_ContextAddFunction ignores the function definition that was passed to it (a disastrous design flaw IMO -- what if the second copy of the function isn't the same as the first one?) and increments the existing function's reference count.

So every function in the SavedModel ends up instantiated in the current context with a reference count of 2.

When the model's surrogate Python object goes out of scope, the ConcreteFunction objects under it also go out of scope, which causes the _EagerDefinedFunctionDeleter objects attached to them to go out of scope, which triggers a single call to TFE_ContextRemoveFunction for each function. In spite of its name, the TFE_ContextRemoveFunction API call does not actually remove the function whose name is passed to it unless the reference count of that function is 1. And every function in the SavedModel has a reference count of 2.

This second memory leak is also fixed in commit 3421416

After that commit, line 317 is replaced with:

  func.add_to_graph(graph)

where graph is a temporary Graph object. Having the g parameter of ConcreteFunction.add_to_graph() set to a value other than None prevents the leaky code in _EagerDefinedFunction.add_to_graph() from running.

Here's an updated version of my workaround that corrects for both memory leaks:

def load_model_hack(saved_model_path: str):
  """
  Load a SavedModel without leaking as much memory as usual.
  
  This function applies two workarounds: 
  * Load the model from a temporary background thread so that 
    `saved_model.load()` won't leave garbage hanging off of the global 
    default graph
  * Unpin functions that `saved_model.load()` pins twice, so that the 
    garbage collection logic in `ConcreteFunction` will correctly remove
    these functions when the restored model goes out of scope..
  """
  import threading
  from typing import Any, Dict
  def callback(path: str, result_holder: Dict[str, Any]) -> None:
    """Callback function to be executed in a background thread"""
    try:
      result_holder["model"] = tf.saved_model.load(path)

      # Every function that was pinned twice (and hopefully exactly those
      # functions) should now be in the background thread's global default
      # graph. Unpin these functions once.
      default_graph = tf.compat.v1.get_default_graph()
      for function_name in default_graph._functions.keys():
        tf.python.context.remove_function(function_name)
    except Exception as e:
      result_holder["error"] = e

  # Call saved_model.load() in a background thread 
  thread_results = {}
  t = threading.Thread(target=callback, args=[saved_model_path, thread_results])
  t.start()
  t.join()

  # Forward any exceptions thrown by the background thread
  if "error" in thread_results:
    raise thread_results["error"]
  return thread_results["model"]

This workaround further reduces the amount of memory that my test program leaks. However it does not completely eliminate memory leakage.

@frreiss
Copy link
Contributor

frreiss commented Jul 21, 2020

Updates:

  • I have found the third leak.
  • The patch for the first two leaks also patched the third leak.
  • I have updated my workaround to cover all three leaks.
  • There is a fourth leak.

Details follow.

Loader._load_nodes() walks through the graph, reconstituting each part of the graph from its serialized Protocol Buffers representation.

(See line 264 in tensorflow/python/saved_model/load.py)

258    # Re-create everything except slot variables.
259    for node_id, proto in enumerate(self._proto.nodes):
260      if node_id in slot_variable_node_ids:
261        # Defer recreating slot variables so we can use the public Optimizer
262        # interface.
263        continue
264      node, setter = self._recreate(proto, node_id)   <<<<<<<<<<<
265      nodes[node_id] = node
266      node_setters[node_id] = setter

The proto variable here is a surrogate object for a SavedObject message:

(in tensorflow/core/protobuf/saved_object_graph.proto)

message SavedObject {
  // Objects which this object depends on: named edges in the dependency
  // graph.
  //
  // Note: currently only valid if kind == "user_object".
  repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;

  // Removed when forking SavedObject from TrackableObjectGraph.
  reserved "attributes";
  reserved 2;

  // Slot variables owned by this object. This describes the three-way
  // (optimizer, variable, slot variable) relationship; none of the three
  // depend on the others directly.
  //
  // Note: currently only valid if kind == "user_object".
  repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
      slot_variables = 3;

  oneof kind {
    SavedUserObject user_object = 4;
    SavedAsset asset = 5;
    SavedFunction function = 6;
    SavedVariable variable = 7;
    SavedBareConcreteFunction bare_concrete_function = 8;
    SavedConstant constant = 9;
    SavedResource resource = 10;
  }
}

Loader._recreate() is an unnecessarily-complex switch statement. Here are the relevant snippets:

(in tensorflow/python/saved_model/load.py)

353  def _recreate(self, proto, node_id):
354    """Creates a Python object from a SavedObject protocol buffer."""
355    factory = {
356        "user_object": (
357            lambda: self._recreate_user_object(proto.user_object, node_id)),
358        "asset": lambda: self._recreate_asset(proto.asset),
359        "function": lambda: self._recreate_function(proto.function),
360        "bare_concrete_function": functools.partial(
361            self._recreate_bare_concrete_function,
362            proto.bare_concrete_function),
363        "variable": lambda: self._recreate_variable(proto.variable),
364        "constant": lambda: self._recreate_constant(proto.constant),
365        "resource": lambda: self._recreate_resource(proto.resource),
366    }
367    kind = proto.WhichOneof("kind")
368    if kind not in factory:
369      raise ValueError("Unknown SavedObject type: %r" % kind)
360    return factory[kind]()

...

396  def _recreate_function(self, proto):
397    return function_deserialization.recreate_function(
398        proto, self._concrete_functions), setattr

In the case of this memory leak, the SavedObject message in the variable proto has its kind field set to bare_concrete_function. So the above translates into:

if proto.kind is bare_concrete_function:
    node = function_deserialization.setup_bare_concrete_function(proto,
        self._concrete_functions)
    setter = getattr
    return node, setter
else ... # code that doesn't execute for this case

(The reference to getattr feeds another bit of complex code elsewhere in load.py.)

function_deserialization.setup_bare_concrete_function() looks up the already-deserialized ConcreteFunction object, then calls that object's add_to_graph() method (line 172 below).

(in tensorflow/python/saved_model/function_deserialization.py)

159 def setup_bare_concrete_function(saved_bare_concrete_function,
160                                 concrete_functions):
161  """Makes a restored bare concrete function callable."""
160  # Bare concrete functions accept only flat lists of Tensors with unique
163  # names.
164  concrete_function = concrete_functions[
165      saved_bare_concrete_function.concrete_function_name]
166  # pylint: disable=protected-access
167  concrete_function._arg_keywords = (
168      saved_bare_concrete_function.argument_keywords)
169  concrete_function._num_positional_args = (
170      saved_bare_concrete_function.allowed_positional_arguments)
171  # pylint: enable=protected-access
172  concrete_function.add_to_graph()     <<<<<<<<<<
173  return concrete_function

Of course, the code that created the ConcreteFunction object in the first place has already added it to the graph twice.

And due to same the root cause as leaks 1 and 2, the ConcreteFunction's graph pointer is set to None, so ConcreteFunction.add_to_graph() calls _EagerDefinedFunction.add_to_graph(), which adds the function to the eager execution context a third time.

Unfortunately, the Loader class discards all information about what functions it has given this special treatment to, so a workaround along the lines of what I posted in my previous comment is not going to work. Here's my new workaround, which involves live-patching the function deletion callbacks in the background thread's global default graph:

class DeleteWithExtremePrejudice(object):
  """
  A version of _EagerDefinedFunctionDeleter (see
  tensorflow/python/eager/function.py) that keeps deleting the target function
  until an InvalidArgumentError exception is thrown.  Checking for that
  exception is the only way to ensure that a function really has been deleted
  and is not, in fact, still taking up memory.
  """

  def __init__(self, func_name):
    self.func_name = func_name

  def __del__(self):
    MAX_ATTEMPTS = 10
    try:
      for i in range(MAX_ATTEMPTS):
        tf.python.context.remove_function(self.func_name)
      # If we get here, removal did *not* fail as expected.
      print(f"WARNING: Failed to remove function "
            f"'{self.func_name}' after {MAX_ATTEMPTS} attempts. "
            f"This problem may result in a memory leak.")
    except tf.errors.InvalidArgumentError:
      # tf.python.context.remove_function() throws this exception when
      # you try to remove a function that has already been removed.
      # In the case of this `try` clause, such behavior is "normal".
      pass
    except Exception as e:
      print(f"WARNING: {e} thrown when attempting to delete function "
            f"'{self.func_name}'. This problem may result in a memory leak.")
  

def load_model_hack(saved_model_path: str):
  """
  Load a SavedModel without leaking memory.
  
  This function applies two workarounds: 
  * Load the model from a temporary background thread so that 
    `saved_model.load()` won't leave garbage hanging off of the global 
    default graph
  * Patch the garbage collection callbacks of all `ConcreteFunction`s
    in the returned model so that these functions will be properly removed
    when the restored model goes out of scope.
  """
  import threading
  from typing import Any, Dict
  def callback(path: str, result_holder: Dict[str, Any]) -> None:
    """Callback function to be executed in a background thread"""
    try:
      model = tf.saved_model.load(path)

      # Every function that was pinned two or more times should now be
      # in the current thread's global default graph variable.
      # Replace the deletion callbacks of these functions with a more
      # effective version.
      default_graph = tf.compat.v1.get_default_graph()
      for func in default_graph._functions.values():
        func._function_deleter = DeleteWithExtremePrejudice(func.name)
        # NOTE: This assignment will trigger the previous deletion callback.
        # That's ok, because every function in this list has been pinned 
        # at least twice.

      result_holder["model"] = model

    except Exception as e:
      result_holder["error"] = e

  # Call saved_model.load() in a background thread 
  thread_results = {}
  t = threading.Thread(target=callback, args=[saved_model_path, thread_results])
  t.start()
  t.join()

After applying this workaround, my test program leaks significantly less memory than before, but it still leaks memory.

@jvishnuvardhan
Copy link
Contributor

@idfah Recently there were some updates to reduce the memory leak. I ran your code with recent tf-nightly and see a loss of ~12 MB over 500 iterations. Can you please check the this gist and let us know what you think.

Please verify and close the issue If this was resolved for you. Thanks!

@jvishnuvardhan jvishnuvardhan added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Aug 7, 2020
@frreiss
Copy link
Contributor

frreiss commented Aug 11, 2020

@jvishnuvardhan, I think this issue needs to be kept open a while longer.

So far, we have verified that there is a serious memory leak in tf.saved_model.load() in TensorFlow 2.0.x, 2.1.x, and 2.2.x.

I would categorize this leak as a blocker for any application that needs to cycle models in and out of memory -- for example, to process a corpus of documents that span multiple languages; or to serve multiple versions of the same model. Our simple test program "only" leaks a few megabytes each time the model is loaded, but larger models with weights embedded in their graphs can leak hundreds of megabytes per load/unload cycle.

The leak is actually three leaks, all of which were patched in master back in March (in commit 3421416). However, the fix was not included in the May release of TensorFlow 2.2.0. As of today, five months later, the fix has not been ported to the 2.2.x, 2.1.x, or 2.0.x branches of TensorFlow.

TensorFlow 2.3.0 includes the fix for these three memory leaks. However, fixing this bug in 2.3.0 does not resolve this issue for us. My colleagues are currently using TensorFlow 2.2.x.

In addition to the three large leaks, there is a fourth leak that is not currently patched in the master branch. You can see the presence of this fourth leak in the output of the notebook linked in your previous comment:
memory_leak

With the simple 2-layer model in your example notebook, each call to saved_model.load() leaks about 25kb of memory. Larger models leak more, probably a megabyte or two for a medium-sized deep learning model. This level of memory leakage is something that one could plausibly work around with periodic reboots; but I would submit that tf.saved_model.load() ought not to leak any memory at all. Authors of long-running applications should be able to load and unload TensorFlow models without worrying about running out of memory.

I have tracked the root cause of the fourth leak to a problem in TensorFlow's mechanism for caching kernels.

In addition to creating graphs, tf.saved_model.load() executes operations in those graphs, primarily for the purpose of restoring variable values. The code that executes these operations is EagerExecute(), which calls EagerLocalExecute(), which calls GetOrCreateKernelAndDevice(), which asks the current EagerContext to look for the kernel for each operation in its kernel cache.

The EagerContext class maintains a cache of kernel instances:
(in tensorflow/core/common_runtime/eager/context.h, direct link here)

612  std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>,
613                     Fprint128Hasher>
614      kernel_cache_ TF_GUARDED_BY(cache_mu_);

The kernel cache does not have a size limit. There is an API call to clear the cache, but the Python side of TensorFlow only uses that API call when resetting the global random seed.

Each entry in the cache is parameterized by the "fingerprint" of the associated operation. This "fingerprint" is a hash value computed from multiple parameters of the operation, including all of the operation's attributes.

saved_model.load() restores variables through a process that involves invoking multiple instances of the VarHandleOp operation. Due to the graph rewriting that happens during model loading, each of these operations has a unique value in the shared_name field if its attributes. These unique values cause the operations to have unique fingerprints, even across multiple load operations on the same model. Each unique fingerprint causes the creation of a new entry in the cache. The cache is of unlimited size and is never cleared, so memory usage for these cached kernels grows in an unbounded fashion.

The best workaround I've found for this problem is to have your Python application periodically clear the cache via the internal API. Here's some Python code to do so:

import tensorflow.python as tf_internals
context_handle = tf_internals.eager.context.context()._context_handle
if context_handle is not None:
  tf_internals.pywrap_tfe.TFE_ContextClearCaches(context_handle)

A more permanent fix would be to evict stale entries from the cache following a least recently used policy. I'm working on a PR to apply such a fix.

The above workaround reduces but does not eliminate the memory leakage of my test program. Before the workaround, each call to saved_model.load() leaked about 125kb on TensorFlow 2.3.0 and 115kb on tf-nightly. After the workaround, each call leaks about 30kb and 20kb on TensorFlow 2.3.0 and nightly, respectively. I haven't checked whether this remaining leakage is constant or whether it scales with model size. However, since the leakage appears to be coming from a graph rewrite, I would expect the amount of memory leaked to scale with model size.

@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Aug 18, 2020
@frreiss
Copy link
Contributor

frreiss commented Aug 18, 2020

Not sure why tensorflow-butler thinks this issue hasn't had recent activity. There is an open pull request for part of the problem described here.

@jvishnuvardhan jvishnuvardhan removed stat:awaiting response Status - Awaiting response from author stale This label marks the issue/pr stale - to be closed automatically if no activity labels Aug 18, 2020
@goldiegadde
Copy link
Contributor

@frreiss regarding the backport, we typically dont backport bug fixes into previous releases. Is it possible for you to use a later version like TF 2.3.0 which has the fixes ?
Regarding the 4th leak, I dont see any pull request that is open currently, is this something you are working on ?

@frreiss
Copy link
Contributor

frreiss commented Nov 18, 2020

Hi @goldiegadde . I see from the release notes that TensorFlow 2.3.1 includes 25 bug fixes, and TensorFlow 2.2.1 includes 19 bug fixes. Perhaps you meant to say that you typically don't backport fixes for memory leaks?

Moving to 2.3.x is the only viable option for us at this point, and my colleagues will be doing so in spite of the disruption that this entails.

The fourth leak is currently low on my priority list because #33412 is a much more severe problem for us. Hopefully that fourth leak fixes itself.

@JoshEZiegler
Copy link

JoshEZiegler commented Dec 14, 2020

I just tested the above colab notebook with the latest tf-nightly (version 2.5.0-dev20201214) and this memory increase with iterations looks reduced by a fair bit. Its unclear to me if this is fully fixed, but perhaps the fourth memory leak you mentioned has been fixed.

memory_leak_TF2.5.0-dev20201214

Edit:
Also want to note this appears fixed in TF 2.4.0, so no need to grab a tf-nightly

memory_leak_TF2.4.0

@frreiss
Copy link
Contributor

frreiss commented Jan 4, 2021

I think this issue is as fixed as it's going to be. What do you think, @idfah ?

@goldiegadde
Copy link
Contributor

goldiegadde commented Jan 8, 2021

@frreiss the commit you referred to has been cherrypicked into 2.2.2 as well. I am closing this bug for now if any other issues linger can you please open a new one?
Thanks!

@qwfy
Copy link

qwfy commented Jun 9, 2021

Updates:

  • I have found the third leak.
  • The patch for the first two leaks also patched the third leak.
  • I have updated my workaround to cover all three leaks.
  • There is a fourth leak.

Details follow.

Loader._load_nodes() walks through the graph, reconstituting each part of the graph from its serialized Protocol Buffers representation.

(See line 264 in tensorflow/python/saved_model/load.py)

258    # Re-create everything except slot variables.
259    for node_id, proto in enumerate(self._proto.nodes):
260      if node_id in slot_variable_node_ids:
261        # Defer recreating slot variables so we can use the public Optimizer
262        # interface.
263        continue
264      node, setter = self._recreate(proto, node_id)   <<<<<<<<<<<
265      nodes[node_id] = node
266      node_setters[node_id] = setter

The proto variable here is a surrogate object for a SavedObject message:

(in tensorflow/core/protobuf/saved_object_graph.proto)

message SavedObject {
  // Objects which this object depends on: named edges in the dependency
  // graph.
  //
  // Note: currently only valid if kind == "user_object".
  repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;

  // Removed when forking SavedObject from TrackableObjectGraph.
  reserved "attributes";
  reserved 2;

  // Slot variables owned by this object. This describes the three-way
  // (optimizer, variable, slot variable) relationship; none of the three
  // depend on the others directly.
  //
  // Note: currently only valid if kind == "user_object".
  repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
      slot_variables = 3;

  oneof kind {
    SavedUserObject user_object = 4;
    SavedAsset asset = 5;
    SavedFunction function = 6;
    SavedVariable variable = 7;
    SavedBareConcreteFunction bare_concrete_function = 8;
    SavedConstant constant = 9;
    SavedResource resource = 10;
  }
}

Loader._recreate() is an unnecessarily-complex switch statement. Here are the relevant snippets:

(in tensorflow/python/saved_model/load.py)

353  def _recreate(self, proto, node_id):
354    """Creates a Python object from a SavedObject protocol buffer."""
355    factory = {
356        "user_object": (
357            lambda: self._recreate_user_object(proto.user_object, node_id)),
358        "asset": lambda: self._recreate_asset(proto.asset),
359        "function": lambda: self._recreate_function(proto.function),
360        "bare_concrete_function": functools.partial(
361            self._recreate_bare_concrete_function,
362            proto.bare_concrete_function),
363        "variable": lambda: self._recreate_variable(proto.variable),
364        "constant": lambda: self._recreate_constant(proto.constant),
365        "resource": lambda: self._recreate_resource(proto.resource),
366    }
367    kind = proto.WhichOneof("kind")
368    if kind not in factory:
369      raise ValueError("Unknown SavedObject type: %r" % kind)
360    return factory[kind]()

...

396  def _recreate_function(self, proto):
397    return function_deserialization.recreate_function(
398        proto, self._concrete_functions), setattr

In the case of this memory leak, the SavedObject message in the variable proto has its kind field set to bare_concrete_function. So the above translates into:

if proto.kind is bare_concrete_function:
    node = function_deserialization.setup_bare_concrete_function(proto,
        self._concrete_functions)
    setter = getattr
    return node, setter
else ... # code that doesn't execute for this case

(The reference to getattr feeds another bit of complex code elsewhere in load.py.)

function_deserialization.setup_bare_concrete_function() looks up the already-deserialized ConcreteFunction object, then calls that object's add_to_graph() method (line 172 below).

(in tensorflow/python/saved_model/function_deserialization.py)

159 def setup_bare_concrete_function(saved_bare_concrete_function,
160                                 concrete_functions):
161  """Makes a restored bare concrete function callable."""
160  # Bare concrete functions accept only flat lists of Tensors with unique
163  # names.
164  concrete_function = concrete_functions[
165      saved_bare_concrete_function.concrete_function_name]
166  # pylint: disable=protected-access
167  concrete_function._arg_keywords = (
168      saved_bare_concrete_function.argument_keywords)
169  concrete_function._num_positional_args = (
170      saved_bare_concrete_function.allowed_positional_arguments)
171  # pylint: enable=protected-access
172  concrete_function.add_to_graph()     <<<<<<<<<<
173  return concrete_function

Of course, the code that created the ConcreteFunction object in the first place has already added it to the graph twice.

And due to same the root cause as leaks 1 and 2, the ConcreteFunction's graph pointer is set to None, so ConcreteFunction.add_to_graph() calls _EagerDefinedFunction.add_to_graph(), which adds the function to the eager execution context a third time.

Unfortunately, the Loader class discards all information about what functions it has given this special treatment to, so a workaround along the lines of what I posted in my previous comment is not going to work. Here's my new workaround, which involves live-patching the function deletion callbacks in the background thread's global default graph:

class DeleteWithExtremePrejudice(object):
  """
  A version of _EagerDefinedFunctionDeleter (see
  tensorflow/python/eager/function.py) that keeps deleting the target function
  until an InvalidArgumentError exception is thrown.  Checking for that
  exception is the only way to ensure that a function really has been deleted
  and is not, in fact, still taking up memory.
  """

  def __init__(self, func_name):
    self.func_name = func_name

  def __del__(self):
    MAX_ATTEMPTS = 10
    try:
      for i in range(MAX_ATTEMPTS):
        tf.python.context.remove_function(self.func_name)
      # If we get here, removal did *not* fail as expected.
      print(f"WARNING: Failed to remove function "
            f"'{self.func_name}' after {MAX_ATTEMPTS} attempts. "
            f"This problem may result in a memory leak.")
    except tf.errors.InvalidArgumentError:
      # tf.python.context.remove_function() throws this exception when
      # you try to remove a function that has already been removed.
      # In the case of this `try` clause, such behavior is "normal".
      pass
    except Exception as e:
      print(f"WARNING: {e} thrown when attempting to delete function "
            f"'{self.func_name}'. This problem may result in a memory leak.")
  

def load_model_hack(saved_model_path: str):
  """
  Load a SavedModel without leaking memory.
  
  This function applies two workarounds: 
  * Load the model from a temporary background thread so that 
    `saved_model.load()` won't leave garbage hanging off of the global 
    default graph
  * Patch the garbage collection callbacks of all `ConcreteFunction`s
    in the returned model so that these functions will be properly removed
    when the restored model goes out of scope.
  """
  import threading
  from typing import Any, Dict
  def callback(path: str, result_holder: Dict[str, Any]) -> None:
    """Callback function to be executed in a background thread"""
    try:
      model = tf.saved_model.load(path)

      # Every function that was pinned two or more times should now be
      # in the current thread's global default graph variable.
      # Replace the deletion callbacks of these functions with a more
      # effective version.
      default_graph = tf.compat.v1.get_default_graph()
      for func in default_graph._functions.values():
        func._function_deleter = DeleteWithExtremePrejudice(func.name)
        # NOTE: This assignment will trigger the previous deletion callback.
        # That's ok, because every function in this list has been pinned 
        # at least twice.

      result_holder["model"] = model

    except Exception as e:
      result_holder["error"] = e

  # Call saved_model.load() in a background thread 
  thread_results = {}
  t = threading.Thread(target=callback, args=[saved_model_path, thread_results])
  t.start()
  t.join()

After applying this workaround, my test program leaks significantly less memory than before, but it still leaks memory.

Wow when reading the third leak, I thought it would be hilarious if there is a fourth (though I think there would be a fifth), by any chance, have you documented the thinking and the tools used to track down these leaks?

@frreiss
Copy link
Contributor

frreiss commented Jun 10, 2021

Wow when reading the third leak, I thought it would be hilarious if there is a fourth (though I think there would be a fifth), by any chance, have you documented the thinking and the tools used to track down these leaks?

My main piece of advice in terms of tools is that the tools lie. TensorFlow is a very complex program with multiple heaps (Python heap, C heap, TensorFlow's own memory manager, and CUDA) and global data structures that can look like leaks at first glance. You need to use several leak checkers and cross-reference their outputs. For this particular issue, I used tcmalloc, pprof, Python's gc package, the PyCharm debugger, and some additional Python code to walk the heap from inside the program under test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.2 Issues related to TF 2.2 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

10 participants