/
rnn.py
1683 lines (1444 loc) · 68.7 KB
/
rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RNN helpers for TensorFlow models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
_concat = rnn_cell_impl._concat
# pylint: enable=protected-access
def _transpose_batch_time(x):
"""Transposes the batch and time dimensions of a Tensor.
If the input tensor has rank < 2 it returns the original tensor. Retains as
much of the static shape information as possible.
Args:
x: A Tensor.
Returns:
x transposed along the first two dimensions.
"""
x_static_shape = x.get_shape()
if x_static_shape.rank is not None and x_static_shape.rank < 2:
return x
x_rank = array_ops.rank(x)
x_t = array_ops.transpose(
x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0))
x_t.set_shape(
tensor_shape.TensorShape(
[x_static_shape.dims[1].value,
x_static_shape.dims[0].value]).concatenate(x_static_shape[2:]))
return x_t
def _best_effort_input_batch_size(flat_input):
"""Get static input batch size if available, with fallback to the dynamic one.
Args:
flat_input: An iterable of time major input Tensors of shape `[max_time,
batch_size, ...]`. All inputs should have compatible batch sizes.
Returns:
The batch size in Python integer if available, or a scalar Tensor otherwise.
Raises:
ValueError: if there is any input with an invalid shape.
"""
for input_ in flat_input:
shape = input_.shape
if shape.rank is None:
continue
if shape.rank < 2:
raise ValueError("Input tensor should have rank >= 2. Received input="
f"{input_} of rank {shape.rank}")
batch_size = shape.dims[1].value
if batch_size is not None:
return batch_size
# Fallback to the dynamic batch size of the first input.
return array_ops.shape(flat_input[0])[1]
def _infer_state_dtype(explicit_dtype, state):
"""Infer the dtype of an RNN state.
Args:
explicit_dtype: explicitly declared dtype or None.
state: RNN's hidden state. Must be a Tensor or a nested iterable containing
Tensors.
Returns:
dtype: inferred dtype of hidden state.
Raises:
ValueError: if `state` has heterogeneous dtypes or is empty.
"""
if explicit_dtype is not None:
return explicit_dtype
elif nest.is_sequence(state):
inferred_dtypes = [element.dtype for element in nest.flatten(state)]
if not inferred_dtypes:
raise ValueError(f"Unable to infer dtype from argument state={state}.")
all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes)
if not all_same:
raise ValueError(
f"Argument state={state} has tensors of different inferred dtypes. "
"Unable to infer a single representative dtype. Dtypes received: "
f"{inferred_dtypes}")
return inferred_dtypes[0]
else:
return state.dtype
def _maybe_tensor_shape_from_tensor(shape):
if isinstance(shape, ops.Tensor):
return tensor_shape.as_shape(tensor_util.constant_value(shape))
else:
return shape
def _should_cache():
"""Returns True if a default caching device should be set, otherwise False."""
if context.executing_eagerly():
return False
# Don't set a caching device when running in a loop, since it is possible that
# train steps could be wrapped in a tf.while_loop. In that scenario caching
# prevents forward computations in loop iterations from re-reading the
# updated weights.
graph = ops.get_default_graph()
ctxt = graph._get_control_flow_context() # pylint: disable=protected-access
in_v1_while_loop = (
control_flow_util.GetContainingWhileContext(ctxt) is not None)
in_v2_while_loop = control_flow_util_v2.in_while_loop_defun(graph)
return not in_v1_while_loop and not in_v2_while_loop
# pylint: disable=unused-argument
def _rnn_step(time,
sequence_length,
min_sequence_length,
max_sequence_length,
zero_output,
state,
call_cell,
state_size,
skip_conditionals=False):
"""Calculate one step of a dynamic RNN minibatch.
Returns an (output, state) pair conditioned on `sequence_length`.
When skip_conditionals=False, the pseudocode is something like:
if t >= max_sequence_length:
return (zero_output, state)
if t < min_sequence_length:
return call_cell()
# Selectively output zeros or output, old state or new state depending
# on whether we've finished calculating each row.
new_output, new_state = call_cell()
final_output = np.vstack([
zero_output if time >= sequence_length[r] else new_output_r
for r, new_output_r in enumerate(new_output)
])
final_state = np.vstack([
state[r] if time >= sequence_length[r] else new_state_r
for r, new_state_r in enumerate(new_state)
])
return (final_output, final_state)
Args:
time: int32 `Tensor` scalar.
sequence_length: int32 `Tensor` vector of size [batch_size].
min_sequence_length: int32 `Tensor` scalar, min of sequence_length.
max_sequence_length: int32 `Tensor` scalar, max of sequence_length.
zero_output: `Tensor` vector of shape [output_size].
state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
or a list/tuple of such tensors.
call_cell: lambda returning tuple of (new_output, new_state) where
new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
state_size: The `cell.state_size` associated with the state.
skip_conditionals: Python bool, whether to skip using the conditional
calculations. This is useful for `dynamic_rnn`, where the input tensor
matches `max_sequence_length`, and using conditionals just slows
everything down.
Returns:
A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
final_output is a `Tensor` matrix of shape [batch_size, output_size]
final_state is either a single `Tensor` matrix, or a tuple of such
matrices (matching length and shapes of input `state`).
Raises:
ValueError: If the cell returns a state tuple whose length does not match
that returned by `state_size`.
"""
# Convert state to a list for ease of use
flat_state = nest.flatten(state)
flat_zero_output = nest.flatten(zero_output)
# Vector describing which batch entries are finished.
copy_cond = time >= sequence_length
def _copy_one_through(output, new_output):
# TensorArray and scalar get passed through.
if isinstance(output, tensor_array_ops.TensorArray):
return new_output
if output.shape.rank == 0:
return new_output
# Otherwise propagate the old or the new value.
with ops.colocate_with(new_output):
return array_ops.where(copy_cond, output, new_output)
def _copy_some_through(flat_new_output, flat_new_state):
# Use broadcasting select to determine which values should get
# the previous state & zero output, and which values should get
# a calculated state & output.
flat_new_output = [
_copy_one_through(zero_output, new_output)
for zero_output, new_output in zip(flat_zero_output, flat_new_output)
]
flat_new_state = [
_copy_one_through(state, new_state)
for state, new_state in zip(flat_state, flat_new_state)
]
return flat_new_output + flat_new_state
def _maybe_copy_some_through():
"""Run RNN step. Pass through either no or some past state."""
new_output, new_state = call_cell()
nest.assert_same_structure(zero_output, new_output)
nest.assert_same_structure(state, new_state)
flat_new_state = nest.flatten(new_state)
flat_new_output = nest.flatten(new_output)
return control_flow_ops.cond(
# if t < min_seq_len: calculate and return everything
time < min_sequence_length,
lambda: flat_new_output + flat_new_state,
# else copy some of it through
lambda: _copy_some_through(flat_new_output, flat_new_state))
# TODO(ebrevdo): skipping these conditionals may cause a slowdown,
# but benefits from removing cond() and its gradient. We should
# profile with and without this switch here.
if skip_conditionals:
# Instead of using conditionals, perform the selective copy at all time
# steps. This is faster when max_seq_len is equal to the number of unrolls
# (which is typical for dynamic_rnn).
new_output, new_state = call_cell()
nest.assert_same_structure(zero_output, new_output)
nest.assert_same_structure(state, new_state)
new_state = nest.flatten(new_state)
new_output = nest.flatten(new_output)
final_output_and_state = _copy_some_through(new_output, new_state)
else:
empty_update = lambda: flat_zero_output + flat_state
final_output_and_state = control_flow_ops.cond(
# if t >= max_seq_len: copy all state through, output zeros
time >= max_sequence_length,
empty_update,
# otherwise calculation is required: copy some or all of it through
_maybe_copy_some_through)
if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
raise ValueError("Internal error: state and output were not concatenated "
f"correctly. Received state length: {len(flat_state)}, "
f"output length: {len(flat_zero_output)}. Expected "
f"contatenated length: {len(final_output_and_state)}.")
final_output = final_output_and_state[:len(flat_zero_output)]
final_state = final_output_and_state[len(flat_zero_output):]
for output, flat_output in zip(final_output, flat_zero_output):
output.set_shape(flat_output.get_shape())
for substate, flat_substate in zip(final_state, flat_state):
if not isinstance(substate, tensor_array_ops.TensorArray):
substate.set_shape(flat_substate.get_shape())
final_output = nest.pack_sequence_as(
structure=zero_output, flat_sequence=final_output)
final_state = nest.pack_sequence_as(
structure=state, flat_sequence=final_state)
return final_output, final_state
def _reverse_seq(input_seq, lengths):
"""Reverse a list of Tensors up to specified lengths.
Args:
input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features)
or nested tuples of tensors.
lengths: A `Tensor` of dimension batch_size, containing lengths for each
sequence in the batch. If "None" is specified, simply reverses the list.
Returns:
time-reversed sequence
"""
if lengths is None:
return list(reversed(input_seq))
flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq)
flat_results = [[] for _ in range(len(input_seq))]
for sequence in zip(*flat_input_seq):
input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank)
for input_ in sequence:
input_shape.assert_is_compatible_with(input_.get_shape())
input_.set_shape(input_shape)
# Join into (time, batch_size, depth)
s_joined = array_ops.stack(sequence)
# Reverse along dimension 0
s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
# Split again into list
result = array_ops.unstack(s_reversed)
for r, flat_result in zip(result, flat_results):
r.set_shape(input_shape)
flat_result.append(r)
results = [
nest.pack_sequence_as(structure=input_, flat_sequence=flat_result)
for input_, flat_result in zip(input_seq, flat_results)
]
return results
@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
"keras.layers.RNN(cell))`, which is equivalent to "
"this API")
@tf_export(v1=["nn.bidirectional_dynamic_rnn"])
@dispatch.add_dispatch_support
def bidirectional_dynamic_rnn(cell_fw,
cell_bw,
inputs,
sequence_length=None,
initial_state_fw=None,
initial_state_bw=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None):
"""Creates a dynamic version of bidirectional recurrent neural network.
Takes input and builds independent forward and backward RNNs. The input_size
of forward and backward cell must match. The initial state for both directions
is zero by default (but can be set optionally) and no intermediate states are
ever returned -- the network is fully unrolled for the given (passed in)
length(s) of the sequence(s) or completely unrolled if length(s) is not
given.
Args:
cell_fw: An instance of RNNCell, to be used for forward direction.
cell_bw: An instance of RNNCell, to be used for backward direction.
inputs: The RNN inputs.
If time_major == False (default), this must be a tensor of shape:
`[batch_size, max_time, ...]`, or a nested tuple of such elements.
If time_major == True, this must be a tensor of shape: `[max_time,
batch_size, ...]`, or a nested tuple of such elements.
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
containing the actual lengths for each of the sequences in the batch. If
not provided, all batch entries are assumed to be full sequences; and time
reversal is applied from time `0` to `max_time` for each sequence.
initial_state_fw: (optional) An initial state for the forward RNN. This must
be a tensor of appropriate type and shape `[batch_size,
cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
tuple of tensors having shapes `[batch_size, s] for s in
cell_fw.state_size`.
initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
corresponding properties of `cell_bw`.
dtype: (optional) The data type for the initial states and expected output.
Required if initial_states are not provided or RNN states have a
heterogeneous dtype.
parallel_iterations: (Default: 32). The number of iterations to run in
parallel. Those operations which do not have any temporal dependency and
can be run in parallel, will be. This parameter trades off time for
space. Values >> 1 use more memory but take less time, while smaller
values use less memory but computations take longer.
swap_memory: Transparently swap the tensors produced in forward inference
but needed for back prop from GPU to CPU. This allows training RNNs which
would typically not fit on a single GPU, with very minimal (or no)
performance penalty.
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
`time_major = True` is a bit more efficient because it avoids transposes
at the beginning and end of the RNN calculation. However, most TensorFlow
data is batch-major, so by default this function accepts input and emits
output in batch-major form.
scope: VariableScope for the created subgraph; defaults to
"bidirectional_rnn"
Returns:
A tuple (outputs, output_states) where:
outputs: A tuple (output_fw, output_bw) containing the forward and
the backward rnn output `Tensor`.
If time_major == False (default),
output_fw will be a `Tensor` shaped:
`[batch_size, max_time, cell_fw.output_size]`
and output_bw will be a `Tensor` shaped:
`[batch_size, max_time, cell_bw.output_size]`.
If time_major == True,
output_fw will be a `Tensor` shaped:
`[max_time, batch_size, cell_fw.output_size]`
and output_bw will be a `Tensor` shaped:
`[max_time, batch_size, cell_bw.output_size]`.
It returns a tuple instead of a single concatenated `Tensor`, unlike
in the `bidirectional_rnn`. If the concatenated one is preferred,
the forward and backward outputs can be concatenated as
`tf.concat(outputs, 2)`.
output_states: A tuple (output_state_fw, output_state_bw) containing
the forward and the backward final states of bidirectional rnn.
Raises:
TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
"""
rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = dynamic_rnn(
cell=cell_fw,
inputs=inputs,
sequence_length=sequence_length,
initial_state=initial_state_fw,
dtype=dtype,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
time_major=time_major,
scope=fw_scope)
# Backward direction
if not time_major:
time_axis = 1
batch_axis = 0
else:
time_axis = 0
batch_axis = 1
def _reverse(input_, seq_lengths, seq_axis, batch_axis):
if seq_lengths is not None:
return array_ops.reverse_sequence(
input=input_,
seq_lengths=seq_lengths,
seq_axis=seq_axis,
batch_axis=batch_axis)
else:
return array_ops.reverse(input_, axis=[seq_axis])
with vs.variable_scope("bw") as bw_scope:
def _map_reverse(inp):
return _reverse(
inp,
seq_lengths=sequence_length,
seq_axis=time_axis,
batch_axis=batch_axis)
inputs_reverse = nest.map_structure(_map_reverse, inputs)
tmp, output_state_bw = dynamic_rnn(
cell=cell_bw,
inputs=inputs_reverse,
sequence_length=sequence_length,
initial_state=initial_state_bw,
dtype=dtype,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
time_major=time_major,
scope=bw_scope)
output_bw = _reverse(
tmp,
seq_lengths=sequence_length,
seq_axis=time_axis,
batch_axis=batch_axis)
outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)
return (outputs, output_states)
@deprecation.deprecated(
None,
"Please use `keras.layers.RNN(cell)`, which is equivalent to this API")
@tf_export(v1=["nn.dynamic_rnn"])
@dispatch.add_dispatch_support
def dynamic_rnn(cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None):
"""Creates a recurrent neural network specified by RNNCell `cell`.
Performs fully dynamic unrolling of `inputs`.
Example:
```python
# create a BasicRNNCell
rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
```
```python
# create 2 LSTMCells
rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers)
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.nn.rnn_cell.LSTMStateTuple for each cell
outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=data,
dtype=tf.float32)
```
Args:
cell: An instance of RNNCell.
inputs: The RNN inputs.
If `time_major == False` (default), this must be a `Tensor` of shape:
`[batch_size, max_time, ...]`, or a nested tuple of such elements.
If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
batch_size, ...]`, or a nested tuple of such elements. This may also be
a (possibly nested) tuple of Tensors satisfying this property. The
first two dimensions must match across all the inputs, but otherwise the
ranks and other shape components may differ. In this case, input to
`cell` at each time-step will replicate the structure of these tuples,
except for the time dimension (from which the time is taken). The input
to `cell` at each time step will be a `Tensor` or (possibly nested)
tuple of Tensors each with dimensions `[batch_size, ...]`.
sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
to copy-through state and zero-out outputs when past a batch element's
sequence length. This parameter enables users to extract the last valid
state and properly padded outputs, so it is provided for correctness.
initial_state: (optional) An initial state for the RNN. If `cell.state_size`
is an integer, this must be a `Tensor` of appropriate type and shape
`[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
should be a tuple of tensors having shapes `[batch_size, s] for s in
cell.state_size`.
dtype: (optional) The data type for the initial state and expected output.
Required if initial_state is not provided or RNN state has a heterogeneous
dtype.
parallel_iterations: (Default: 32). The number of iterations to run in
parallel. Those operations which do not have any temporal dependency and
can be run in parallel, will be. This parameter trades off time for
space. Values >> 1 use more memory but take less time, while smaller
values use less memory but computations take longer.
swap_memory: Transparently swap the tensors produced in forward inference
but needed for back prop from GPU to CPU. This allows training RNNs which
would typically not fit on a single GPU, with very minimal (or no)
performance penalty.
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
`time_major = True` is a bit more efficient because it avoids transposes
at the beginning and end of the RNN calculation. However, most TensorFlow
data is batch-major, so by default this function accepts input and emits
output in batch-major form.
scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A pair (outputs, state) where:
outputs: The RNN output `Tensor`.
If time_major == False (default), this will be a `Tensor` shaped:
`[batch_size, max_time, cell.output_size]`.
If time_major == True, this will be a `Tensor` shaped:
`[max_time, batch_size, cell.output_size]`.
Note, if `cell.output_size` is a (possibly nested) tuple of integers
or `TensorShape` objects, then `outputs` will be a tuple having the
same structure as `cell.output_size`, containing Tensors having shapes
corresponding to the shape data in `cell.output_size`.
state: The final state. If `cell.state_size` is an int, this
will be shaped `[batch_size, cell.state_size]`. If it is a
`TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
If it is a (possibly nested) tuple of ints or `TensorShape`, this will
be a tuple having the corresponding shapes. If cells are `LSTMCells`
`state` will be a tuple containing a `LSTMStateTuple` for each cell.
Raises:
TypeError: If `cell` is not an instance of RNNCell.
ValueError: If inputs is None or an empty list.
@compatibility(TF2)
`tf.compat.v1.nn.dynamic_rnn` is not compatible with eager execution and
`tf.function`. Please use `tf.keras.layers.RNN` instead for TF2 migration.
Take LSTM as an example, you can instantiate a `tf.keras.layers.RNN` layer
with `tf.keras.layers.LSTMCell`, or directly via `tf.keras.layers.LSTM`. Once
the keras layer is created, you can get the output and states by calling
the layer with input and states. Please refer to [this
guide](https://www.tensorflow.org/guide/keras/rnn) for more details about
Keras RNN. You can also find more details about the difference and comparison
between Keras RNN and TF compat v1 rnn in [this
document](https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md)
#### Structural Mapping to Native TF2
Before:
```python
# create 2 LSTMCells
rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers)
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.nn.rnn_cell.LSTMStateTuple for each cell
outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=data,
dtype=tf.float32)
```
After:
```python
# RNN layer can take a list of cells, which will then stack them together.
# By default, keras RNN will only return the last timestep output and will not
# return states. If you need whole time sequence output as well as the states,
# you can set `return_sequences` and `return_state` to True.
rnn_layer = tf.keras.layers.RNN([tf.keras.layers.LSTMCell(128),
tf.keras.layers.LSTMCell(256)],
return_sequences=True,
return_state=True)
outputs, output_states = rnn_layer(inputs, states)
```
#### How to Map Arguments
| TF1 Arg Name | TF2 Arg Name | Note |
| :-------------------- | :-------------- | :------------------------------- |
| `cell` | `cell` | In the RNN layer constructor |
| `inputs` | `inputs` | In the RNN layer `__call__` |
| `sequence_length` | Not used | Adding masking layer before RNN :
: : : to achieve the same result. :
| `initial_state` | `initial_state` | In the RNN layer `__call__` |
| `dtype` | `dtype` | In the RNN layer constructor |
| `parallel_iterations` | Not supported | |
| `swap_memory` | Not supported | |
| `time_major` | `time_major` | In the RNN layer constructor |
| `scope` | Not supported | |
@end_compatibility
"""
rnn_cell_impl.assert_like_rnncell("cell", cell)
with vs.variable_scope(scope or "rnn") as varscope:
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
# By default, time_major==False and inputs are batch-major: shaped
# [batch, time, depth]
# For internal calculations, we transpose to [time, batch, depth]
flat_input = nest.flatten(inputs)
if not time_major:
# (B,T,D) => (T,B,D)
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
parallel_iterations = parallel_iterations or 32
if sequence_length is not None:
sequence_length = math_ops.cast(sequence_length, dtypes.int32)
if sequence_length.get_shape().rank not in (None, 1):
raise ValueError(
f"Argument sequence_length must be a vector of length batch_size."
f" Received sequence_length={sequence_length} of shape: "
f"{sequence_length.get_shape()}")
sequence_length = array_ops.identity( # Just to find it in the graph.
sequence_length,
name="sequence_length")
batch_size = _best_effort_input_batch_size(flat_input)
if initial_state is not None:
state = initial_state
else:
if not dtype:
raise ValueError("If no initial_state is provided, argument `dtype` "
"must be specified")
if getattr(cell, "get_initial_state", None) is not None:
state = cell.get_initial_state(
inputs=None, batch_size=batch_size, dtype=dtype)
else:
state = cell.zero_state(batch_size, dtype)
def _assert_has_shape(x, shape):
x_shape = array_ops.shape(x)
packed_shape = array_ops.stack(shape)
return control_flow_ops.Assert(
math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
"Expected shape for Tensor %s is " % x.name, packed_shape,
" but saw shape: ", x_shape
])
if not context.executing_eagerly() and sequence_length is not None:
# Perform some shape validation
with ops.control_dependencies(
[_assert_has_shape(sequence_length, [batch_size])]):
sequence_length = array_ops.identity(
sequence_length, name="CheckSeqLen")
inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
(outputs, final_state) = _dynamic_rnn_loop(
cell,
inputs,
state,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
sequence_length=sequence_length,
dtype=dtype)
# Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
# If we are performing batch-major calculations, transpose output back
# to shape [batch, time, depth]
if not time_major:
# (T,B,D) => (B,T,D)
outputs = nest.map_structure(_transpose_batch_time, outputs)
return (outputs, final_state)
def _dynamic_rnn_loop(cell,
inputs,
initial_state,
parallel_iterations,
swap_memory,
sequence_length=None,
dtype=None):
"""Internal implementation of Dynamic RNN.
Args:
cell: An instance of RNNCell.
inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
tuple of such elements.
initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
`cell.state_size` is a tuple, then this should be a tuple of tensors
having shapes `[batch_size, s] for s in cell.state_size`.
parallel_iterations: Positive Python int.
swap_memory: A Python boolean
sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
dtype: (optional) Expected dtype of output. If not specified, inferred from
initial_state.
Returns:
Tuple `(final_outputs, final_state)`.
final_outputs:
A `Tensor` of shape `[time, batch_size, cell.output_size]`. If
`cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
objects, then this returns a (possibly nested) tuple of Tensors matching
the corresponding shapes.
final_state:
A `Tensor`, or possibly nested tuple of Tensors, matching in length
and shapes to `initial_state`.
Raises:
ValueError: If the input depth cannot be inferred via shape inference
from the inputs.
ValueError: If time_step is not the same for all the elements in the
inputs.
ValueError: If batch_size is not the same for all the elements in the
inputs.
"""
state = initial_state
assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
state_size = cell.state_size
flat_input = nest.flatten(inputs)
flat_output_size = nest.flatten(cell.output_size)
# Construct an initial output
input_shape = array_ops.shape(flat_input[0])
time_steps = input_shape[0]
batch_size = _best_effort_input_batch_size(flat_input)
inputs_got_shape = tuple(
input_.get_shape().with_rank_at_least(3) for input_ in flat_input)
const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]
for i, shape in enumerate(inputs_got_shape):
if not shape[2:].is_fully_defined():
raise ValueError(
"Input size (depth of inputs) must be accessible via shape inference,"
f" but saw value None for input={flat_input[i]}.")
got_time_steps = shape.dims[0].value
got_batch_size = shape.dims[1].value
if const_time_steps != got_time_steps:
raise ValueError(
"Time steps is not the same for all the elements in the input in a "
f"batch. Received time steps={got_time_steps} for input="
f"{flat_input[i]}.")
if const_batch_size != got_batch_size:
raise ValueError(
"Batch_size is not the same for all the elements in the input. "
f"Received batch size={got_batch_size} for input={flat_input[i]}.")
# Prepare dynamic conditional copying of state & output
def _create_zero_arrays(size):
size = _concat(batch_size, size)
return array_ops.zeros(
array_ops.stack(size), _infer_state_dtype(dtype, state))
flat_zero_output = tuple(
_create_zero_arrays(output) for output in flat_output_size)
zero_output = nest.pack_sequence_as(
structure=cell.output_size, flat_sequence=flat_zero_output)
if sequence_length is not None:
min_sequence_length = math_ops.reduce_min(sequence_length)
max_sequence_length = math_ops.reduce_max(sequence_length)
else:
max_sequence_length = time_steps
time = array_ops.constant(0, dtype=dtypes.int32, name="time")
with ops.name_scope("dynamic_rnn") as scope:
base_name = scope
def _create_ta(name, element_shape, dtype):
return tensor_array_ops.TensorArray(
dtype=dtype,
size=time_steps,
element_shape=element_shape,
tensor_array_name=base_name + name)
in_graph_mode = not context.executing_eagerly()
if in_graph_mode:
output_ta = tuple(
_create_ta(
"output_%d" % i,
element_shape=(
tensor_shape.TensorShape([const_batch_size]).concatenate(
_maybe_tensor_shape_from_tensor(out_size))),
dtype=_infer_state_dtype(dtype, state))
for i, out_size in enumerate(flat_output_size))
input_ta = tuple(
_create_ta(
"input_%d" % i,
element_shape=flat_input_i.shape[1:],
dtype=flat_input_i.dtype)
for i, flat_input_i in enumerate(flat_input))
input_ta = tuple(
ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input))
else:
output_ta = tuple([0 for _ in range(time_steps.numpy())]
for i in range(len(flat_output_size)))
input_ta = flat_input
def _time_step(time, output_ta_t, state):
"""Take a time step of the dynamic RNN.
Args:
time: int32 scalar Tensor.
output_ta_t: List of `TensorArray`s that represent the output.
state: nested tuple of vector tensors that represent the state.
Returns:
The tuple (time + 1, output_ta_t with updated flow, new_state).
"""
if in_graph_mode:
input_t = tuple(ta.read(time) for ta in input_ta)
# Restore some shape information
for input_, shape in zip(input_t, inputs_got_shape):
input_.set_shape(shape[1:])
else:
input_t = tuple(ta[time.numpy()] for ta in input_ta)
input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
# Keras RNN cells only accept state as list, even if it's a single tensor.
call_cell = lambda: cell(input_t, state)
if sequence_length is not None:
(output, new_state) = _rnn_step(
time=time,
sequence_length=sequence_length,
min_sequence_length=min_sequence_length,
max_sequence_length=max_sequence_length,
zero_output=zero_output,
state=state,
call_cell=call_cell,
state_size=state_size,
skip_conditionals=True)
else:
(output, new_state) = call_cell()
# Pack state if using state tuples
output = nest.flatten(output)
if in_graph_mode:
output_ta_t = tuple(
ta.write(time, out) for ta, out in zip(output_ta_t, output))
else:
for ta, out in zip(output_ta_t, output):
ta[time.numpy()] = out
return (time + 1, output_ta_t, new_state)
if in_graph_mode:
# Make sure that we run at least 1 step, if necessary, to ensure
# the TensorArrays pick up the dynamic shape.
loop_bound = math_ops.minimum(time_steps,
math_ops.maximum(1, max_sequence_length))
else:
# Using max_sequence_length isn't currently supported in the Eager branch.
loop_bound = time_steps
_, output_final_ta, final_state = control_flow_ops.while_loop(
cond=lambda time, *_: time < loop_bound,
body=_time_step,
loop_vars=(time, output_ta, state),
parallel_iterations=parallel_iterations,
maximum_iterations=time_steps,
swap_memory=swap_memory)
# Unpack final output if not using output tuples.
if in_graph_mode:
final_outputs = tuple(ta.stack() for ta in output_final_ta)
# Restore some shape information
for output, output_size in zip(final_outputs, flat_output_size):
shape = _concat([const_time_steps, const_batch_size],
output_size,
static=True)
output.set_shape(shape)
else:
final_outputs = output_final_ta
final_outputs = nest.pack_sequence_as(
structure=cell.output_size, flat_sequence=final_outputs)
if not in_graph_mode:
final_outputs = nest.map_structure_up_to(
cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs)
return (final_outputs, final_state)
@tf_export(v1=["nn.raw_rnn"])
@dispatch.add_dispatch_support
def raw_rnn(cell,
loop_fn,
parallel_iterations=None,
swap_memory=False,
scope=None):
"""Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.
**NOTE: This method is still in testing, and the API may change.**
This function is a more primitive version of `dynamic_rnn` that provides
more direct access to the inputs each iteration. It also provides more
control over when to start and finish reading the sequence, and
what to emit for the output.