Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset.map() with tf.data.experimental.AUTOTUNE runs out of memory when using batch size=1 #33516

Closed
EduardoGRocha opened this issue Oct 18, 2019 · 9 comments
Assignees
Labels
comp:data tf.data related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug

Comments

@EduardoGRocha
Copy link

EduardoGRocha commented Oct 18, 2019

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): YES
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 19.04 Ubuntu 18.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: Not tried on Mobile
  • TensorFlow installed from (source or binary): BINARY
  • TensorFlow version (use command below): 2.0.0
  • Python version: 3.7
  • Bazel version (if compiling from source): NO
  • GCC/Compiler version (if compiling from source): NO
  • CUDA/cuDNN version: CUDA 10.0 cuDNN 7.6
  • GPU model and memory: RTX2070 8GB

Describe the current behavior
I use Dataset.map to normalize images. When using tf.data.experimental.AUTOTUNE and BATCH_SIZE of 1, memory consumption grows up till the program is killed. The most intriguing part is that when setting the BATCH_SIZE to greater than 1, the program works correctly

This issue happens both with tensorflow 2.0.0 and tensorflow-gpu 2.0.0

Describe the expected behavior
Code should work for batch size of 1

Code to reproduce the issue

import tensorflow as tf
from tensorflow.keras.layers import Flatten, Dense, Reshape
from tensorflow.keras.losses import MeanSquaredError


(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')

# Setting this here to True will break the code
TOGGLE_ERROR = True
if TOGGLE_ERROR:
    BATCH_SIZE = 1
else:
    BATCH_SIZE = 3


def map_function(train_image):
    return (train_image - 127.5) / 127.5


train_dataset = tf.data.Dataset.from_tensor_slices(train_images)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.map(map_function, tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(64)


class AutoEncoder(tf.keras.Model):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.flatten = Flatten()
        self.dense_1 = Dense(128, activation="relu")
        self.dense_2 = Dense(784, activation="relu")
        self.reshape = Reshape((28, 28, 1))

    @tf.function
    def call(self, inputs):
        flatten = self.flatten(inputs)
        encoded = self.dense_1(flatten)
        decoded = self.dense_2(encoded)
        return self.reshape(decoded)


auto_encoder = AutoEncoder()
mse = MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(1e-5)


@tf.function
def train_step(batch):
    with tf.GradientTape() as tape:
        auto_encoded = auto_encoder(batch)
        loss = mse(batch, auto_encoded)

    grads = tape.gradient(loss, auto_encoder.trainable_variables)
    optimizer.apply_gradients(zip(grads, auto_encoder.trainable_variables))
    return loss


for step, image_batch in enumerate(train_dataset):
    loss = train_step(image_batch)
    if step % 1000 == 0:
        print(loss)

Other info / logs
Might be related to Issue #32052

@chtheiss
Copy link

System Information

  • Linux Ubuntu 18.04
  • TensorFlow installed from: pip install tensorflow-gpu
  • TensorFlow version: 2.0.0-beta1
  • Python version: 3.6.8
  • Bazel version (if compiling from source): NO
  • GCC/Compiler version (if compiling from source): NO
  • CUDA/cuDNN version: CUDA 10.0 cuDNN 7.6
  • GPU model and memory: GTX 1080 Ti

Describe the current behavior

I found the same problem as @EduardoGRocha that for batch size 1 and number of parallel calls set to AUTOTUNE for map the memory consumption of the program rises until the program is killed.

This problem seems to occur due to an infinite number of calls to the map function.

Below you can find a minimal example.

Code

import tensorflow as tf
import numpy as np
import time

data = np.random.normal(size=(200,32,32))
tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(data)

######################## WORKING ########################

# batch size equal to 1 works with any number
# greater or equal to 1 for number of parallel calls
map_dataset = tensor_slice_dataset.map(
    lambda x: x*x, num_parallel_calls=1)
batch_dataset = map_dataset.batch(1)

for v in batch_dataset:
    time.sleep(0.1)

##################### ALSO WORKING #####################

# batch size greater than 1 works with any number
# greater or equal to 1 and also with tf.data.experimental.AUTOTUNE
map_dataset = tensor_slice_dataset.map(
    lambda x: x*x, num_parallel_calls=tf.data.experimental.AUTOTUNE)
batch_dataset = map_dataset.batch(2)

for v in batch_dataset:
    time.sleep(0.1)

######################## BROKEN  ########################

# batch size equal to 1 DOES NOT work with
# tf.data.experimental.AUTOTUNE
map_dataset = tensor_slice_dataset.map(
    lambda x: x*x, num_parallel_calls=tf.data.experimental.AUTOTUNE)
batch_dataset = map_dataset.batch(1)

for v in batch_dataset:
    time.sleep(0.1)

@rmothukuru rmothukuru self-assigned this Oct 21, 2019
@rmothukuru rmothukuru added TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug comp:data tf.data related issues labels Oct 21, 2019
@rmothukuru
Copy link
Contributor

Could reproduce this issue with TF Version 2.0 in Google Colab, with CPU and GPU as Runtime. Here is the Gist.

@rmothukuru rmothukuru assigned rsepassi and unassigned rmothukuru Oct 21, 2019
@rmothukuru rmothukuru added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Oct 21, 2019
@naturomics
Copy link

naturomics commented Oct 31, 2019

Found the same issue in my project

Update:
In case wrapping tf.data.Dataset with tf.distribute.Strategy.experimental_distribute_dataset and batch size = 1*num_replicas (batch size = 1 for each replica), the problem still exists.

Sorry I am not allowed to share the code for reproducing this problem,
my work flow is standard like this:

batch_size = 1 * num_gpus

files = tf.data.Dataset.list_files(tfrecord_files_pattern)
if training:
    files.shuffle(buffer_size=1024)
dataset = files.interleave(tf.data.TFRecordDataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.repeat(1)
num_parallel_calls = tf.data.experimental.AUTOTUNE if batch_size > 1 else 8
dataset = dataset.map(parse_fn, num_parallel_calls=num_parallel_calls)
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

...
strategy = ... # tf.distribute.Strategy here
dataset = strategy.experimental_distribute_dataset(dataset)

For ppl come across the same issue, it's not recommended to use tf.data.experimental.AUTOTUNE at this time (tf 2.0), manage buffer_size/num_parallel_calls yourself.

@domsew
Copy link

domsew commented Nov 11, 2019

Hello, I faced the same issue. I can confirm that the problem occurs on:

  • Windows 10, 32 GB RAM, tf version: v2.0.0-rc2-26-g64c3d382ca 2.0.0
  • Windows 7, 8 GB RAM, tf version: v2.0.0-beta1-5101-gc75bb66a99 2.0.0-rc0

Python version is 3.7.3
It consumes all available RAM, causing out of memory error on both devices.

@naturomics
Copy link

Update again, this is really a serous bug! With the pipeline in my previous comment, and batch size = 1024, the RAM was eat slowly from ~20% at the start of training to out of memory after one night training (total memory is 48G). So,

  1. with batch size = 1 for each gpus, the bug is triggered and runs out the memory after several training step.
  2. with batch size > 1 for each gpus, the memory increases slowly.
  3. without any AUTOTUNE at any batch size: testing.

@rsepassi rsepassi assigned jsimsa and unassigned rsepassi Nov 14, 2019
@jsimsa
Copy link
Contributor

jsimsa commented Nov 14, 2019

@rachellim could you please take a look?

@jsimsa jsimsa assigned rachellim and unassigned jsimsa Nov 14, 2019
@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 14, 2019
@rachellim
Copy link
Contributor

Thanks for the repro. Looking into it.

@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

rachellim pushed a commit to rachellim/tensorflow that referenced this issue Nov 21, 2019
…l_calls = autotune, batch_size = 1.

Closes tensorflow#33516.

PiperOrigin-RevId: 281775472
Change-Id: Ie10cea0ef1515d5aff8e3dddadc069ddee1a5a76
@rachellim
Copy link
Contributor

This was indeed intriguing. I've submitted a fix, which will be in TF 2.1 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug
Projects
Development

No branches or pull requests

9 participants