Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf Dataset / Iterator console flood when using CUDA builds #12414

Closed
georgesterpu opened this issue Aug 19, 2017 · 48 comments
Closed

tf Dataset / Iterator console flood when using CUDA builds #12414

georgesterpu opened this issue Aug 19, 2017 · 48 comments
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug

Comments

@georgesterpu
Copy link
Contributor

georgesterpu commented Aug 19, 2017

System information

  • **Have I written custom code **: Yes
  • **OS Platform and Distribution **: Manjaro linux, Arch Linux repo
  • TensorFlow installed from (source or binary): binary
  • **TensorFlow version **: 1.3.0
  • Python version: 3.6.2
  • **Bazel version **:
  • CUDA/cuDNN version: cuda 8.0.61, cudnn 7.0.1 & cudnn6 6.0.21
  • GPU model and memory: Nvidia 1080 GTX 8GB
  • Exact command to reproduce:

The problem

When using a tensorflow wheel built with cuda support, my app prints the following warning message at the end of a training epoch:

2017-08-19 14:01:18.214060: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?,132], [?], [?,?], [?]], output_types=[DT_FLOAT, DT_INT64, DT_INT64, DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]

The code trains a seq2seq model, and I assume the message gets printed somewhere downstream of seq2seq.dynamic_decode. The message still gets printed even when the NN cells are not wrapped with a tf.contrib.rnn.DeviceWrapper with device field indicating a GPU, only works fine on non-cuda builds.

All of this happens while the code is protected with the try/except statements:

        for epoch in range(num_epochs):
            session.run(iterator.initializer)
            while True:
                try:
                    session.run([operation])
                except tf.errors.OutOfRangeError:
                    break

Now the only cheeky thing is that I am using the binaries from Arch Linux repositories, but these are far from being dodgy.
python-tensorflow
python-tensorflow-cuda
The build script

This problem was in tensorflow 1.2 and persists in tensorflow 1.3.
Also tested on a laptop without dedicated gpu, but same OS and packages, works fine.

@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 19, 2017

this is a side effect of using tf.contrib.data iterators with small batches (as you are doing for inference). @mrry any way to disable the logging in a configurable way?

@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 19, 2017

(nothing to do with dynamic_rnn).

@georgesterpu
Copy link
Contributor Author

Hi @ebrevdo, by small batches you refer to the size of most batches or the size of the last one ?
The warning is there in training too, for any batch size apparently, including the case when the batch size evenly divides the training examples.

I thought it has something to do with dynamic_decode because simply iterating through the dataset and printing it to the console, for example, works fine.

@mrry
Copy link
Contributor

mrry commented Aug 19, 2017

How is the dataset defined?

I think there's an issue in the 1.3 release when you have a Dataset.map() that contains a tf.py_func() that ultimately raises StopIteration... essentially it gets reported as an errors::OutOfRange() and logged, rather than end_of_sequence = true which exits silently. (I recently noticed fixed that bug in an internal branch, and it should appear in the nightlies soon.)

@georgesterpu
Copy link
Contributor Author

    dataset = tf.contrib.data.TFRecordDataset(record)
    dataset = dataset.map(lambda p: _parse_function(p, feature_size))

_parse_function is a Python function, does tf.py_func get called internally ?
The problem also existed in TF 1.2, CUDA build only

@mrry
Copy link
Contributor

mrry commented Aug 19, 2017

What's the source of _parse_function()?

@georgesterpu
Copy link
Contributor Author

_parse_function takes an Example and an int, applies tf.parse_example, returns a tuple of one tf.float32 and three tf.int64 (the context and the features)

@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 20, 2017 via email

@georgesterpu
Copy link
Contributor Author

georgesterpu commented Aug 20, 2017

I get it at the end of each training epoch.
If using BeamSearchDecoder with attention, it gets printed multiple times, depending on the beam width.

One workaround that I could find is stopping the loop after the last batch, if I know in advance their number (instead of using the infinite loop and stopping on exception)

Here is the content of my _parse_function(), couldn't make it work as a tf.py_func though, with the tuple trick indicated by @mrry:

def _parse_function(example, feature_size):

    context_features = {
        "input_length": tf.FixedLenFeature(shape=[], dtype=tf.int64),
        "labels_length": tf.FixedLenFeature(shape=[], dtype=tf.int64),
    }
    sequence_features = {
        "inputs": tf.FixedLenSequenceFeature(shape=[feature_size], dtype=tf.float32),
        "labels": tf.FixedLenSequenceFeature(shape=[], dtype=tf.int64)
    }

    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=example,
        context_features=context_features,
        sequence_features=sequence_features
    )

    return sequence_parsed["inputs"], context_parsed["input_length"], \
           sequence_parsed["labels"], context_parsed["labels_length"]

@ali01 ali01 added stat:awaiting tensorflower Status - Awaiting response from tensorflower type:support Support issues labels Aug 28, 2017
@rightaditya
Copy link

I can confirm the same issue on a couple different Ubuntu versions (and some unknown Linux distro), Python 3.5 and 3.6, TensorFlow 1.2 and 1.3, cudnn 5.1 and cudnn 6, and on Telsa P100, GeForce GTX 1070 and Pascal Titan X.

I get the message at the end of each epoch, on both training and dev set. Here's my dataset definition:

def setup_dataset(filename: str, batch_size: int, shuffle_seed=None):
    n_recs, labels = tfrecord_metadata(filename)
    dataset = tfdata.TFRecordDataset(filename)
    dataset = dataset.skip(1)  # I store some metadata in the first record
    dataset = dataset.map(parse_label_sequence_example if label_seq
                          else parse_example)
    if shuffle_seed is not None:
        dataset = dataset.shuffle(n_recs, shuffle_seed)

    dataset = dataset.map(lambda x, y: (x, y, build_mask(y)))
    padded_shapes = (tf.TensorShape([None]), tf.TensorShape([None],),
                     tf.TensorShape([None]))
    dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes)

    return dataset, labels

I use a batch size of 8, and shuffle the training set but not the dev set (but the warning shows up in all cases).

@georgesterpu
Copy link
Contributor Author

What happens to your data once it is fetched and parsed by the iterator ?

@rightaditya
Copy link

How much do you need? Here's how it starts:

shuf_seed = tf.placeholder(tf.int64, shape=[])
trn_data, n_cat = setup_dataset(args.train, args.batch_size, shuf_seed)
dev_data, _ = setup_dataset(args.dev, args.batch_size)

it = tfdata.Iterator.from_structure(trn_data.output_types,
                                    trn_data.output_shapes)
x, y, y_mask = it.get_next()
y = tf.reshape(y, [-1])
y_mask = tf.cast(tf.reshape(y_mask, [-1]), tf.float32)
temporal_padding = args.filter_size[0] - 1
t_pad_before = temporal_padding // 2
t_pad_after = temporal_padding - t_pad_before
x = tf.pad(x, [[0, 0], [t_pad_before, t_pad_after]])
trn_init = it.make_initializer(trn_data)
dev_init = it.make_initializer(dev_data)
<snip>
emb_layer = tf.Variable(embeddings, trainable=args.trainable_embeddings,
                        name='embedding_matrix')
x_embedded = tf.nn.embedding_lookup(emb_layer, x)

@georgesterpu
Copy link
Contributor Author

Never mind, I made a wrong assumption. Looking at the first comment of @mrry, the issue might be fixed by now in tf nightly 1.4.

@rightaditya
Copy link

I'd be willing to test it out---I'm not sure what the git model is that's used for TF, but if I compile from master and test that, will that suffice?

@georgesterpu
Copy link
Contributor Author

georgesterpu commented Sep 20, 2017 via email

@rightaditya
Copy link

This issue is specific to the GPU builds, and the TF README.md says the nightlies are CPU builds only---has that changed?

@georgesterpu
Copy link
Contributor Author

georgesterpu commented Sep 20, 2017 via email

@rightaditya
Copy link

I just cloned master yesterday and compiled that and gave it a shot, but the error still shows up for me with that latest build. Not sure how it was for @georgesterpu but I actually get multiple copies of the message:

2017-09-23 10:21:59.360135: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence
	 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?], [?,?], [?,?]], output_types=[DT_INT64, DT_INT64, DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]
2017-09-23 10:21:59.360135: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence
	 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?], [?,?], [?,?]], output_types=[DT_INT64, DT_INT64, DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]
2017-09-23 10:21:59.360160: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence
	 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?], [?,?], [?,?]], output_types=[DT_INT64, DT_INT64, DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]
2017-09-23 10:21:59.360195: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence
	 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?], [?,?], [?,?]], output_types=[DT_INT64, DT_INT64, DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]

On my dev set, which is smaller, I just get one:

2017-09-23 10:22:01.758969: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence
	 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?], [?,?], [?,?]], output_types=[DT_INT64, DT_INT64, DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator)]]

@fumihwh
Copy link
Contributor

fumihwh commented Sep 29, 2017

I meet the same problem and fix by flowing.
tr_data = Dataset.from_tensor_slices((train_images, train_labels)).repeat().batch(100)
Just add .repeat().

@rightaditya
Copy link

@fumihwh That won't work for me, as I need to know when the epoch is done.

@fumihwh
Copy link
Contributor

fumihwh commented Oct 5, 2017

@rightaditya how about this with .repeat(). I know your problem now and use following to avoid it. If I do not use this, I got same error log with yours.

for epoch in range(num_epochs):
    session.run(iterator.initializer)
    for _ in range(dataset_size / batch_size):
        session.run([operation])

@georgesterpu
Copy link
Contributor Author

I have just cloned the master repo and compiled from sources. This time with cuDNN 7 :)
Warning still there

@rightaditya
Copy link

@fumihwh Is your dataset size divisible by your batch size? Mine isn't, so I'd probably have to ceil the division (and make sure to use float division if using python2).

Still, it's not an ideal solution given the way the API seems to have been designed.

@ybsave
Copy link

ybsave commented Oct 13, 2017

@fumihwh Is the .repeat() function only suitable for training? I suppose it's not suitable for testing, as I only want to evaluate once; will the make_one_shot_iterator() only run the data set once even if my dataset size is NOT divisible by my batch size? My test seems that the evaluation would go infinite loops when using .repeat().

I encountered the same error as described by previous users during testing but not in training. During evaluation, there are thousands of " Out of range: End of sequence" errors (I guess the number of errors are the same as the number of my evaluation samples). But correct evaluation results are still printed out after those errors, and the program did not crash and can still continue training. Anyone knows the reason and how to fix it? Thank you.

I used estimator, and the input functions are

def input_fn(mode):
  """Input function which provides a single batch for train or eval."""
  dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames(mode))
  if mode == tf.estimator.ModeKeys.TRAIN:
	dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
  dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)

  if mode == tf.estimator.ModeKeys.TRAIN:
	dataset = dataset.repeat()

  dataset = dataset.map(lambda value: dataset_parser(value, mode),
						num_threads=FLAGS.map_threads,
						output_buffer_size=FLAGS.batch_size)

  if mode == tf.estimator.ModeKeys.TRAIN:
	dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)

  iterator = dataset.batch(FLAGS.batch_size).make_one_shot_iterator()
  images, labels = iterator.get_next()
  return images, labels

My model function is

_DEVICE_LIST = ['/gpu:0', '/gpu:1']
def imagenet_model_fn(features, labels, mode):
  """ Our model_fn for ResNet to be used with our Estimator."""
  tf.summary.image('images', features, max_outputs=6)

  with tf.device('/cpu:0'):
	split_batch = tf.split(features, len(_DEVICE_LIST))
	split_labels = tf.split(labels, len(_DEVICE_LIST))
	
	if mode == tf.estimator.ModeKeys.TRAIN:
	  global_step = tf.train.get_or_create_global_step()

	  # Multiply the learning rate by 0.1 at 30, 60, 120, and 150 epochs.
	  boundaries = [
		int(batches_per_epoch * epoch) for epoch in [30, 60, 120, 150]]
	  values = [
		_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 1e-3, 1e-4]]
	  learning_rate = tf.train.piecewise_constant(
		tf.cast(global_step, tf.int32), boundaries, values)

	  # Create a tensor named learning_rate for logging purposes.
	  tf.identity(learning_rate, name='learning_rate')
	  tf.summary.scalar('learning_rate', learning_rate)

	  optimizer = tf.train.MomentumOptimizer(
		learning_rate=learning_rate,
		momentum=_MOMENTUM)

	tower_grads = []
	tower_cross_entropy = []
	tower_reg_loss = []
	tower_preds = []

	with tf.variable_scope(tf.get_variable_scope()):
	  for dev_idx, (device, device_features, device_labels) in enumerate(zip(
		_DEVICE_LIST, split_batch, split_labels)):
		with tf.device(device):
		  with tf.name_scope('device_%d' % dev_idx):
			logits = network(inputs=device_features,
							 is_training=(mode == tf.estimator.ModeKeys.TRAIN))

			tf.get_variable_scope().reuse_variables()
	  
			tower_pred = {
			  'classes': tf.argmax(logits, axis=1),
			  'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
			}

			tower_preds.append(tower_pred)

			cross_entropy = tf.losses.softmax_cross_entropy(
			  logits=logits, onehot_labels=device_labels)
			tower_cross_entropy.append(cross_entropy)

			reg_loss = FLAGS.weight_decay / len(_DEVICE_LIST) * tf.add_n(
			  [tf.nn.l2_loss(v) for v in tf.trainable_variables()])
			tower_reg_loss.append(reg_loss)            

			loss = cross_entropy + reg_loss

			if mode == tf.estimator.ModeKeys.TRAIN:          
			  grads = optimizer.compute_gradients(loss)
			  tower_grads.append(grads)

	predictions = {
	  'classes': tf.concat([p['classes'] for p in tower_preds], axis=0),
	  'probabilities':
			tf.concat([p['probabilities'] for p in tower_preds], axis=0)
	}

	if mode == tf.estimator.ModeKeys.PREDICT:
	  return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)    

	cross_entropy = tf.add_n(tower_cross_entropy)
	tf.identity(cross_entropy, name='cross_entropy')
	tf.summary.scalar('cross_entropy', cross_entropy)

	reg_loss = tf.add_n(tower_reg_loss)
	tf.summary.scalar('reg_loss', reg_loss)

	loss = cross_entropy + reg_loss
	tf.summary.scalar('total_loss', loss)

	accuracy = tf.metrics.accuracy(
			  tf.argmax(labels, axis=1), predictions['classes'])
	metrics = {'accuracy': accuracy}
	  
	if mode == tf.estimator.ModeKeys.TRAIN:
	  tf.identity(accuracy[1], name='train_accuracy')
	  tf.summary.scalar('train_accuracy', accuracy[1])
	  
	  # Batch norm requires update_ops to be added as a train_op dependency.
	  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
	  with tf.control_dependencies(update_ops):
		grads = average_gradients(tower_grads)
		train_op = optimizer.apply_gradients(grads, global_step=global_step)
	else:
	  train_op = None
	
	return tf.estimator.EstimatorSpec(
	  mode=mode,
	  predictions=predictions,
	  loss=loss,
	  train_op=train_op,
	  eval_metric_ops=metrics)

@mmuneebs
Copy link

mmuneebs commented Oct 15, 2017

I get flooded the same way using today's tf-nightly-gpu wheel binary on Ubuntu 16.04, Python 2.7, cudnn 6.

@post2web
Copy link

I get the same warning with tensorflow 1.3

@monomon
Copy link

monomon commented Oct 18, 2017

I am experiencing this with tensorflow-gpu on Windows 10, Python 3.6.1, cudnn64_6, using tf.slim.

Still need to verify this, but it seems the message only comes up when training for the first time, without a checkpoint. The dataset size is divisible by the batch size.

@thjashin
Copy link
Contributor

Same problem with TF1.3

@knn1989
Copy link

knn1989 commented Nov 1, 2017

Same problem. I got a bunch of error logs at the end of each epoch. Very annoying.

"2017-10-31 18:59:22.473230: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence [[Node: IteratorGetNext = IteratorGetNextoutput_shapes=[[?,120,120,3], [?,120,120,1]], output_types=[DT_FLOAT, DT_UINT8], _device="/job:localhost/replica:0/task:0/cpu:0"]]"

@JeremyCCHsu
Copy link

Same problem with TF 1.4
(built from source, CUDA 9.0 cuDNN 7.0, Ubuntu 16.04, Python 3.5)

@ming616
Copy link

ming616 commented Nov 4, 2017

I think i can guess out what cause this problems but i do not know how to resolve it.
That is :
There are multi threads in our training dataset instances(num_parallel_calls =4 in my config) ,
but during our training process, we catch the except tf.errors.OutOfRangeError .
I guess the try block only deal this exception only once. So we still get another three excepts below.

2017-11-04 10:47:47.713841: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence
2017-11-04 10:47:47.713915: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence
2017-11-04 10:47:47.714404: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence

Maybe we should treat this training process with multi threads, so that each of training threads only consume the return data form just one dataset thread.

@JeremyCCHsu
Copy link

JeremyCCHsu commented Nov 5, 2017

Maybe there are some other causes because by setting the num_parallel_calls argument to 1 in tf.data.Dataset.map doesn't prevent the warning from flooding.

@ebrevdo
Copy link
Contributor

ebrevdo commented Nov 5, 2017 via email

@bantmensc
Copy link

I can also confirm this problem and I am also not passing a py_func to dataset.map.

@mrry mrry added type:bug Bug and removed type:support Support issues labels Nov 14, 2017
@mrry mrry self-assigned this Nov 14, 2017
@jhseu jhseu closed this as completed in 301a6c4 Nov 16, 2017
@libulin
Copy link

libulin commented Nov 17, 2017

@ybsave
i also meeting this problem same for you,but it is not a bug.
you can see the doc in https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator
about train()
steps: Number of steps for which to train model. If None, train forever or train until input_fn generates the OutOfRange error or StopIteration exception. 'steps' works incrementally. If you call two times train(steps=10) then training occurs in total 20 steps. If OutOfRange or StopIteration occurs in the middle, training stops before 20 steps. If you don't want to have incremental behavior please set max_steps instead. If set, max_steps must be None.

@mrry
Copy link
Contributor

mrry commented Nov 17, 2017

Thanks for clarifying that @libulin!

While it's true that the log messages are not strictly a bug, we appreciate that it might be confusing or annoying to have a flood of messages in your console from code that's operating correctly :). With the fix in 301a6c4 (and a related fix for the StopIteration logging in c154d47), the logs should be much quieter when using tf.data.

k-w-w pushed a commit that referenced this issue Nov 17, 2017
…ernel.

A failing call to `Send()` indicates that the step has been aborted by
a corresponding call to `Rendezvous::StartAbort()`. As a result, the
error logged by `Send()` is not particularly informative, and creates
a non-deterministic amount of extra log spam for each step that fails
as `Send()` calls are being issued. The failure that causes the step to be
aborted is logged separately by the kernel that failed, unless that kernel
deliberately does not log on failure.

In particular, this change reduces log spam when using `Iterator.get_next()`
in a multi-device setting. The `Iterator.get_next()` op deliberately does not
log when an `OutOfRange` error (indicated the end of the dataset) is raised,
because this is common and expected behavior, especially when using an
initializable iterator that is reinitialized at the end of an epoch. Previously,
when running in distributed mode or using a GPU, pending `Send()` calls may
cause unwanted log messages to be printed.

Fixes #12414.

PiperOrigin-RevId: 175716290
@betterenvi
Copy link
Contributor

betterenvi commented Dec 18, 2017

I am using TF 1.4.0. The following works, though weird.

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # ERROR
import tensorflow as tf

@maxfiedler
Copy link

I am also getting this Warning W tensorflow/core/framework/op_kernel.cc:1192] Out of range: StopIteration: Iteration finished.
I am using a feedable iterator and the warning appears when the iterator is switching from the train-iterator-handle to the validation-iterator-handle (however not every time it switches)

@mrry
Copy link
Contributor

mrry commented Jan 15, 2018

@maxfiedler Can you reproduce this with TF 1.5.0? If so, please open a new issue with code to reproduce the warning and we'll take a look.

@maxfiedler
Copy link

@mrry I am still working with TF 1.4.0. I will let you know if the problem persists, once we switched to TF 1.5.0 and produce a code example if so

@maxfiedler
Copy link

maxfiedler commented Jan 17, 2018

A few more things I notice (still on TF 1.4).
I am also getting this Warning

2018-01-17 13:35:41.716773: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: End of sequence
	 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,90,120,3], [?,10800,21]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

Plus a tf.errors.OutOfRangeError
when I am using a one_shot_iterator without repeat() that runs over my test_set AND a batch size that IS cleanly divides the number of elements in the test data
(and I am getting way more than double those warnings when running on 2 GPUs instead of 1 GPU)
Relevant snippet of my code

self.test_set = self.create_test_dataset()
        if self.parse_fn:
            self.test_set = self.test_set.map(self.parse_fn, num_parallel_calls=self.num_parallel_processes)
        if self.preprocessing_fn:
            self.test_set = self.test_set.map(self.preprocessing_fn, num_parallel_calls=self.num_parallel_processes)
        # self.test_set = self.test_set.repeat(1)  # this line is not needed, but makes the behavior more explicit
        self.test_set = self.test_set.batch(self.batch_size)
        iterator = self.test_set.make_one_shot_iterator()
        features, targets = iterator.get_next()

and

with self._session as sess:
                while not self._stop_testing:
                    # Start the actual calculation
                    loss, accuracy = sess.run(
                        [self._net.loss, self._net.perf_metrics[0]], feed_dict={
                            self._net.net_is_training: False})

Is there a way to signal the while loop that the iterator has reached the end of the dataset instead of throwing an OutOfRangeError?

I am following the "load on batch (i.e. one get_next() op) and distribute it over the GPUs via tf.split" scheme. When the last batch is not cleanly splittable I get understandably an tf.InvalidArgumentError.

But before I get a ton of the following warnings:

[[Node: split_batch = Split[T=DT_FLOAT, num_split=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](split_batch_1/split_dim, IteratorGetNext)]]
2018-01-17 14:35:58.198566: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 10) and num_split 3

@scsanad
Copy link

scsanad commented Mar 14, 2018

I have the same Out of range: End of sequence warning with tf 1.4.
Any experience with 1.5 or higher?

@rightaditya
Copy link

@tanabics it was fixed for me with 1.5+

@terminsen
Copy link

Hi - I have a related problem. I am able to run several epochs, but then it suddenly stops ...

Start training process ...
Iter: 0, Loss: 4.2690
Iter: 1, Loss: 0.6611
Iter: 2, Loss: 0.3918
Iter: 3, Loss: 0.2462
Iter: 4, Loss: 0.1588
Iter: 5, Loss: 0.1042
Iter: 6, Loss: 0.0817
Iter: 7, Loss: 0.0698
Iter: 8, Loss: 0.0631
Iter: 9, Loss: 0.0607
Iter: 10, Loss: 0.0573
Iter: 11, Loss: 0.0561

The process is stopped and I get the error message:

OutOfRangeError: End of sequence
[[Node: IteratorGetNext = IteratorGetNextoutput_shapes=[[?,20,1], [?,20,1]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

I have asked for 70 epochs (and data is prepared for that), but something is not working. I use Tensorflow 1.10. The model is not saved, so there must be a strange bug.

It looks like some of the same issues as in this thread, but my model is not estimated and finished up by TF. I get the impression that some in this thread get an estimated model, but in addition get a error message. I dont have a model returned by TF.

Anyone ?

@feihugis
Copy link
Member

I got a similar problem when running serveral test cases, such as tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py and tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py, although the test cases pass.

The error messages are as following:

ERROR:tensorflow:End of sequence
	 [[node IteratorGetNext (defined at /anaconda3/lib/python3.6/unittest/case.py:605)  = IteratorGetNext[output_shapes=[[]], output_types=[DT_STRING], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

Caused by op 'IteratorGetNext', defined at:
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py", line 242, in <module>
    test.main()
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/platform/test.py", line 64, in main
    return _googletest.main(argv)
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/platform/googletest.py", line 100, in main
    benchmark.benchmarks_main(true_main=main_wrapper)
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/platform/benchmark.py", line 344, in benchmarks_main
    true_main()
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/platform/googletest.py", line 99, in main_wrapper
    return app.run(main=g_main, argv=args)
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 125, in run
    _sys.exit(main(argv))
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/platform/googletest.py", line 70, in g_main
    return unittest_main(argv=argv)
  File "/anaconda3/lib/python3.6/unittest/main.py", line 95, in __init__
    self.runTests()
  File "/anaconda3/lib/python3.6/unittest/main.py", line 256, in runTests
    self.result = testRunner.run(self.test)
  File "/anaconda3/lib/python3.6/unittest/runner.py", line 176, in run
    test(result)
  File "/anaconda3/lib/python3.6/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/anaconda3/lib/python3.6/unittest/suite.py", line 122, in run
    test(result)
  File "/anaconda3/lib/python3.6/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/anaconda3/lib/python3.6/unittest/suite.py", line 122, in run
    test(result)
  File "/anaconda3/lib/python3.6/unittest/case.py", line 653, in __call__
    return self.run(*args, **kwds)
  File "/anaconda3/lib/python3.6/unittest/case.py", line 605, in run
    testMethod()
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py", line 49, in testEmptyDirectory
    next_element = itr.get_next()
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/data/ops/iterator_ops.py", line 420, in get_next
    name=name)), self._output_types,
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/ops/gen_dataset_ops.py", line 2069, in iterator_get_next
    output_shapes=output_shapes, name=name)
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/framework/ops.py", line 3274, in create_op
    op_def=op_def)
  File "/private/var/tmp/_bazel_fei/5dc7b372a7c45427ff30500d3e22fe26/execroot/org_tensorflow/bazel-out/darwin-opt/bin/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.runfiles/org_tensorflow/tensorflow/python/framework/ops.py", line 1770, in __init__
    self._traceback = tf_stack.extract_stack()

@mrry
Copy link
Contributor

mrry commented Sep 22, 2018

I think that was an effect of 2c97ee3. I agree that this makes the output of tf.data tests hard to read, because they use OutOfRangeError to signal completion.

If you'd like to send a PR that disables the logging for OutOfRangeError, I'd be happy to approve it.

@feihugis
Copy link
Member

@mrry Thanks for your quick reply! Will submit a PR for this.

@gdoras
Copy link

gdoras commented Oct 17, 2018

Hello
I had opened a Stackoverflow question related to this thread before finding it -- https://stackoverflow.com/questions/52832625/outofrangeerror-logged-at-each-epoch-after-upgrade-from-tensorflow-1-8-0-to-1-11.
@mrry does it mean that this is only a logging issue, and that the Error raised shall not be worried about ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug
Projects
None yet
Development

No branches or pull requests