Skip to content

tf.data Dataset: Warning when caching validation set. "You should use dataset.take(k).cache().repeat() instead." #61160

@munsteraner

Description

@munsteraner

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

source

TensorFlow version

tf 2.9.0

Custom code

Yes

OS platform and distribution

Ubuntu 22.04 LTS

Mobile device

No response

Python version

3.9.5

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

cuda_11.3.r11.3/compiler.29920130_0

GPU model and memory

NVIDIA A100-SXM4-80GB

Current behavior?

Whenever I use the cache function on my tf.data validation dataset I get the warning below. When I use the cache only and without the validation set, no warning appears.

Standalone code to reproduce the issue

dataset = tf.data.Dataset.from_generator(pygen.generator, args=[files,minmax],output_signature=(
    tf.TensorSpec(shape=s[0], dtype=tf.float32),
    tf.TensorSpec(shape=s[1], dtype=tf.float32)))

    val_dataset = tf.data.Dataset.from_generator(pygen.generator, args=[val_files,minmax],output_signature=(
    tf.TensorSpec(shape=s[0], dtype=tf.float32),
    tf.TensorSpec(shape=s[1], dtype=tf.float32)))
dataset = dataset.take(len(files)).cache().batch(args.bs).repeat(args.epochs).prefetch(tf.data.AUTOTUNE) 
    
val_dataset = val_dataset.take(len(val_files)).cache(filename=f'{tempfile.gettempdir()}/val').batch(64).repeat(args.epochs).prefetch(tf.data.AUTOTUNE)
strategy = tf.distribute.MultiWorkerMirroredStrategy()
        with strategy.scope():
            m = unet28.build(s[0])
            m.fit(dataset,validation_data=val_dataset, epochs=args.ep, steps_per_epoch = spe,validation_steps = vspe,callbacks=[model_checkpoint_callback,save50,model_csv_logger,model_tensorboard,model_earlystopping_30],verbose=2)

Relevant log output

2023-07-03 13:47:14.532355: W tensorflow/core/kernels/data/cache_dataset_ops.cc:296] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Metadata

Metadata

Assignees

Labels

TF 2.9Issues found in the TF 2.9 release (or RCs)comp:datatf.data related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authortype:bugBug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions