Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Leak in tf.data.Dataset.from_generator #37653

Closed
hankcs opened this issue Mar 17, 2020 · 29 comments
Closed

Memory Leak in tf.data.Dataset.from_generator #37653

hankcs opened this issue Mar 17, 2020 · 29 comments
Assignees
Labels
comp:data tf.data related issues stat:awaiting response Status - Awaiting response from author TF 2.7 Issues related to TF 2.7.0 type:performance Performance Issue

Comments

@hankcs
Copy link

hankcs commented Mar 17, 2020

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock
    example script provided in TensorFlow): No
  • OS Platform and Distribution (e.g.,
    Linux Ubuntu 16.04): Ubuntu 18.04.3 LTS
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
    the issue happens on mobile device:
  • TensorFlow installed from (source or
    binary): binary
  • TensorFlow version (use command below): v2.1.0-rc2-17-ge5bf8de 2.1.0
  • Python version: Python 3.6.6
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from
    source):
  • CUDA/cuDNN version:
    CUDA Version: 10.1
    cudnn-10.1
  • GPU model and memory:
    TITAN RTX
    24190MiB

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" 2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
tf.data.Dataset.from_generator leaks memory after each call even if followed by gc.collect().

Describe the expected behavior
Memory should be released when no reference exists for the dataset.

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

import gc
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
import tracemalloc
import linecache


def display_top(snapshot, key_type='lineno', limit=3):
    snapshot = snapshot.filter_traces((
        tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
        tracemalloc.Filter(False, "<unknown>"),
    ))
    top_stats = snapshot.statistics(key_type)

    print("Top %s lines" % limit)
    for index, stat in enumerate(top_stats[:limit], 1):
        frame = stat.traceback[0]
        # replace "/path/to/module/file.py" with "module/file.py"
        filename = os.sep.join(frame.filename.split(os.sep)[-2:])
        print("#%s: %s:%s: %.1f KiB"
              % (index, filename, frame.lineno, stat.size / 1024))
        line = linecache.getline(frame.filename, frame.lineno).strip()
        if line:
            print('    %s' % line)

    other = top_stats[limit:]
    if other:
        size = sum(stat.size for stat in other)
        print("%s other: %.1f KiB" % (len(other), size / 1024))
    total = sum(stat.size for stat in top_stats)
    print("Total allocated size: %.1f KiB" % (total / 1024))


def generator():
    yield tf.zeros(2, 3)


tracemalloc.start()
for i in range(1000):
    dataset = tf.data.Dataset.from_generator(generator, output_types=tf.int32, output_shapes=[None])
    del dataset
    gc.collect()
    snapshot = tracemalloc.take_snapshot()
    display_top(snapshot)

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

Top 3 lines
#1: python3.6/_weakrefset.py:84: 159.5 KiB
    self.data.add(ref(item, self._remove))
#2: python3.6/_weakrefset.py:37: 38.2 KiB
    self.data = set()
#3: python3.6/_weakrefset.py:48: 32.4 KiB
    self._iterating = set()
461 other: 306.4 KiB
Total allocated size: 536.4 KiB
Top 3 lines
#1: python3.6/_weakrefset.py:84: 159.5 KiB
    self.data.add(ref(item, self._remove))
#2: python3.6/_weakrefset.py:37: 38.2 KiB
    self.data = set()
#3: python3.6/_weakrefset.py:48: 32.4 KiB
    self._iterating = set()
516 other: 343.1 KiB
Total allocated size: 573.1 KiB

...

Top 3 lines
#1: python3.6/weakref.py:335: 257.8 KiB
    self = ref.__new__(type, ob, callback)
#2: debug/tf_dataset_memory_leak.py:45: 189.7 KiB
    dataset = tf.data.Dataset.from_generator(generator, output_types=tf.int32, output_shapes=[None])
#3: ops/script_ops.py:257: 174.7 KiB
    return "pyfunc_%d" % uid
519 other: 2423.3 KiB
Total allocated size: 3045.5 KiB

It leaks 3MB in 1000 calls. In some real projects, it can leak as much as 5GB and keeps increasing.

@gadagashwini-zz
Copy link
Contributor

Was able to replicate the issue with Tf 2.1.
Please find the gist here. Thanks!

@gadagashwini-zz gadagashwini-zz added type:performance Performance Issue and removed type:bug Bug labels Mar 17, 2020
@gowthamkpr gowthamkpr assigned jsimsa and unassigned gowthamkpr Mar 17, 2020
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 17, 2020
@jsimsa
Copy link
Contributor

jsimsa commented Mar 17, 2020

@kkimdev could you please take a look? Thank you.

@jsimsa jsimsa assigned kkimdev and unassigned jsimsa Mar 17, 2020
@kkimdev
Copy link
Contributor

kkimdev commented Mar 17, 2020

memory debug colab link : https://colab.corp.google.com/drive/1TYW_QWcJ6j6FfuepdTu6nfj1TP2Cw5PU#scrollTo=ViLA0wnRmblE

leak per 100 iteration:

============================================================================
Type                          Old_ids  Current_ids      New_ids Count_Deltas
============================================================================
cell                            74412        75112         +700         +700
dict                           100309       100909         +604         +600
tuple                           84895        85395         +500         +500
function                       113257       113757         +500         +500
list                            50680        50980         +302         +300
KeyedRef                        26530        26830         +300         +300
method                           9305         9405         +100         +100
_GeneratorState                  8845         8945         +100         +100
TensorShape                      8849         8949         +100         +100
Dimension                        8844         8944         +100         +100

Reference graph from _GeneratorState https://graphviz.corp.google.com/svg?graph_id=2afe255c1644cc79fe98e01ab09c6be8

Seems like the leaking edge is _py_funcs_used_in_graph

@kkimdev
Copy link
Contributor

kkimdev commented Mar 17, 2020

Probably one of the calls script_ops.numpy_function(...) in from_generator(...)

_py_funcs_used_in_graph is just to maintain the lifetime of the function, so I think the correct fix is attaching to something else, not the graph. @mdanatg @akshaym

@mdanatg
Copy link

mdanatg commented Mar 18, 2020

The function needs to be attached to the graph, otherwise the py_func Op would have nothing to call. There should be no reference from the function to the graph though - it should have no knowledge of it. That said, it's easy for it to close over something that points to it.

@kkimdev
Copy link
Contributor

kkimdev commented Mar 18, 2020

btw, FYI: in this case it was the global graph.

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 20, 2020
@luvwinnie
Copy link

luvwinnie commented Jul 11, 2020

any progress in this issues? It seems like still occurs in tensorflow-gpu==2.2.0.
Please provide some information to prevent this problem..... My process has been killed by this OOM problems

@hankcs
Copy link
Author

hankcs commented Jul 11, 2020

4 months passed without any progress. I can do nothing but rewrite my entire project into PyTorch. I really love Keras and hope one day the user experience of TensorFlow will line up with Keras.

@luvwinnie
Copy link

4 months passed without any progress. I can do nothing but rewrite my entire project into PyTorch. I really love Keras and hope one day the user experience of TensorFlow will line up with Keras.

I have a pretrained model for my product So I can't change to PyTorch easily unless I have a confidence that converting the weights giving the same performance due to the Precision between Float and Double which I think double data type is being use in keras.

What if we use only normal Python Generator? Shouldn't that be solve....? Anyone has try that instead of using tf.data.Dataset.from_generator?

@kkimdev
Copy link
Contributor

kkimdev commented Jul 12, 2020

@luvwinnie @hankcs I deeply apologize for the issues. Would you mind trying this workaround?

# Before calling `tf.data.Dataset.from_generator`.
tf.compat.v1.get_default_graph()._py_funcs_used_in_graph = []

@luvwinnie
Copy link

luvwinnie commented Jul 12, 2020

@kkimdev I have tried by adding the _py_funcs_used_in_graph, however I got this error instead. By the way I used tf.distribute.MirroredStrategy() for multi-gpu, I tried strategy.make_dataset_iterator or strategy.experimental_distribute_dataset gives the same results

Traceback (most recent call last):
  File "train.py", line 1228, in <module>
    main(args)
  File "train.py", line 1001, in main
    ditributred_train(train_iterator)
  File "/misc/home/usr16/cheesiang_leow/.virtualenvs/tensorflow-gpu-2.3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/misc/home/usr16/cheesiang_leow/.virtualenvs/tensorflow-gpu-2.3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 644, in _call
    return self._stateless_fn(*args, **kwds)
  File "/misc/home/usr16/cheesiang_leow/.virtualenvs/tensorflow-gpu-2.3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2420, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/misc/home/usr16/cheesiang_leow/.virtualenvs/tensorflow-gpu-2.3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1665, in _filtered_call
    self.captured_inputs)
  File "/misc/home/usr16/cheesiang_leow/.virtualenvs/tensorflow-gpu-2.3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1746, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/misc/home/usr16/cheesiang_leow/.virtualenvs/tensorflow-gpu-2.3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 598, in call
    ctx=ctx)
  File "/misc/home/usr16/cheesiang_leow/.virtualenvs/tensorflow-gpu-2.3/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  ValueError: callback pyfunc_0 is not found
Traceback (most recent call last):

  File "/misc/home/usr16/cheesiang_leow/.virtualenvs/tensorflow-gpu-2.3/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py", line 232, in __call__
    raise ValueError("callback %s is not found" % token)

ValueError: callback pyfunc_0 is not found


	 [[{{node PyFunc}}]]
	 [[MultiDeviceIteratorGetNextFromShard]]
	 [[RemoteCall]]
	 [[IteratorGetNextAsOptional_1]]
  (1) Cancelled:  Function was cancelled before it was started
0 successful operations.
1 derived errors ignored. [Op:__inference_ditributred_train_28590]

Function call stack:
ditributred_train -> ditributred_train

@hankcs
Copy link
Author

hankcs commented Jul 12, 2020

Hi @kkimdev , thanks for your workaround. I tried it on 2.1 and 2.2, neither works.

@luvwinnie The parallelization and cache can't be easily implemented using NumPy. Besides, I have a whole pile of data pipelines built on top of tf.data API. To abandon all these was a hard decision for me. But I guess that's life... I'm happy with PyTorch so far, its data API is not as good as TF but it is stable at least.

One user of my project said he ended up with using a docker and he has to restart it once tf exceeds the memory quota. I don't think this is the right thing to do for a decent project.

@luvwinnie
Copy link

luvwinnie commented Jul 12, 2020

@hankcs yes i agree with you, I have an exactly same situation with you which my pile of data pipelines built on top of tf.data API too. one of my user he said that he ended up using a normal python generator instead, but doing so, he can't use multiple GPU which the tensorflow need a tf.data API.

@kkimdev I think this is a urgent issues for every heavy user of tf.data. How come this didn't solve in 4months? I think this is a problem of eager execution currently the multi-gpu strategy still depends on old Session thing.

@kkimdev
Copy link
Contributor

kkimdev commented Jul 12, 2020

Another workaround try:

# Cleanup utility class
class TfDataset(object):
    def __init__(self):
        self.py_func_set_to_cleanup = set()

    def from_generator(self, generator, output_types, output_shapes=None, args=None):
        if not hasattr(tf.compat.v1.get_default_graph(), '_py_funcs_used_in_graph'):
            tf.compat.v1.get_default_graph()._py_funcs_used_in_graph = []
        py_func_set_before = set(tf.compat.v1.get_default_graph()._py_funcs_used_in_graph)
        result = tf.data.Dataset.from_generator(generator, output_types, output_shapes, args)
        py_func_set_after = set(tf.compat.v1.get_default_graph()._py_funcs_used_in_graph) - py_func_set_before
        self.py_func_set_to_cleanup |= py_func_set_after
        return result
  
    def cleanup(self):
        new_py_funcs = set(tf.compat.v1.get_default_graph()._py_funcs_used_in_graph) - self.py_func_set_to_cleanup
        tf.compat.v1.get_default_graph()._py_funcs_used_in_graph = list(new_py_funcs)
        self.py_func_set_to_cleanup = set()

# Usage example
tf_dataset = TfDataset()
dataset = tf_dataset.from_generator(generator, output_types=tf.int32, output_shapes=[None])
del dataset
tf_dataset.cleanup()  # Call this after done using the generator.

Sample Colab run: https://colab.research.google.com/gist/kkimdev/0047ce8444a14197c60c19bba0349156/copy-of-untitled463.ipynb

@luvwinnie
Copy link

luvwinnie commented Jul 12, 2020

@kkimdev Thank you for your workaround! just another question about this, you did a del dataset and tf_dataset.cleanup(), I have a loop like below, what do you suggest the timing for cleanup the dataset?

strategy = tf.distribute.MirroredStrategy()
train_iterator = strategy.make_dataset_iterator(train_dataset)
def train_step(inputs):
    images, labels = inputs
    with tf.GradientTape() as tape:
        y_pred = recognizer(images, training=True)
        loss = compute_loss(
                   y_pred, labels
        )
    losses.update_state(loss)
    gradients = tape.gradient(loss, recognizer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, recognizer.trainable_variables))
@tf.function
def distributred_train(dataset):
        return strategy.experimental_run(train_step, dataset)

for epoch in range(epochs):
    for steps in range(train_steps_per_epoch):
        ditributred_train(train_iterator)
        # it use the train_iterator to run the distributed training step.
        # How should I suppose to reset the dataset?

@kkimdev
Copy link
Contributor

kkimdev commented Jul 12, 2020

@luvwinnie Actually, your snippet could be a different issue. This issue is for memory leaks when tf.data.Dataset.from_generator is called repeatedly. Could you file another issue? (ideally with a reproducible Colab example like this bug)

@luvwinnie
Copy link

@kkimdev I trying to find out the problem, but it seems like the problem of multi-GPUs settings, I would like to open an issue with a when I'm able to make a reproducible code since the Colab environment can't use 2gpus

@alexpyattaev
Copy link

The issue is unrelated to GPU, and is perfectly reproducible with CPU training. I am using flat numpy arrays as training input, so it is not a dataset/generator issue either. Memory leak rate seems to depend on training set size though, I am collecting some stats, will update post in a couple of hours once I get a nice plot of RAM use vs time.

@heilwood
Copy link

heilwood commented Nov 2, 2020

Found a workaround, del together with gc.collect() working fine for me.:

   `images_list = np.vstack(images_to_load)
    pred_result = self.model.predict(images_list, batch_size=10)

    del(pred_result)
    gc.collect()`

@jjrugui
Copy link

jjrugui commented Jan 11, 2021

@kkimdev I am going to check your workaround right now since I am experiencing memory leaks using the from_generator method.

If I understand from your snipped provided in Colab, the dataset is cleared at every iteration. Would you do this periodically but not necessarily every step? I'm asking this because it seems that if the dataset is cleared at every step there is no gain in terms of caching (say the dataset is cleared every x epochs)

@kkimdev
Copy link
Contributor

kkimdev commented Jan 11, 2021

@jjrugui Yes that sounds reasonable, but let's first confirm that the memory leak could be resolved by using the wrapper.

@SysuJayce
Copy link

@jjrugui Yes that sounds reasonable, but let's first confirm that the memory leak could be resolved by using the wrapper.

@kkimdev actually, from_tensor_slices() also faces memory leak problem

@SodaGremlin
Copy link

I am also seeing this error occur with using tf agents's tf_uniform_replay_buffer and calling the as_dataset method. After each call (regardless of replay size or variables I am getting):

after episode 2

dict                               49318      +128
list                               20866       +60
_Listener                           4494       +42
weakproxy                           4494       +42
RepeatedCompositeFieldContainer      615       +12
AttrValue                           1502       +10
MessageMap                           647        +8
set                                 7087        +7
TensorShape                          679        +7
ArgDef                               276        +7

after episode 3:

dict                               49446      +128
list                               20926       +60
_Listener                           4536       +42
weakproxy                           4536       +42
RepeatedCompositeFieldContainer      627       +12
AttrValue                           1512       +10
MessageMap                           655        +8
TensorShape                          686        +7
ArgDef                               283        +7
tuple                              49100        +6

I have removed all my training code. The only method being called in this loop is:

dataset = replay_buffer.as_dataset(
        ...
      single_deterministic_pass=False
    )

However, if I change it to:

dataset = replay_buffer.as_dataset(
        ...
      single_deterministic_pass=True
    )

the memory leak goes away 100% ... Not sure what the tf agent code is doing with that variable, but it might help someone track down the problem or help someone else using tf agents.

@tilakrayal
Copy link
Contributor

@hankcs ,
Can you please look at this workaround in latest tf v2.7 and let us know if the issue still persists.Thanks!

@tilakrayal tilakrayal added the stat:awaiting response Status - Awaiting response from author label Dec 23, 2021
@tilakrayal tilakrayal self-assigned this Dec 23, 2021
@hankcs
Copy link
Author

hankcs commented Dec 23, 2021

@hankcs , Can you please look at this workaround in latest tf v2.7 and let us know if the issue still persists.Thanks!

Thank you for the workaround and I can confirm it works with the latest v2.7: https://colab.research.google.com/drive/1xNMnqzM0Zrfr3I8XwcS0E9kdhhslxuZd?usp=sharing

@tilakrayal
Copy link
Contributor

@hankcs ,
Please feel free to move this issue to closed status as it answers your question.Thanks!

@tilakrayal tilakrayal added TF 2.7 Issues related to TF 2.7.0 stat:awaiting response Status - Awaiting response from author and removed stat:awaiting response Status - Awaiting response from author TF 2.1 for tracking issues in 2.1 release labels Dec 24, 2021
@hankcs
Copy link
Author

hankcs commented Dec 24, 2021

Thank you so much!

@Kornel
Copy link

Kornel commented Aug 9, 2023

Why is this issue closed @tilakrayal , I'm still experiencing the same issue. With tf 2.11 (v2.11.0-rc2-17-gd5b57ca93e5, 2.11.0 specifically) in my case. Is the proposed solution to use the workaround?

To reproduce, I'm using mostly the same as @hankcs. In my 'real' scenario the leak is in the order of 40GB which is a bit problematic to be honest.

import gc

import psutil
import tensorflow as tf


def generator():
    yield tf.random.randn((1000, 1000))


gc.collect()

process = psutil.Process()

deltas = []

for i in range(1000):
    mem_used_0 = process.memory_info().rss
    # create a dataset from a generator, and cleanup immediatelly
    dataset = tf.data.Dataset.from_generator(
        generator, output_types=tf.float64, output_shapes=[None]
    )
    del dataset
    gc.collect()

    # collect memory usage
    mem_used_1 = process.memory_info().rss

    # store the memory usage delta
    delta = mem_used_1 - mem_used_0
    deltas.append(delta)

# How much memory is leaking?
delta_sum_mb = sum(deltas) / 1024**2
print(f"Sum of memory leaking: {delta_sum_mb:.2f}MB")

Output:

Sum of memory leaking: 125.18MB

@nfergu
Copy link

nfergu commented Apr 26, 2024

If anyone is experiencing this, one workaround is to create the dataset in a new thread each time.

This is because the object that's holding on to the data is the thread-local global default graph. As I understand it this is an object that exists to provide compatibility with TensorFlow version 1.

Here's the chain of references that causes the leak:

╙── ndarray instance (id=6222109296)
    └─╼ get_shuffled_data_generator.<locals>.generator.input_data (closure) (id=6986368624)
        └─╼  dict (object) (id=6844999744)
            └─╼ _GeneratorState (object) (id=6083798848)
                └─╼ DatasetV2.from_generator.<locals>.finalize_fn.<locals>.finalize_py_func.generator_state (closure) (id=6986368240)
                    └─╼ DatasetV2.from_generator.<locals>.finalize_fn.<locals>.finalize_py_func.func (closure) (id=6986368048)
                        └─╼ list (object) (id=6175396416)
                            └─╼  Graph._py_funcs_used_in_graph (instance attribute) (id=6079720320)
                                └─╼  Graph (object) (id=6079694016)
                                    └─╼  dict[_global_default_graph] (id=4988446144)
                                        └─╼  dict[<weakref at 0x129560770; to '_thread._localdummy' at 0x129419a50>] (id=4988445824)

In my case my generator function is a closure (get_shuffled_data_generator.<locals>.generator) that has a reference to my data. So the thread-local global default graph keeps a reference to my closure, which keeps a reference to my data. Creating a new thread for each dataset avoids the leak by clearing the thread-local state.

It may also be possible to avoid this by using a regular (non-closure) function, and passing the data via from_generator's args argument, but I haven't tried this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues stat:awaiting response Status - Awaiting response from author TF 2.7 Issues related to TF 2.7.0 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests