Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leak when using py_function inside tf.data.Dataset #35084

Closed
QiJune opened this issue Dec 13, 2019 · 15 comments
Closed

Memory leak when using py_function inside tf.data.Dataset #35084

QiJune opened this issue Dec 13, 2019 · 15 comments
Assignees
Labels
comp:eager Eager related issues TF 2.0 Issues relating to TensorFlow 2.0 type:performance Performance Issue

Comments

@QiJune
Copy link

QiJune commented Dec 13, 2019

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
  • OS Platform and Distribution: Linux Ubuntu 16.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary):
  • TensorFlow version (use command below): 2.0
  • Python version: 3.6.8
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

Describe the current behavior

屏幕快照 2019-12-13 下午6 06 20

Describe the expected behavior

The tf.data.Dataset instance should be freed in every step.

Code to reproduce the issue

import tensorflow as tf
import os
import numpy as np
import psutil

def _generator():
    for i in range(100):
        yield "1,2,3,4,5,6,7,8"

def _py_parse_data(record):
    record = record.numpy()
    record = bytes.decode(record)
    rl = record.split(",")
    rl = [str(int(r) + 1) for r in rl]
    return [",".join(rl)]

def parse_data(record, shape=10):
    sparse_data = tf.strings.split([record], sep=",")
    sparse_data = tf.strings.to_number(sparse_data[0], tf.int64)
    ids_num = tf.cast(tf.size(sparse_data), tf.int64)
    indices = tf.range(0, ids_num, dtype=tf.int64)
    indices = tf.reshape(indices, shape=(-1, 1))
    sparse_data = tf.sparse.SparseTensor(
                indices, sparse_data, dense_shape=(shape,)
    )
    return sparse_data

process = psutil.Process(os.getpid())

step = 0
while (step < 10000):
    t = tf.data.Dataset.from_generator(_generator, output_types=tf.string)
    t = t.map(lambda record: tf.py_function(_py_parse_data, [record], [tf.string]))
    t = t.map(parse_data)
    for d in t:
        a = 1
    if step % 10 == 0:
        print("Memory : ", process.memory_info().rss)
    step += 1
@ravikyram ravikyram self-assigned this Dec 16, 2019
@ravikyram ravikyram added comp:data tf.data related issues TF 2.0 Issues relating to TensorFlow 2.0 type:performance Performance Issue labels Dec 16, 2019
@ravikyram
Copy link
Contributor

@QiJune

I have tried on colab with TF version 2.0 ,2.1.0-rc1 . Please, find the gist here. Is this the expected behavior?

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Dec 16, 2019
@QiJune
Copy link
Author

QiJune commented Dec 16, 2019

@ravikyram Yes, exactly, the memory is increasing quickly.
I am not sure why the tf.data.Dataset objects are not freed.

@jsimsa
Copy link
Contributor

jsimsa commented Jan 3, 2020

@kkimdev this seems related to your work on properly garbage collecting traced functions that go out of scope ... could you please take a look? thank you

@jsimsa jsimsa assigned kkimdev and unassigned jsimsa Jan 3, 2020
@kkimdev
Copy link
Contributor

kkimdev commented Jan 5, 2020

List of leaking objects per 100 iterations:

=======================================================================
Type                     Old_ids  Current_ids      New_ids Count_Deltas
=======================================================================
dict                       49562        50462         +904         +900
cell                       19873        20673         +800         +800
tuple                      39958        40658         +700         +700
function                   71183        71783         +600         +600
list                       27319        27819         +502         +500
KeyedRef                    3800         4200         +400         +400
EagerTensor                 1900         2100         +200         +200
method                      1413         1513         +100         +100
_GeneratorState              950         1050         +100         +100
TensorShape                  953         1053         +100         +100
Tape                         950         1050         +100         +100
GradientTape                 950         1050         +100         +100
EagerFunc                    950         1050         +100         +100
StringIO                       3            3           +1           +0
wrapper_descriptor          3782         3782           +0           +0
weekday                       14           14           +0           +0
weakref                    13226        13226           +0           +0
weakcallableproxy              1            1           +0           +0
vectorize                      4            4           +0           +0
validate_nseq_int              1            1           +0           +0
validate_nseq_float            5            5           +0           +0
uname_result                   1            1           +0           +0
tzutc                          1            1           +0           +0
tzUTC                          1            1           +0           +0
type                        6277         6277           +0           +0
staticmethod                1212         1212           +0           +0
slice                         72           72           +0           +0
set                         6633         6633           +0           +0
scputimes                    113          113           +0           +0
pybind11_type                 49           49           +0           +0
=======================================================================

Leaking EagerFunc reference graph
image

So seems like the problem is py_function getting created every loop and _py_funcs_used_in_graph keeps growing.

@jsimsa
Copy link
Contributor

jsimsa commented Jan 6, 2020

Thanks Kibeom. So the problem is that _py_funcs_used_in_graph is not garbage collected?

tensorflow-copybara pushed a commit that referenced this issue Jan 7, 2020
Eager mode can incorrectly have a global graph.  Disabling global graph
on eager mode breaks too many assumptions so first introduce a flag indicating it.

Also, avoid attaching py_function to eager mode global graph, which is a leak.

Though this CL doesn't fix the leak yet as there are two more references that leads
to the leak, `tape_cache` and `ag_dnc_wrapper__` .

#35084

PiperOrigin-RevId: 288415011
Change-Id: Ica53e29521320af22c10609857d0a0219a9596ce
@kkimdev
Copy link
Contributor

kkimdev commented Jan 7, 2020

Actually, there shouldn't be a global graph at the first place since this is eager mode. With 3b74a63 , it should be correctly attached to a func graph's _py_funcs_used_in_graph and it will be gone when the func graph is garbage collected.

Though still there are two more causes for this leak, tape_cache and ag_dnc_wrapper__. We haven't found a good solution for these references yet.

@mdanatg fyi

tensorflow-copybara pushed a commit that referenced this issue Jan 8, 2020
Eager mode can incorrectly have a global graph.  Disabling global graph
on eager mode breaks too many assumptions so first introduce a flag indicating it.

Also, avoid attaching py_function to eager mode global graph, which is a leak.

Though this CL doesn't fix the leak yet as there are two more references that leads
to the leak, `tape_cache` and `ag_dnc_wrapper__` .

#35084

PiperOrigin-RevId: 288728035
Change-Id: I27c254de4323e3fcac9966294e624dda61f91cd2
@loretoparisi
Copy link

loretoparisi commented Jan 13, 2020

@kkimdev @jsimsa hello, we are having the same problem with TF 1.14.
We use tf.py_function to load a wave file:

results = tf.py_function(
                self.safe_load,
                [audio_descriptor, offset, duration, sample_rate, dtype],
                (tf.float32, tf.bool)),
            waveform, error = results[0]

putting this into a tf.Dataset:

dataset = dataset.map(
        lambda sample: dict(
            sample,
            **audio_adapter.load_tf_waveform(
                sample['audio_id'],
                session=session,
                sample_rate=sample_rate,
                offset=sample['start'],
                duration=sample['end'] - sample['start'])),
        num_parallel_calls=2)

and getting a leak, where the memory leaked is the size of the wave file being loaded:

after prediction traced memory: 28670 KiB  peak: 28673 KiB  overhead: 29677 KiB
after load traced memory: 28801 KiB  peak: 28808 KiB  overhead: 29755 KiB
after prediction traced memory: 53988 KiB  peak: 55396 KiB  overhead: 54529 KiB
after load traced memory: 54100 KiB  peak: 55396 KiB  overhead: 54604 KiB

@goldiegadde goldiegadde added comp:eager Eager related issues and removed comp:data tf.data related issues labels Jan 14, 2020
@zaccharieramzi
Copy link
Contributor

zaccharieramzi commented Apr 8, 2020

@loretoparisi do you also create the dataset in a for loop or do you instantiate it only once ?

I am asking because I suspect a memory leak as well, but I am only creating one dataset object and then training on it using fit.
On my side, I use tf.py_function to load HDF5 files because of this error in tfio.

I would also be interested in the script you used to get the last lines of your post.

@loretoparisi
Copy link

@zaccharieramzi yes sure. The source file, that you can even try yourself in the project is here: https://github.com/deezer/spleeter/blob/master/spleeter/dataset.py

@zaccharieramzi
Copy link
Contributor

@loretoparisi I am sorry but I am not sure I understand how to use the file you linked to generate the memory leak evidence.

I also don't see how it answers the question of whether you create datasets in a loop (i.e. call to tf.data.Dataset in a loop).

@mdanatg mdanatg added comp:data tf.data related issues and removed comp:data tf.data related issues labels Apr 9, 2020
@loretoparisi
Copy link

@loretoparisi I am sorry but I am not sure I understand how to use the file you linked to generate the memory leak evidence.

I also don't see how it answers the question of whether you create datasets in a loop (i.e. call to tf.data.Dataset in a loop).

The memory leak comes out when you run the framework over more samples, but yes it's complex to test it since it's a specific tool (audio separation).

@zaccharieramzi
Copy link
Contributor

Ok gotcha, but so you don't instantiate the dataset in a loop?

@amahendrakar amahendrakar self-assigned this Jul 9, 2020
@amahendrakar
Copy link
Contributor

@QiJune,
Is this still an issue? On running the code with TF v2.2, I did not observe much difference between each iteration.
Please find the gist of it here. Thanks!

@amahendrakar amahendrakar added the stat:awaiting response Status - Awaiting response from author label Jul 9, 2020
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jul 16, 2020
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@amahendrakar amahendrakar removed stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author labels Jul 24, 2020
copybara-service bot pushed a commit that referenced this issue Sep 22, 2021
…quested.

The memory leak test added for #35084 in tf.data still passes.

I renamed the confusing argument name eager= to use_eager_py_func=. The argument has little to do with eager mode; it only controls for which C++ op to use.

There may be a better way of fixing this. A few things to consider on top of
this CL:

- can we record the data actually to the tape instead of in our own cache?
  it feels like it is doable, but it is not so clear to me how to do this.

PiperOrigin-RevId: 398261944
Change-Id: I941bd9f7032ddfab209db161475c78068a17cc52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:eager Eager related issues TF 2.0 Issues relating to TensorFlow 2.0 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests