AttentionMachanism is not compatible with Eager Execution #535
Describe the bug
In the context of eager execution, we need to re-setup the memory on each step of training. However, it seems that the current API does not provide this kind of behavior. The following code snippet is from
if self._memory_initialized: if len(inputs) not in (2, 3): raise ValueError( "Expect the inputs to have 2 or 3 tensors, got %d" % len(inputs)) if len(inputs) == 2: # We append the calculated memory here so that the graph will be # connected. inputs.append(self.values) return super(_BaseAttentionMechanism, self).__call__(inputs, **kwargs)
As you can see, once the memory gets initialized, it assumes future inputs will be only to query the memory. Therefore the second call to this method (to re-setup the memory) will raise an error.
Other info / logs
Ideas to solve:
The text was updated successfully, but these errors were encountered:
The problem I had with the first option was this comment in the
I'm not sure how using the
For the 3rd option, I meant calling the method with
Furthermore, two consecutive calls to this method will cause an error too (this is the case with the Keras training loop in which the model's
mechanism(memory, memory_mask=mask, setup_memory=True) mechanism(memory, memory_mask=mask, setup_memory=True)
will raise an error due to the following condition:
if self._memory_initialized: if len(inputs) not in (2, 3): (...)
To be clear, by the step, I meant one step of training (one batch flowing through the model's graph), not one step of dynamic decoding in RNNs.
I hypothesize that in the eager execution, the graph is dynamic so there are no symbolic tensors; tensors are created eagerly upon OPS invocation. This means that, if we don't call the
I see. "run_eagerly=True" is more like a debug mode where all the tensor input/output to layer will just be eager tensor (numpy array like). It will have a bad performance, but will allow user to debug and trace the numeric value if needed.
In that case, I agree that the cached value for memory will be incorrect, and should be reset/populated per batch. I think the call() should take that into consideration, which is the option 3 you stated above.