Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras has memory leak when passing in dataset object to predict(...) function #30448

Closed
davidparks21 opened this issue Jul 6, 2019 · 4 comments
Assignees
Labels
comp:keras Keras related issues TF 1.14 for issues seen with TF 1.14 type:bug Bug

Comments

@davidparks21
Copy link

davidparks21 commented Jul 6, 2019

Summary
Performance degrades quickly and memory increases consistently when calling the Keras predict function in a loop with a dataset object. This does not happen when passing predict a numpy array, or when passing in a tensor from a dataset iterator.

System information

  • Have I written custom code: Minimally reproducible example below uses only stock 1.14.0 code.
  • OS Platform and Distribution: Ubuntu 18.04 / Linux Mint 19.1
  • TensorFlow installed from (source or binary): pip install tensorflow-gpu (example not using GPU, CUDA_VISIBLE_DEVICES=-1)
  • TensorFlow version (use command below): v1.14.0-rc1-22-gaf24dc9 1.14.0
  • Python version: 3.7.3

Describe the current behavior
Looping over model.predict(x=mydataset) in a continuous loop degrades in performance after a few hundred iterations. The minimally reproducible example below starts at ~0.04s per loop iteration and within about a minute of running is near 0.5s per loop iteration. Memory continues to climb.

This does not happen when passing in a numpy array to model.predict(x=myndarray). The problem is also less severe when passing in tf.data.Iterator rather than a tf.data.Dataset. If you pass an iterator the performance will continue to degrade at a fifth to a tenth the rate.

The cause of the difference between the dataset performance and the iterator performance is likely at training_utils.py:1314 where Keras creates a new iterator for each predict loop.

The issue is completely ameliorated when passing predict the tensor produced from tf.data.make_one_shot_iterator(mydataset).get_next(). In this case no additional dataset operations appear to be created by keras in the predict loop.

Describe the expected behavior
Multiple calls to predict should not degrade in performance over time when passing in a dataset.

Code to reproduce the issue
This code reproduces the issue and is copy/paste runnable, performance will degrade significantly within ~30 seconds running this example.

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()

debug_time = time.time()
while True:
    model.predict(x=ds, steps=1)
    print('Processing time {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()

This example demonstrates passing a numpy array does not have the same issue.

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)

debug_time = time.time()
while True:
    model.predict(x=np_data)  # using numpy array directly
    print('Processing time {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()

This issue started at SO at: https://stackoverflow.com/questions/56910950/keras-predict-loop-memory-leak-using-tf-data-dataset-but-not-with-a-numpy-array

I decided to post it here when I realized that predict is creating a new iterator each predict loop iteration, and works when the get_next tensor is passed in directly.

@gadagashwini-zz gadagashwini-zz self-assigned this Jul 9, 2019
@gadagashwini-zz gadagashwini-zz added comp:keras Keras related issues type:bug Bug labels Jul 9, 2019
@gadagashwini-zz
Copy link
Contributor

I could able to reproduce the issue by provided code snippet on Colab with Tensorflow 1.14.0. Thanks!

@goldiegadde goldiegadde added the TF 1.14 for issues seen with TF 1.14 label Jul 10, 2019
@robieta
Copy link

robieta commented Jul 18, 2019

Indeed, this is because in v1 making a dataset iterator adds ops to the graph.
print('Processing time {:.2f}, {} ops on graph'.format(time.time() - debug_time, len(inp.graph.get_operations())))

In v2 (tf.enable_v2_behavior()), you'll see that there is no accumulation of ops and run time does not increase over subsequent model.predict calls.

It's not obvious why the performance drops off so quickly since the dataset iterator doesn't add that many. In any event though, as long as you use 2.0 you should be fine.

@robieta robieta closed this as completed Jul 18, 2019
@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@deadsoul44
Copy link

I still have this problem with TF 2.0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 1.14 for issues seen with TF 1.14 type:bug Bug
Projects
None yet
Development

No branches or pull requests

6 participants