Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset with Keras Functional Model: tuple index out of range uin steps_per_epoch #35925

Closed
jguhlin opened this issue Jan 16, 2020 · 23 comments
Closed
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.1 for tracking issues in 2.1 release type:bug Bug

Comments

@jguhlin
Copy link

jguhlin commented Jan 16, 2020

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
    Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu, Newest
  • TensorFlow installed from (source or binary): Pip (Binary)
  • TensorFlow version (use command below): 2.1
  • Python version: Python 3.7.6
  • CUDA/cuDNN version: 10.2
  • GPU model and memory: GeForce 1070 8Gb

Describe the current behavior
Passing in a dataset to model.fit from a model generated with tf.Keras layers results in IndexError: tuple index out of range. Error both with custom TFRecord dataset and datasets derived from tensorflow-datasets installed via pip. Looks like it is in the standardize_input_data function but since it is an instance of DatasetV2 it should not be hitting that if statement...

Describe the expected behavior
Keras models should accept tf DataSets.

Code to reproduce the issue

from tensorflow.keras.layers import Dense, Embedding, Flatten, Lambda, Subtract, Input, Concatenate, Average, Reshape, GlobalAveragePooling1D, Dot, Dropout
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.utils import Sequence
from tensorflow.keras import initializers

import tensorflow_datasets as tfds
tfds.list_builders()
dataset, info = tfds.load("mnist", with_info=True)
inputs = Input((28, 28, 1), name="image")
First = Dense(128, activation="relu")
Second = Dropout(0.2)
Third = Dense(10, activation="softmax", name="label")

first = First(inputs)
second = Second(first)
third = Third(second)
model = Model(inputs=[inputs], outputs=[third])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(dataset['train'].batch(4096))

Other info / logs

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing,
**kwargs)
    340                 mode=ModeKeys.TRAIN,
    341                 training_context=training_context,
--> 342                 total_epochs=epochs)
    343             cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
    344 

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)
    126         step=step, mode=mode, size=current_batch_size) as batch_logs:
    127       try:
--> 128         batch_outs = execution_function(iterator)
    129       except (StopIteration, errors.OutOfRangeError):
    130         # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in execution_function(input_fn)
     96     # `numpy` translates Tensors to values in Eager mode.
     97     return nest.map_structure(_non_none_constant_value,
---> 98                               distributed_function(input_fn))
     99 
    100   return execution_function

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    566         xla_context.Exit()
    567     else:
--> 568       result = self._call(*args, **kwds)
    569 
    570     if tracing_count == self._get_tracing_count():

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    613       # This is the first call of __call__, so we have to initialize.
    614       initializers = []
--> 615       self._initialize(args, kwds, add_initializers_to=initializers)
    616     finally:
    617       # At this point we know that the initialization is complete (or less

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    495     self._concrete_stateful_fn = (
    496         self._stateful_fn._get_concrete_function_internal_garbage_collected( 
# pylint: disable=protected-access
--> 497             *args, **kwds))
    498 
    499     def invalid_creator_scope(*unused_args, **unused_kwds):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args,
**kwargs)    2387       args, kwargs = None, None    2388     with self._lock:
-> 2389       graph_function, _, _ = self._maybe_define_function(args, kwargs)    2390     return graph_function    2391 

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)    2701     2702       self._function_cache.missed.add(call_context_key)
-> 2703       graph_function = self._create_graph_function(args, kwargs)    2704       self._function_cache.primary[cache_key] = graph_function    2705       return graph_function, args, kwargs

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)    2591             arg_names=arg_names,    2592             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2593             capture_by_value=self._capture_by_value),    2594         self._function_attributes,    2595         # Tell the ConcreteFunction to clean up its graph once it goes out of

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    976                                           converted_func)
    977 
--> 978       func_outputs = python_func(*func_args, **func_kwargs)
    979 
    980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in distributed_function(input_iterator)
     83     args = _prepare_feed_values(model, input_iterator, mode, strategy)
     84     outputs = strategy.experimental_run_v2(
---> 85         per_replica_function, args=args)
     86     # Out of PerReplica outputs reduce or pick values to return.
     87     all_outputs = dist_utils.unwrap_output_dict(

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py in experimental_run_v2(self, fn, args, kwargs)
    761       fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx(),
    762                                 convert_by_default=False)
--> 763       return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    764 
    765   def reduce(self, reduce_op, value, axis):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py in call_for_each_replica(self, fn, args, kwargs)    1817       kwargs
= {}    1818     with self._container_strategy().scope():
-> 1819       return self._call_for_each_replica(fn, args, kwargs)    1820     1821   def _call_for_each_replica(self, fn, args, kwargs):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py in _call_for_each_replica(self, fn, args, kwargs)    2162         self._container_strategy(),    2163         replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
-> 2164       return fn(*args, **kwargs)    2165     2166   def _reduce_to(self, reduce_op, value, destinations):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    290   def wrapper(*args, **kwargs):
    291     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
--> 292       return func(*args, **kwargs)
    293 
    294   if inspect.isfunction(func) or inspect.ismethod(func):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in train_on_batch(model, x, y, sample_weight, class_weight, reset_metrics, standalone)
    414   x, y, sample_weights = model._standardize_user_data(
    415       x, y, sample_weight=sample_weight, class_weight=class_weight,
--> 416       extract_tensors_from_dataset=True)
    417   batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0]
    418   # If `model._distribution_strategy` is True, then we are in a replica context

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)    2381         is_dataset=is_dataset,   2382         class_weight=class_weight,
-> 2383         batch_size=batch_size)    2384     2385   def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)    2467           shapes=None,    2468           check_batch_axis=False,  # Don't enforce the batch size.
-> 2469           exception_prefix='target')    2470     2471       # Generate sample-wise weight values given the `sample_weight` and

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    510                        'for each key in: ' + str(names))
    511   elif isinstance(data, (list, tuple)):
--> 512     if isinstance(data[0], (list, tuple)):
    513       data = [np.asarray(d) for d in data]
    514     elif len(names) == 1 and isinstance(data[0], (float, int)):

IndexError: tuple index out of range
@amahendrakar amahendrakar self-assigned this Jan 16, 2020
@amahendrakar
Copy link
Contributor

Was able to reproduce the issue. Please find the Gist here. Thanks!

@amahendrakar amahendrakar added comp:keras Keras related issues TF 2.1 for tracking issues in 2.1 release type:bug Bug labels Jan 16, 2020
@amahendrakar amahendrakar assigned ymodak and unassigned amahendrakar Jan 16, 2020
@jguhlin
Copy link
Author

jguhlin commented Jan 16, 2020

I don't think this feature is supposed to work in 2.0 but can't confirm. But here is the error from there:

   1/Unknown - 0s 26ms/step
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-14-5486743d1f94> in <module>
      1 # Because it's confusing
      2 # This fn is made for loss here...
----> 3 model.fit(batched_dataset)

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    726         max_queue_size=max_queue_size,
    727         workers=workers,
--> 728         use_multiprocessing=use_multiprocessing)
    729 
    730   def evaluate(self,

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
    322                 mode=ModeKeys.TRAIN,
    323                 training_context=training_context,
--> 324                 total_epochs=epochs)
    325             cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
    326 

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)
    121         step=step, mode=mode, size=current_batch_size) as batch_logs:
    122       try:
--> 123         batch_outs = execution_function(iterator)
    124       except (StopIteration, errors.OutOfRangeError):
    125         # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in execution_function(input_fn)
     84     # `numpy` translates Tensors to values in Eager mode.
     85     return nest.map_structure(_non_none_constant_value,
---> 86                               distributed_function(input_fn))
     87 
     88   return execution_function

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    455 
    456     tracing_count = self._get_tracing_count()
--> 457     result = self._call(*args, **kwds)
    458     if tracing_count == self._get_tracing_count():
    459       self._call_counter.called_without_tracing()

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    501       # This is the first call of __call__, so we have to initialize.
    502       initializer_map = object_identity.ObjectIdentityDictionary()
--> 503       self._initialize(args, kwds, add_initializers_to=initializer_map)
    504     finally:
    505       # At this point we know that the initialization is complete (or less

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    406     self._concrete_stateful_fn = (
    407         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 408             *args, **kwds))
    409 
    410     def invalid_creator_scope(*unused_args, **unused_kwds):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   1846     if self.input_signature:
   1847       args, kwargs = None, None
-> 1848     graph_function, _, _ = self._maybe_define_function(args, kwargs)
   1849     return graph_function
   1850 

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2148         graph_function = self._function_cache.primary.get(cache_key, None)
   2149         if graph_function is None:
-> 2150           graph_function = self._create_graph_function(args, kwargs)
   2151           self._function_cache.primary[cache_key] = graph_function
   2152         return graph_function, args, kwargs

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2039             arg_names=arg_names,
   2040             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041             capture_by_value=self._capture_by_value),
   2042         self._function_attributes,
   2043         # Tell the ConcreteFunction to clean up its graph once it goes out of

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    913                                           converted_func)
    914 
--> 915       func_outputs = python_func(*func_args, **func_kwargs)
    916 
    917       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    356         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    357         # the function a weak reference to itself to avoid a reference cycle.
--> 358         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    359     weak_wrapped_fn = weakref.ref(wrapped_fn)
    360 

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in distributed_function(input_iterator)
     71     strategy = distribution_strategy_context.get_strategy()
     72     outputs = strategy.experimental_run_v2(
---> 73         per_replica_function, args=(model, x, y, sample_weights))
     74     # Out of PerReplica outputs reduce or pick values to return.
     75     all_outputs = dist_utils.unwrap_output_dict(

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py in experimental_run_v2(self, fn, args, kwargs)
    758       fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx(),
    759                                 convert_by_default=False)
--> 760       return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    761 
    762   def reduce(self, reduce_op, value, axis):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py in call_for_each_replica(self, fn, args, kwargs)
   1785       kwargs = {}
   1786     with self._container_strategy().scope():
-> 1787       return self._call_for_each_replica(fn, args, kwargs)
   1788 
   1789   def _call_for_each_replica(self, fn, args, kwargs):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py in _call_for_each_replica(self, fn, args, kwargs)
   2130         self._container_strategy(),
   2131         replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
-> 2132       return fn(*args, **kwargs)
   2133 
   2134   def _reduce_to(self, reduce_op, value, destinations):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    290   def wrapper(*args, **kwargs):
    291     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
--> 292       return func(*args, **kwargs)
    293 
    294   if inspect.isfunction(func) or inspect.ismethod(func):

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in train_on_batch(model, x, y, sample_weight, class_weight, reset_metrics)
    262       y,
    263       sample_weights=sample_weights,
--> 264       output_loss_metrics=model._output_loss_metrics)
    265 
    266   if reset_metrics:

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py in train_on_batch(model, inputs, targets, sample_weights, output_loss_metrics)
    309           sample_weights=sample_weights,
    310           training=True,
--> 311           output_loss_metrics=output_loss_metrics))
    312   if not isinstance(outs, list):
    313     outs = [outs]

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py in _process_single_batch(model, inputs, targets, output_loss_metrics, sample_weights, training)
    250               output_loss_metrics=output_loss_metrics,
    251               sample_weights=sample_weights,
--> 252               training=training))
    253       if total_loss is None:
    254         raise ValueError('The model cannot be run '

~/miniconda3/envs/keras/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py in _model_loss(model, inputs, targets, output_loss_metrics, sample_weights, training)
    164 
    165         if hasattr(loss_fn, 'reduction'):
--> 166           per_sample_losses = loss_fn.call(targets[i], outs[i])
    167           weighted_losses = losses_utils.compute_weighted_loss(
    168               per_sample_losses,

IndexError: list index out of range

@dendrondal
Copy link

I tried changing my model to Sequential to circumvent this error, and the error still occurs. Has anyone found a way around this?

@jguhlin
Copy link
Author

jguhlin commented Jan 26, 2020

@dendrondal The only way I've found around it is not using tf.data at all and using a plain Python generator. @amahendrakar Should I submit a PR for the tf.keras.model doc's since they say they support tf.data datasets but that is incorrect?

@i-TechSoftware
Copy link

Hi. I faced the same problem. Is there any solution or workaround?

@ymodak
Copy link
Contributor

ymodak commented Feb 1, 2020

Perhaps this thread can help.
tensorflow/datasets#720

@ymodak ymodak added the stat:awaiting response Status - Awaiting response from author label Feb 1, 2020
@JarnoRFB
Copy link
Contributor

JarnoRFB commented Feb 7, 2020

I experienced the same problem when trying to train an autoencoder with only a single input. In my case the problem was solved my mapping the function

def autoencoder_sample(x):
    return x, x

in a final step, even though the second example is basically ignored, keras .fit seems to expect an input and a label.

In the OP's example the issue might be solved with by using
dataset, info = tfds.load("mnist", with_info=True, as_supervised=True)
which leads the dataset to return tuples rather than dictionaries.

However, it is clear that the current error message is cryptic and does not help the user to find a solution.

@bestpredicts
Copy link

I faced the same problem. Is there any solution or workaround update ?

@ymodak ymodak assigned rsepassi and unassigned ymodak Feb 14, 2020
@hctomkins
Copy link

Also struggling with this! Any workarounds?

@jguhlin
Copy link
Author

jguhlin commented Mar 8, 2020

No solution, I've fallen back to using python generators instead of trying to use the tf.dataset. In reference to an earlier comment (thanks for the advice though!)

I was using this with custom code, following the tensorflow guide to creating datasets, but used the tfds lib here to make a really small working example so no need for extensive custom code in the issue.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Mar 10, 2020
@FabHub
Copy link

FabHub commented Mar 18, 2020

@jguhlin Hi, i faced the same issue a few days ago, could you please provide some info on how did you manage to use plain python generator function? Did you need to create your own batching mechanism and yield batches from the generator? thx.

@jguhlin
Copy link
Author

jguhlin commented Mar 18, 2020

@FabHub Yes, I'm just using a generator and skipping the tf datasets integration. In the generator, I'm creating the batches and yielding those. Have to handle shuffling and the like manually.

@FabHub
Copy link

FabHub commented Mar 18, 2020

Ok, just wanted to make sure, whether there is not some more advanced solution :) , thanks a lot for quick reply.

@jguhlin
Copy link
Author

jguhlin commented Mar 18, 2020

@FabHub Check the keras docs, it will tell you what to provide from the generator.
https://keras.io/models/model/
Good luck!

@F-29
Copy link

F-29 commented Apr 8, 2020

any progress?

@MihailMihaylov97
Copy link

I also have the same issue.
Is there a solution to integrating the tf.data API and the tf.Dataset (with tfrecords) with fit function of tf.keras.Model ?

@FabHub
Copy link

FabHub commented Apr 16, 2020

@F-29 @MihailMihaylov97, I only used @jguhlins solution. I created simple data generator function that iterates over dataset, yields tuple (x, y) and created dataset like this: tf.data.Dataset.from_generator(generator_function ...). Unfortunately in my case, I had to create data that are in uniform shape and store them on a disk prior to train process, couldn't cut the samples in my case on-the-fly. I needed to do this in order to perform shuffling properly and also use full potential of dataset. Hope it helps.

@MihailMihaylov97
Copy link

@FabHub thank you for the reply. I switched to tensorflow estimators, and it works just fine.

@csttsn
Copy link

csttsn commented Apr 29, 2020

Encountered the same issue with another tfds dataset. Adding as_supervised=True to tfds.load() solves the problem. Looks like this issue occurs when FeaturesDict is used.

@sachinprasadhs
Copy link
Contributor

While trying to reproduce your issue in Tf Nightly, encountered different error, please find the gist here.Thanks!

@sachinprasadhs sachinprasadhs added the stat:awaiting response Status - Awaiting response from author label Jun 4, 2021
@rsepassi rsepassi removed their assignment Jun 7, 2021
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Jun 14, 2021
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.1 for tracking issues in 2.1 release type:bug Bug
Projects
None yet
Development

No branches or pull requests