-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Got errors when calling get_initial_state with attention mechanism #673
Comments
@kazemnejad Do you spot an obvious error in the above code? |
In terms of obvious errors, I don't think so. However, this line should be changed as follows: decoder_initial_state = decoder_cell.get_initial_state(batch_size=tf.constant(batch_size), dtype=tf.float32) to decoder_initial_state = decoder_cell.get_initial_state(batch_size=tf.shape(masked_input)[0], dtype=tf.float32) Nonetheless, even the new version still produces the same error( It also worth mentioning that the above code seems to raise another error:
|
Thanks for looking into this! |
@CryMasK Do you have some updates on your issue? |
I didn't keep trying with tfa.
This workaround works well even with a long sequence and it can perform local-attention and masking by the same concept. |
System information
Run tfa-nightly (0.7.0.dev20191105 ) and TF 2.0 on colab
Describe the bug
I want to implement encoder-decoder structure with attention mechanism.
But when I called
decoder_initial_state = decoder_cell.get_initial_state(masked_input)
, I got the errorValueError: Duplicate node name in graph: 'AttentionWrapperZeroState/zeros/packed'
If I change to another calling method
decoder_initial_state = decoder_cell.get_initial_state(batch_size=tf.constant(batch_size), dtype=tf.float32)
, there also is an errorOperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Did I do something wrong?
Code to reproduce the issue
The text was updated successfully, but these errors were encountered: