Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling XLA in tensorflow 2.16 causes memory leaks #64170

Open
Di-Is opened this issue Mar 21, 2024 · 9 comments
Open

Enabling XLA in tensorflow 2.16 causes memory leaks #64170

Di-Is opened this issue Mar 21, 2024 · 9 comments
Assignees
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower subtype: ubuntu/linux Ubuntu/Linux Build/Installation Issues TF 2.16 type:performance Performance Issue

Comments

@Di-Is
Copy link

Di-Is commented Mar 21, 2024

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

binary

TensorFlow version

tensorflow 2.16.1
(tensorflow[and-cuda] 2.16.1)

Custom code

Yes

OS platform and distribution

Linux Ubuntu 22.04

Mobile device

No response

Python version

3.12

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

12.3/8.9.7

GPU model and memory

RTX3060Ti、RTX3060

Current behavior?

Executing the tf.keras.Model.fit method with XLA enabled will cause a memory leak.
Note that XLA seems to be enabled by default since Tensorflow 2.16.1.
Setting tf.keras.Model.jit_compile to False disables XLA and eliminates the memory leak.

I think that updating to tensorflow 2.16 will cause memory leaks in almost all existing programs.
I think you need to fix this problem as soon as possible or alert people about it in documents such as Release Note or README.

Standalone code to reproduce the issue

import numpy as np
import tensorflow as tf
import gc


def call(epochs: int):
    x = np.arange(-1, 1, 0.01)
    y = 0.8 * x + 0.2
    model = tf.keras.Sequential([tf.keras.layers.Dense(1, activation=None)])
    model.compile("sgd", "mse")
    # model.jit_compile = False
    model.build(input_shape=(0,1))
    model.fit(x, y, epochs=epochs)
    del model

epoch_num = 10000
for i in range(epoch_num):
    call(1)
    tf.keras.backend.clear_session()
    gc.collect()

Relevant log output

No response

@google-ml-butler google-ml-butler bot added the type:bug Bug label Mar 21, 2024
@Di-Is Di-Is changed the title Memory leak when XLA is enabled Enabling XLA in tensorflow 2.16 causes memory leaks Mar 21, 2024
@Di-Is
Copy link
Author

Di-Is commented Mar 22, 2024

@NBCBM
Please do not post low-quality LLM-generated text.

@Venkat6871 Venkat6871 added comp:xla XLA subtype: ubuntu/linux Ubuntu/Linux Build/Installation Issues TF 2.16 labels Mar 25, 2024
@Venkat6871
Copy link

Venkat6871 commented Mar 25, 2024

Hi @Di-Is ,
I tried to run your code on Colab using TF v2.16 and faced the same issue. Please find the gist here for reference.

Thank you!

@Venkat6871 Venkat6871 added the stat:awaiting response Status - Awaiting response from author label Mar 25, 2024
@Di-Is
Copy link
Author

Di-Is commented Mar 25, 2024

@Venkat6871
Thank you for your reply!
I am using tensorflow compiled for Nvidia GPU(tensorflow[and-cuda] 2.16.1).
It seems there is also a difference in the Python version you confirmed.

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Mar 25, 2024
@Di-Is
Copy link
Author

Di-Is commented Mar 25, 2024

I observed memory leaks in Google Colab with T4 GPU.

@juanma9613
Copy link

juanma9613 commented Apr 4, 2024

@Venkat6871, do you have any progress on this issue.

I'm also affected because of this issue, I tried will all tf versions from 2.11 up to 2.16 and it seems like this happens since tf 2.12. It seems this was not happening for tf 2.11

thank you

@sgkouzias
Copy link

I face the exact problem. Using Ubuntu 20.04, NVIDIA RTX3060 & python=3.11

@Venkat6871, do you have any progress on this issue?

@AshwinAmbal
Copy link

This document from NVIDIA on tweaking environment variables for XLA memory could help. We're yet to test this out though and will try to keep this thread updated on any findings:
https://docs.nvidia.com/deeplearning/frameworks/tensorflow-user-guide/index.html#xla-best-practices

Btw, the formation of clusters / Operation fusing / compilation using XLA increases the memory for each different shape of input that the framework comes across. If we keep the shape constant, we've found that after X amount of time, the memory growth stabilizes even with TF 2.15.

@BobbyWilt
Copy link

BobbyWilt commented Apr 17, 2024

I'm also encountering the exact same problem using Ubuntu 20.04 NVIDIA GTX1070 and python 3.10. Would be great to get this fixed since the memory leakage can get pretty astronomical if left unchecked. I've had it grow up to 12gB until it crashed my kernel.

@SuryanarayanaY
Copy link
Collaborator

Hi @Di-Is ,

I have replicated the reported memory leak with XLA. Attached gist for reference.

@SuryanarayanaY SuryanarayanaY added type:performance Performance Issue stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed type:bug Bug labels Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower subtype: ubuntu/linux Ubuntu/Linux Build/Installation Issues TF 2.16 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests

7 participants