Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom shuffle layer leaks memory when run on Apple M1 GPU with tensorflow-metal #60616

Closed
sirno opened this issue May 17, 2023 · 3 comments
Closed
Assignees
Labels
comp:gpu GPU related issues stat:awaiting response Status - Awaiting response from author subtype:macOS macOS Build/Installation issues TF 2.12 For issues related to Tensorflow 2.12 type:bug Bug

Comments

@sirno
Copy link

sirno commented May 17, 2023

Click to expand!

Issue Type

Bug

Have you reproduced the bug with TF nightly?

No

Source

binary

Tensorflow Version

2.12.0

Custom Code

Yes

OS Platform and Distribution

macOS 13.0 (22A380)

Mobile device

Apple M1

Python version

3.10.10

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current Behaviour?

The following layer leaks memory during training with keras when running on an Apple M1 with tensorflow-macos and tensorflow-metal installed:

class Shuffle(keras.layers.Layer):
    def call(self, inputs):
        shape = tf.concat([tf.shape(inputs)[:-1], [1]], axis=0)
        rnd = tf.argsort(tf.random.uniform(shape), axis=1)
        return tf.gather_nd(inputs, rnd, batch_dims=2)

When run on the CPU alone it does not leak memory

Standalone code to reproduce the issue

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
from tensorflow_datasets.testing.mocking import mock_data

import numpy as np


class Shuffle(keras.layers.Layer):
    def call(self, inputs):
        shape = tf.concat([tf.shape(inputs)[:-1], [1]], axis=0)
        rnd = tf.argsort(tf.random.uniform(shape), axis=1)
        return tf.gather_nd(inputs, rnd, batch_dims=2)


def build_leaky_model(input_shape):
    input = keras.Input(input_shape)

    x = Shuffle()(input)

    x = layers.GlobalAveragePooling2D()(x)

    x = layers.Flatten()(x)

    output = layers.Dense(1, activation="sigmoid")(x)

    return keras.Model(
        input,
        output,
    )


epochs = 1000

data_set = tf.data.Dataset.from_generator(
    lambda: (
        (
            np.ones(shape=(1000, 5000, 1), dtype=np.uint8),
            i % 2,
        )
        for i in range(200)
    ),
    output_types=(tf.uint8, tf.int32),
    output_shapes=((1000, 5000, 1), ()),
)

data_set = data_set.batch(10)

model = build_leaky_model((1000, 5000, 1))

model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)

history = model.fit(
    data_set,
    epochs=epochs,
)


### Relevant log output

_No response_</details>
@google-ml-butler google-ml-butler bot added the type:bug Bug label May 17, 2023
@tilakrayal tilakrayal added TF 2.12 For issues related to Tensorflow 2.12 comp:gpu GPU related issues labels May 18, 2023
@tilakrayal tilakrayal assigned SuryanarayanaY and unassigned synandi May 18, 2023
@SuryanarayanaY SuryanarayanaY added the subtype:macOS macOS Build/Installation issues label May 18, 2023
@SuryanarayanaY
Copy link
Collaborator

Hi @sirno ,

First thing I want to clear that tensorflow-macos was built and maintained by Apple itself .

Hence I tried to test the code first with Regular Tensorflow package tensorflow, for confirming the memory leakage problem exists with TF package also.I have executed the code on colab with GPU environment and observed no memory leakage(ran upto 175 epochs) and attached gist here for reference.

This indicates the issue is specific to only tensorflow-macos package and shall be addressed by Apple tensorflow-metal developes. You can post the issue here .

Thanks!

@SuryanarayanaY SuryanarayanaY added the stat:awaiting response Status - Awaiting response from author label May 18, 2023
@sirno
Copy link
Author

sirno commented May 19, 2023

thanks for the response and the reference to the Apple forums.

I will resubmit the issue in the Apple forums.

@sirno sirno closed this as completed May 19, 2023
@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:gpu GPU related issues stat:awaiting response Status - Awaiting response from author subtype:macOS macOS Build/Installation issues TF 2.12 For issues related to Tensorflow 2.12 type:bug Bug
Projects
None yet
Development

No branches or pull requests

4 participants