Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layers.Flatten( ) fails under TF 2.16.1 but not under TF 2.15 #64177

Closed
JuanVargas opened this issue Mar 21, 2024 · 6 comments
Closed

Layers.Flatten( ) fails under TF 2.16.1 but not under TF 2.15 #64177

JuanVargas opened this issue Mar 21, 2024 · 6 comments
Assignees
Labels
comp:apis Highlevel API related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.16 type:bug Bug

Comments

@JuanVargas
Copy link

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

binary

TensorFlow version

TF 2.16.1

Custom code

No

OS platform and distribution

Linux Ubuntu 22.04.4 LTS

Mobile device

No response

Python version

3.10.12

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

CUDA 12.4; CUDNN 8

GPU model and memory

NVIDIA GeForce RTX 3060 12 GBs

Current behavior?

Code from pp 289 of Cholllet's book second ed, that uses a Flatten layer, runs OK under TF 2.15 and Keras 2.15. The exact same code fails under TF 2.16.1 when it tries to execute the line to get the history obj. For reference the code is in the section for code of this form and the error msg is also in the box of this form

2024-03-21 11:16:28.427720: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1711034188.704423 18284 service.cc:145] XLA service 0x7ad818006930 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1711034188.704465 18284 service.cc:153] StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2024-03-21 11:16:28.713034: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at xla_ops.cc:580 : INVALID_ARGUMENT: only one input size may be -1, not both 0 and 1

Stack trace for op definition:

Standalone code to reproduce the issue

inputs = keras.Input(shape=(sequence_length, raw_data.shape[-1]))
x = keras.layers.Flatten()(inputs)
x = keras.layers.Dense(16, activation="relu")(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

callbacks = [
   keras.callbacks.ModelCheckpoint("jena_dense.keras.x.keras",
                                    save_best_only=True)
]

model.compile(optimizer="rmsprop", loss="mse", metrics=["mae"])

history = model.fit(train_dataset,
                    epochs=10,
                    validation_data=val_dataset,
                    callbacks=callbacks)

Relevant log output

2024-03-21 11:16:28.427720: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1711034188.704423   18284 service.cc:145] XLA service 0x7ad818006930 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1711034188.704465   18284 service.cc:153]   StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2024-03-21 11:16:28.713034: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at xla_ops.cc:580 : INVALID_ARGUMENT: only one input size may be -1, not both 0 and 1

Stack trace for op definition: 
tack trace for op definition: 
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in process_one
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 359, in execute_request
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 446, in do_execute
File "/tmp/ipykernel_18158/437540976.py", line 29, in wrapper
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
File "<ipython-input-1-6a9ff972cd60>", line 250, in <module>
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 325, in fit
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 118, in one_step_on_iterator
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 106, in one_step_on_data
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 57, in train_step
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/layers/layer.py", line 814, in __call__
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/ops/operation.py", line 48, in __call__
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/models/functional.py", line 194, in call
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/ops/function.py", line 151, in _run_through_graph
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/models/functional.py", line 578, in call
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/layers/layer.py", line 814, in __call__
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/ops/operation.py", line 48, in __call__
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/layers/reshaping/flatten.py", line 54, in call
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/ops/numpy.py", line 4507, in reshape
File "/drv3/hm3/code/python/tf.2.16.1/.tf.2.16.1/lib/python3.10/site-packages/keras/src/backend/tensorflow/numpy.py", line 1545, in reshape

	 [[{{node functional_1_1/flatten_1/Reshape}}]]
	tf2xla conversion failed while converting __inference_one_step_on_data_19335[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
2024-03-21 11:16:28.713066: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: only one input size may be -1, not both 0 and 1
@sgkouzias
Copy link

Set the input_shape argument of your Dense layer equal to (sequence_length * raw_data.shape[-1],)

@Di-Is
Copy link

Di-Is commented Mar 24, 2024

Hi @JuanVargas

I also encountered a similar error while training an image classification model.
Below is a list of measures that were effective in suppressing the error in my environment:

  1. Use tf.ensure_shape to explicitly define the shape of tensors in the training and validation dataset.
  2. Switch from graph execution to eager execution.
  3. Use a reshape layer instead of a flatten layer.
    • In your case tf.keras.Reshape((sequence_length, raw_data.shape[-1])) or tf.keras.Reshape((None, sequence_length, raw_data.shape[-1]))

@sushreebarsa
Copy link
Contributor

@JuanVargas
In order to expedite the trouble-shooting process, please provide a complete code snippet to reproduce the issue reported here. Thank you!

@sushreebarsa sushreebarsa added comp:apis Highlevel API related issues TF 2.16 stat:awaiting response Status - Awaiting response from author labels Mar 27, 2024
Copy link

github-actions bot commented Apr 4, 2024

This issue is stale because it has been open for 7 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Apr 4, 2024
Copy link

This issue was closed because it has been inactive for 7 days since being marked as stale. Please reopen if you'd like to work on this further.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:apis Highlevel API related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.16 type:bug Bug
Projects
None yet
Development

No branches or pull requests

4 participants