Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory leak in tf.keras.Model.predict #44711

Closed
plooney opened this issue Nov 9, 2020 · 25 comments
Closed

memory leak in tf.keras.Model.predict #44711

plooney opened this issue Nov 9, 2020 · 25 comments
Assignees
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.3 Issues related to TF 2.3 type:performance Performance Issue

Comments

@plooney
Copy link

plooney commented Nov 9, 2020

https://stackoverflow.com/questions/64199384/tf-keras-model-predict-results-in-memory-leak

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary):
  • TensorFlow version (use command below):
  • Python version:
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior

Describe the expected behavior

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

@plooney
Copy link
Author

plooney commented Nov 9, 2020

https://stackoverflow.com/questions/64199384/tf-keras-model-predict-results-in-memory-leak

Seems to be an issue with tf.keras.Model.predict

@plooney plooney changed the title memory leak memory leak in tf.keras.Model.predict Nov 9, 2020
@ravikyram
Copy link
Contributor

I have tried in TF GPU with version 2.3 gist ,nightly version(2.5.0-dev20201109) gist and was able to reproduce the issue.Thanks!

@ravikyram ravikyram added comp:keras Keras related issues TF 2.3 Issues related to TF 2.3 type:performance Performance Issue and removed type:bug Bug labels Nov 10, 2020
@jvishnuvardhan
Copy link
Contributor

jvishnuvardhan commented Nov 10, 2020

@plooney model.predict is a high-level API which is designed for batch-predicting outside of any loops. It automatically wraps your model into a tf.function and maintains graph based execution. Which means, if there is any change in input signature (shape and dtype) to that function (here model.predict), then it traces multiple models instead of a single model as you are expecting.

In your case, inImm is a numpy input which is considered as different signature each time you provide it in a for loop to a function wrapped by tf.function. However, providing inImm as a tensor will result in same input signature and hence there is a single graph to which these inputs are fed and results are obtained. In the numpy case, there are 60 static graphs (which is not what you want). As there are many static graphs, the memory is increasing with each for loop iteration.

When I added one line in your code, the code is not crashing. Please check the gist here. Thanks!

inImm=tf.convert_to_tensor(inImm)

Please read 1, 2, 3, and 4. These resources will help you more. Thanks!

Please close the issue if this was resolved for you. Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting response Status - Awaiting response from author label Nov 11, 2020
@bhack
Copy link
Contributor

bhack commented Nov 12, 2020

@jvishnuvardhan Predict in a loop it is a quite recurrent issue, I remember some weeks ago I've just triaged two of this tickets.
How we could better expose this in the docs? /cc @lamberta @MarkDaoust

@jvishnuvardhan
Copy link
Contributor

@bhack Good point. I think updating one of the docs (tutorial/guides) would help resolving this kind of issues. Thanks!

@bhack
Copy link
Contributor

bhack commented Nov 12, 2020

@jvishnuvardhan Yes we need to find a quite popular entry point in the Docs if any internal team member has some stats about Docs page views.

@bhack
Copy link
Contributor

bhack commented Nov 12, 2020

Also a more specific "entry level" warning (instead of the generic function retracing) could be very useful for newcomers.

@MarkDaoust
Copy link
Member

@jvishnuvardhan thanks for the clear explanation. If this were calling a tf.function in a loop that would be 100% the correct answer. But Model.predict manages some of this to avoid this problem (in general keras fit/evaluate/predict never require that the user convert inputs to tensors). It looks like something more complicated is happening.

The first two clues that suggest it are:

  1. It's not printing the frequent retracing warning.
  2. You're creating a single numpy array and passing that multiple times and it still goes OOM. Except for constants the caching logic is based on object identity, re-using the object should reuse the same function trace.

Investigating a little farther you can find the model.predict_function is the @tf.function that runs in here.
Inspecting that, both it's ._list_all_concrete_functions() and .pretty_printed_concrete_signatures() show that there is only one graph, and predict is handling the conversion of the numpy array to a Tensor.

So I agree that this is leaking memory somewhere. But I've confirmed that it's not the tf.function cache causing it.

@tomerk, you're pretty familiar with this code, do you have any ideas?

@bhack
Copy link
Contributor

bhack commented Nov 13, 2020

I've just modified the original stackoverflow mentioned code for Colab:

pip install --upgrade memory_profiler
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input,Conv2D, Activation

%load_ext memory_profiler

matrixSide = 512 #define a big enough matrix to give memory issues

inputL = Input([matrixSide,matrixSide,12]) #create a toy model
l1 = Conv2D(32,3,activation='relu',padding='same') (inputL) #120
l1 = Conv2D(64,1,activation='relu',padding='same')(l1)
l1 = Conv2D(64,3,activation='relu',padding='same')(l1)
l1 = Conv2D(1,1,padding='same')(l1)
l1 = Activation('linear')(l1)
model = tf.keras.Model(inputs= inputL,outputs = l1)


#run predictions
inImm = np.zeros((64,matrixSide,matrixSide,12))
for i in range (60):
  print(i)
  %memit outImm = model.predict(inImm)

@bhack
Copy link
Contributor

bhack commented Nov 13, 2020

peak memory: 5729.88 MiB, increment: 3190.20 MiB
1
peak memory: 6498.93 MiB, increment: 769.06 MiB
2
peak memory: 6499.14 MiB, increment: 0.21 MiB
3
peak memory: 6499.14 MiB, increment: 0.00 MiB
4
peak memory: 6499.14 MiB, increment: 0.00 MiB
5
peak memory: 6499.15 MiB, increment: 0.00 MiB
6
peak memory: 6499.17 MiB, increment: 0.02 MiB
7
peak memory: 6499.17 MiB, increment: 0.00 MiB
8
peak memory: 6499.17 MiB, increment: 0.00 MiB
9
peak memory: 6499.17 MiB, increment: 0.00 MiB
10
peak memory: 6499.17 MiB, increment: 0.00 MiB
11
peak memory: 6499.17 MiB, increment: 0.00 MiB
12
peak memory: 6499.18 MiB, increment: 0.01 MiB
13
peak memory: 6499.18 MiB, increment: 0.00 MiB
14
peak memory: 6499.19 MiB, increment: 0.00 MiB
15
peak memory: 6499.19 MiB, increment: 0.00 MiB
16
peak memory: 6499.19 MiB, increment: 0.00 MiB
17
peak memory: 6499.19 MiB, increment: 0.00 MiB
18
peak memory: 6499.19 MiB, increment: 0.00 MiB
19
peak memory: 6499.19 MiB, increment: 0.00 MiB
20
peak memory: 6499.19 MiB, increment: 0.00 MiB
21
peak memory: 6499.19 MiB, increment: 0.00 MiB
22
peak memory: 6499.19 MiB, increment: 0.00 MiB
23
peak memory: 6499.19 MiB, increment: 0.00 MiB
24
peak memory: 6499.19 MiB, increment: 0.00 MiB
25
peak memory: 6499.19 MiB, increment: 0.00 MiB
26
peak memory: 6499.19 MiB, increment: 0.00 MiB
27
peak memory: 6499.19 MiB, increment: 0.00 MiB
28
peak memory: 6499.19 MiB, increment: 0.00 MiB
29
peak memory: 6499.19 MiB, increment: 0.00 MiB
30
peak memory: 6499.19 MiB, increment: 0.00 MiB
31
peak memory: 6499.19 MiB, increment: 0.00 MiB
32
peak memory: 6499.19 MiB, increment: 0.00 MiB
33
peak memory: 6499.19 MiB, increment: 0.00 MiB
34
peak memory: 6499.19 MiB, increment: 0.00 MiB
35
peak memory: 6499.19 MiB, increment: 0.00 MiB
36
peak memory: 6499.19 MiB, increment: 0.00 MiB
37
peak memory: 6499.19 MiB, increment: 0.00 MiB
38
peak memory: 6499.19 MiB, increment: 0.00 MiB
39
peak memory: 6499.19 MiB, increment: 0.00 MiB
40
peak memory: 6499.19 MiB, increment: 0.00 MiB
41
peak memory: 6499.19 MiB, increment: 0.00 MiB
42
peak memory: 6499.19 MiB, increment: 0.00 MiB
43
peak memory: 6499.19 MiB, increment: 0.00 MiB
44
peak memory: 6499.19 MiB, increment: 0.00 MiB
45
peak memory: 6499.19 MiB, increment: 0.00 MiB
46
peak memory: 6499.19 MiB, increment: 0.00 MiB
47
peak memory: 6499.19 MiB, increment: 0.00 MiB
48
peak memory: 6499.19 MiB, increment: 0.00 MiB
49
peak memory: 6499.20 MiB, increment: 0.00 MiB
50
peak memory: 6499.20 MiB, increment: 0.00 MiB
51
peak memory: 6499.20 MiB, increment: 0.00 MiB
52
peak memory: 6499.20 MiB, increment: 0.00 MiB
53
peak memory: 6499.20 MiB, increment: 0.00 MiB
54
peak memory: 6499.20 MiB, increment: 0.00 MiB
55
peak memory: 6499.20 MiB, increment: 0.00 MiB
56
peak memory: 6499.20 MiB, increment: 0.00 MiB
57
peak memory: 6499.20 MiB, increment: 0.00 MiB
58
peak memory: 6499.20 MiB, increment: 0.00 MiB
59
peak memory: 6499.20 MiB, increment: 0.00 MiB

@MarkDaoust
Copy link
Member

MarkDaoust commented Nov 13, 2020

It doesn't OOM with memory_profiler running because %memit calls gc.collect:

https://github.com/pythonprofilers/memory_profiler/blob/bd6da910f791cef725640fa207fb4ee433c93586/memory_profiler.py#L1060

If you add a gc.collect there it continues as well. If you monkey-patch out gc.collect then you get the explosive memory growth.

So that points towards an issue with cyclical-garbage not getting cleaned up fast enough.

@bhack
Copy link
Contributor

bhack commented Nov 13, 2020

So that points this towards an issue with cyclical-garbage not getting cleaned up fast enough.

Yes It was one of the suspects

@bhack
Copy link
Contributor

bhack commented Nov 13, 2020

Also at step 0
outImm = model(inImm)

ResourceExhaustedError: OOM when allocating tensor with shape[64,512,512,64] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:BiasAdd]

@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Nov 20, 2020
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@tomerk tomerk reopened this Dec 1, 2020
@yichenj
Copy link

yichenj commented Mar 17, 2021

So? I understand the OOM of GPU is because the GPU memory is too small. It means it is not the same issue with the leakage problem. The bug i met is that after 7 hours of running model.predict() in a loop, there is a GPU OOM.

@herossa
Copy link

herossa commented Oct 17, 2022

My attempts:

2.4.1: leak
2.7.1: leak

How the problem occurs to me?

def predict(self, img: np.ndarray) -> np.ndarray:
	return self._model.predict(np.expand_dims(img, axis=0))

How did I solve it?

def predict(self, img: np.ndarray) -> np.ndarray:
    return self._model(convert_to_tensor(np.expand_dims(img, axis=0)), training=False).numpy()

@Dobiasd
Copy link

Dobiasd commented Oct 18, 2022

@jvishnuvardhan Predict in a loop it is a quite recurrent issue, I remember some weeks ago I've just triaged two of this tickets. How we could better expose this in the docs? /cc @lamberta @MarkDaoust

Lol, I'm fighting with the memory-leak problems in multiple TensorFlow service in PROD since years and implemented different things like watchers that check the memory usage to gracefully restart our workers before they OOM-crash in a job, and adding tf.config.threading.set_inter_op_parallelism_threads(1); tf.config.threading.set_intra_op_parallelism_threads(1) to reduce the amount of leakage, etc.

Just yesterday I finally discovered this. Maybe one can prevent future users like me from wasting so much time/energy on this by adjusting "What's the difference between Model methods predict() and __call__()?" in the Keras FAQ, which currently recommends using the memory-leaking way of doing predictions:

You should use model(x) when you need to retrieve the gradients of the model call, and you should use predict() if you just need the output value. In other words, always use predict() unless you're in the middle of writing a low-level gradient descent loop (as we are now).

🙂

@lukeconibear
Copy link

lukeconibear commented Jan 6, 2023

In case this helps:

If the dataset can fit in memory, then the following functions can replace the call to model.predict:

def generate_batches(
    x: np.ndarray | tf.Tensor, batch_size: int = 32
) -> np.ndarray | tf.Tensor:
    """Generate batches of test data for inference.

    Args:
        x (np.ndarray | tf.Tensor):
            Full test data set.
        batch_size (int, default=32):
            Batch size.

    Yields:
        np.ndarray | tf.Tensor:
            Yielded batches of test data.
    """
    for index in range(0, x.shape[0], batch_size):
        yield x[index : index + batch_size]


def predict(
    model: tf.keras.Model,
    x: np.ndarray | tf.Tensor,
    batch_size: int = 32,
) -> np.ndarray:
    """Predict using generated batched of test data.

    - Used instead of model.predict() due to memory leaks.
    - https://github.com/tensorflow/tensorflow/issues/44711

    Args:
        model (tf.keras.Model):
            The model to use for prediction.
        x (np.ndarray | tf.Tensor):
            Full test data set.
        batch_size (int, default=32):
            Batch size.

    Returns:
        np.ndarray:
            Predictions on the test data.
    """
    y_batches = []
    for x_batch in generate_batches(x=x, batch_size=batch_size):
        y_batch = model(x_batch, training=False).numpy()
        y_batches.append(y_batch)

    return np.concatenate(y_batches)


# instead of
# y_pred = model.predict(x_test)

# use
y_pred = predict(model=model, x=x_test, batch_size=32)

Else, if the dataset does not fit in memory, then consider using tf.data:

def create_tf_dataset(
    data_split: str,
    x: np.ndarray,
    y: np.ndarray,
    batch_size: int,
    use_mixed_precision: bool,
) -> tf.data.Dataset:
    """Create a TensorFlow dataset.

    - Cache train data before shuffling for performance (consider full dataset size).
    - Shuffle train data to increase accuracy (not needed for validation or test data).
    - Batch train data after shuffling for unique batches at each epoch.
    - Cache test data after batching as batches can be the same between epochs.
    - End pipeline with prefetching for performance.
    
    Args:
        data_split (str):
            The data split to create the dataset for.
            Supported are "train", "validation", and "test".
        x (np.ndarray):
            The feature data.
        y (np.ndarray):
            The target data.
        batch_size (int):
            The batch size.
        use_mixed_precision (bool):
            Whether to use mixed precision.

    Raises:
        ValueError: If the data split is not supported.

    Returns:
        tf.data.Dataset:
            The TensorFlow dataset.
    """
    if data_split not in {"train", "validation", "test"}:
        raise ValueError(f"Invalid data split: {data_split}")

    if use_mixed_precision:
        tf.keras.mixed_precision.set_global_policy("mixed_float16")
        x = x.astype(np.float16)
        y = y.astype(np.float16)

    ds = tf.data.Dataset.from_tensor_slices((x, y))

    if data_split == "train":
        ds = ds.cache()
        set_random_seed(seed=RANDOM_SEED)
        ds = ds.shuffle(number_of_samples, seed=RANDOM_SEED)
        ds = ds.batch(batch_size)
    else:
        ds = ds.batch(batch_size)
        ds = ds.cache()

    ds = ds.prefetch(AUTOTUNE)

    return ds


# need to do this call separately on a machine with enough memory
ds_test = create_tf_dataset(
    data_split="test",
    x=x_test,
    y=y_test,
    batch_size=32,
    use_mixed_precision=True,
)

# then use it
y_pred = model.predict(ds_test)

@CryptoSurff
Copy link

@lukeconibear worked for me thank you!

@MichaelBreslavskyNintex
Copy link

I had the same issue with '.predict', running on about 50,000 inputs over several hours, seen a leak of around 0.35 GB. Traced the leak back to the '.predict' method. Tried replacing it with the 'call' method, which solved the memory leak but was slower by about 50%.

nhuet added a commit to nhuet/decomon that referenced this issue Nov 2, 2023
This is innefficient and can lead to memory leaks.
See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

The issue even leads to crash in test suite on github for keras 3.0
(maybe also because of the tensorflow version used)
nhuet added a commit to nhuet/decomon that referenced this issue Nov 2, 2023
The idea is to avoid calling predict() which is known to be
not designed for small arrays, and leads to memory leaks when used in loops.

See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

Use it in wrapper instead of predict().
nhuet added a commit to nhuet/decomon that referenced this issue Nov 3, 2023
This is innefficient and can lead to memory leaks.
See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

The issue even leads to crash in test suite on github for keras 3.0
(maybe also because of the tensorflow version used)
nhuet added a commit to nhuet/decomon that referenced this issue Nov 3, 2023
The idea is to avoid calling predict() which is known to be
not designed for small arrays, and leads to memory leaks when used in loops.

See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

Use it in wrapper instead of predict().
nhuet added a commit to nhuet/decomon that referenced this issue Nov 6, 2023
This is innefficient and can lead to memory leaks.
See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

The issue even leads to crash in test suite on github for keras 3.0
(maybe also because of the tensorflow version used)
nhuet added a commit to nhuet/decomon that referenced this issue Nov 6, 2023
The idea is to avoid calling predict() which is known to be
not designed for small arrays, and leads to memory leaks when used in loops.

See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

Use it in wrapper instead of predict().
nhuet added a commit to nhuet/decomon that referenced this issue Nov 13, 2023
This is innefficient and can lead to memory leaks.
See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

The issue even leads to crash in test suite on github for keras 3.0
(maybe also because of the tensorflow version used)
nhuet added a commit to nhuet/decomon that referenced this issue Nov 13, 2023
The idea is to avoid calling predict() which is known to be
not designed for small arrays, and leads to memory leaks when used in loops.

See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

Use it in wrapper instead of predict().
nhuet added a commit to nhuet/decomon that referenced this issue Nov 20, 2023
This is innefficient and can lead to memory leaks.
See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

The issue even leads to crash in test suite on github for keras 3.0
(maybe also because of the tensorflow version used)
nhuet added a commit to nhuet/decomon that referenced this issue Nov 20, 2023
The idea is to avoid calling predict() which is known to be
not designed for small arrays, and leads to memory leaks when used in loops.

See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

Use it in wrapper instead of predict().
ducoffeM pushed a commit to airbus/decomon that referenced this issue Dec 4, 2023
This is innefficient and can lead to memory leaks.
See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

The issue even leads to crash in test suite on github for keras 3.0
(maybe also because of the tensorflow version used)
ducoffeM pushed a commit to airbus/decomon that referenced this issue Dec 4, 2023
The idea is to avoid calling predict() which is known to be
not designed for small arrays, and leads to memory leaks when used in loops.

See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

Use it in wrapper instead of predict().
ducoffeM pushed a commit to ducoffeM/decomon that referenced this issue Jan 17, 2024
This is innefficient and can lead to memory leaks.
See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

The issue even leads to crash in test suite on github for keras 3.0
(maybe also because of the tensorflow version used)
ducoffeM pushed a commit to ducoffeM/decomon that referenced this issue Jan 17, 2024
The idea is to avoid calling predict() which is known to be
not designed for small arrays, and leads to memory leaks when used in loops.

See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

Use it in wrapper instead of predict().
ducoffeM pushed a commit to airbus/decomon that referenced this issue Jan 17, 2024
This is innefficient and can lead to memory leaks.
See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

The issue even leads to crash in test suite on github for keras 3.0
(maybe also because of the tensorflow version used)
ducoffeM pushed a commit to airbus/decomon that referenced this issue Jan 17, 2024
The idea is to avoid calling predict() which is known to be
not designed for small arrays, and leads to memory leaks when used in loops.

See https://keras.io/api/models/model_training_apis/#predict-method and
tensorflow/tensorflow#44711

Use it in wrapper instead of predict().
@irreg
Copy link

irreg commented Apr 12, 2024

Switching to the __call__ function significantly reduced the amount of leakage, but still leaked about 70bytes per __call__ in my environment.
Finally, converting the keras model to a bare TensorFlow graph seems to have eliminated the leakage in my environment.

model: tf.keras.Model
x: np.ndarray

graph = tf.function(model)
# When processing large data, it is necessary to add logic to divide the data into small batches for processing.
result = graph(convert_to_tensor(x)) 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.3 Issues related to TF 2.3 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

14 participants