/
dynamic_ragged_shape.py
3292 lines (2767 loc) · 121 KB
/
dynamic_ragged_shape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shapes & broadcasting for RaggedTensors.
TODO(martinz): make this suitable for output for tf.shape
TODO(martinz): replace ragged_tensor_shape with this.
"""
import abc
from typing import Any, Iterable, Optional, Sequence, Tuple, Union
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import extension_type
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged.row_partition import RowPartition
from tensorflow.python.ops.ragged.row_partition import RowPartitionSpec
from tensorflow.python.types import core
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
class _DynamicRaggedShapeBatchEncoder(extension_type.ExtensionTypeBatchEncoder):
"""A batch encoder for DynamicRaggedShape below."""
def batch(self, spec: "DynamicRaggedShape.Spec",
batch_size) -> "DynamicRaggedShape.Spec":
if spec.num_row_partitions:
new_head = _batch_rp_spec_head(spec._row_partitions[0], batch_size) # pylint:disable=protected-access
new_tail = [_batch_rp_spec(rp, batch_size) for rp in spec._row_partitions] # pylint:disable=protected-access
new_rp = [new_head] + new_tail
new_static_inner_shape = _batch_static_inner_shape(
spec._static_inner_shape, batch_size) # pylint:disable=protected-access
return DynamicRaggedShape.Spec(
row_partitions=new_rp,
static_inner_shape=new_static_inner_shape,
dtype=spec.dtype)
elif batch_size is None:
if spec.inner_rank == 0:
return DynamicRaggedShape.Spec._from_tensor_shape( # pylint:disable=protected-access
[None],
0,
dtype=spec.dtype)
else:
# Might be None
new_head = RowPartitionSpec(
uniform_row_length=spec._dimension(0), # pylint:disable=protected-access
dtype=spec.dtype)
new_static_inner_shape = _batch_static_inner_shape(
spec._static_inner_shape, batch_size) # pylint:disable=protected-access
return DynamicRaggedShape.Spec(
row_partitions=[new_head],
static_inner_shape=new_static_inner_shape,
dtype=spec.dtype)
else:
return DynamicRaggedShape.Spec(
row_partitions=[],
static_inner_shape=_batch_tensor_shape(
spec._static_inner_shape, # pylint:disable=protected-access
batch_size),
dtype=spec.dtype)
def unbatch(self,
spec: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape.Spec":
if spec.num_row_partitions:
result = []
head = spec._row_partitions[0] # pylint:disable=protected-access
scale = None if head.uniform_row_length is None else head.nrows
for rp in spec._row_partitions[1:]: # pylint:disable=protected-access
if scale is None:
result.append(
RowPartitionSpec(
nrows=None,
nvals=None,
uniform_row_length=rp.uniform_row_length,
dtype=spec.dtype))
else:
nrows = None if rp.nrows is None else rp.nrows // scale
if rp.uniform_row_length is None:
scale = None
result.append(
RowPartitionSpec(
nrows=nrows,
nvals=None,
uniform_row_length=None,
dtype=spec.dtype))
else:
result.append(
RowPartitionSpec(
nrows=nrows,
nvals=rp.nvals // scale,
uniform_row_length=rp.uniform_row_length,
dtype=spec.dtype))
return DynamicRaggedShape.Spec(
row_partitions=result,
static_inner_shape=_unbatch_static_inner_shape(
spec._static_inner_shape, scale), # pylint:disable=protected-access
dtype=spec.dtype)
else: # spec.num_row_partitions == 0
return DynamicRaggedShape.Spec(
row_partitions=[],
static_inner_shape=spec._static_inner_shape[1:], # pylint:disable=protected-access
dtype=spec.dtype)
def decode(self, spec: "DynamicRaggedShape.Spec",
encoding) -> "DynamicRaggedShape":
return DynamicRaggedShape.from_tensor(encoding, dtype=spec.dtype)
def encode(
self,
spec: "DynamicRaggedShape.Spec",
value,
minimum_rank=0) -> Union[ragged_tensor.RaggedTensor, tensor_lib.Tensor]:
return ones(value, dtype=dtypes.bool)
def encoding_specs(
self, spec: "DynamicRaggedShape.Spec"
) -> Union[ragged_tensor.RaggedTensorSpec, tensor_lib.TensorSpec]:
if spec.rank != 0:
ragged_rank = spec.num_row_partitions
else:
# special case: need to unbatch twice to get ragged tensor.
ragged_rank = -1
return ragged_tensor.RaggedTensorSpec(
shape=spec._to_tensor_shape(), # pylint:disable=protected-access
dtype=dtypes.bool,
ragged_rank=ragged_rank,
row_splits_dtype=spec.dtype)
# TODO(martinz): allow inner_shape to be a fully defined TensorShape.
# A "fully defined TensorShape" means one where the rank and all dimensions are
# known.
# Allowing inner_shape might mean allowing inner_shape to be initialized by
# a fully defined TensorShape, or it might mean that you can actually store
# TensorShape in the inner_shape field. This could conceivably construct
# a DynamicRaggedShape that was dtype agnostic.
#
# TODO(martinz): unify the impl of the determination of index type across
# RowPartition and DynamicRaggedShape.
@tf_export("experimental.DynamicRaggedShape")
class DynamicRaggedShape(extension_type.BatchableExtensionType):
"""The shape of a ragged or dense tensor.
Ragged shapes are encoded using two fields:
* `inner_shape`: An integer vector giving the shape of a dense tensor.
* `row_partitions`: A list of `RowPartition` objects, describing how
that flat shape should be partitioned to add ragged axes.
If a DynamicRaggedShape is the shape of a RaggedTensor rt, then:
1. row_partitions = rt._nested_row_partitions
(and thus len(row_partitions) > 0)
2. inner_shape is the shape of rt.flat_values
If a DynamicRaggedShape is the shape of a dense tensor t, then:
1. row_partitions = []
2. inner_shape is the shape of t.
Examples:
The following table gives a few examples (where `RP(lengths)` is short
for `RowPartition.from_lengths(lengths)`):
Row Partitions | Inner Shape | Example Tensor
--------------------------- | ------------ | ----------------------------
[] | [2, 3] | `[[1, 2, 3], [4, 5, 6]]`
[RP([2, 0, 3])] | [5] | `[[1, 2], [], [3, 4, 5]]`
[RP([2, 1])] | [3, 2] | `[[[1, 2], [3, 4]], [[5, 6]]]`
[RP([2, 1]), RP([2, 1, 2])] | [5] | `[[[1, 2], [3]], [[4, 5]]]`
"""
_row_partitions: Tuple[RowPartition, ...]
_inner_shape: tensor_lib.Tensor
_static_inner_shape: tensor_shape.TensorShape
__batch_encoder__ = _DynamicRaggedShapeBatchEncoder()
__name__ = "tf.DynamicRaggedShape"
def __init__(self,
row_partitions: Sequence[RowPartition],
inner_shape: core.TensorLike,
dtype: Optional[dtypes.DType] = None,
validate: bool = False,
static_inner_shape: ... = None):
"""Core constructor for a DynamicRaggedShape.
Create a DynamicRaggedShape. This can be used to construct a
DynamicRaggedShape representing a ragged or dense shape. If row_partitions
is an empty list, then this is equivalent to a dense shape.
If row_partitions is specified, then the num_row_partitions will be equal
to len(row_partitions). There are several checks made.
Specifically:
1. Consecutive row_partitions must have consistent nvals and nrows.
2. The last row_partitions must have nvals equal to the first element of
inner_shape.
The inner_shape is converted to a tensor.
All row_partitions and the inner_shape are converted to the same dtype
(int64 or int32).
Args:
row_partitions: the row_partitions of the shape.
inner_shape: if len(row_partitions) > 0, the shape of the flat_values.
Otherwise, the shape of the tensor.
dtype: tf.int64, tf.int32, or None representing the preferred dtype.
validate: if true, dynamic validation is applied to the shape.
static_inner_shape: if len(row_partitions) > 0, the static shape of the
flat_values. Otherwise, the static shape of the tensor. Should be
convertible to a TensorShape.
"""
if not isinstance(row_partitions, Iterable):
raise TypeError(
"row_partitions should be a list of row partitions. Instead, got " +
str(row_partitions))
for x in row_partitions:
if not isinstance(x, RowPartition):
raise TypeError("row_partitions contains " + str(x) +
" which is not a RowPartition")
dtype = _find_dtype_iterable(row_partitions, dtype)
dtype = _find_dtype(inner_shape, dtype)
if (isinstance(inner_shape, np.ndarray) and
inner_shape.dtype == np.int32 and dtype is None):
dtype = dtypes.int32
dtype = _find_dtype(dtypes.int64, dtype)
row_partitions = tuple([rp.with_dtype(dtype) for rp in row_partitions])
self._row_partitions = row_partitions
self._inner_shape = ops.convert_to_tensor(
inner_shape, dtype_hint=dtype, name="inner_dim_sizes")
if self._inner_shape.dtype != dtype:
self._inner_shape = math_ops.cast(self._inner_shape, dtype)
checks = []
# Validate shapes.
if self._row_partitions:
for axis, rp in enumerate(self._row_partitions):
if axis > 0:
previous_row_partition = self._row_partitions[axis - 1]
msg = ("RowPartitions in DynamicRaggedShape do not align "
f"between {axis - 1} and {axis}")
static_nrows = rp.static_nrows
static_nvals = previous_row_partition.static_nvals
if (static_nrows is not None) and (static_nvals is not None):
if static_nrows != static_nvals:
raise ValueError(msg)
else:
continue
if validate:
checks.append(
check_ops.assert_equal(
previous_row_partition.nvals(), rp.nrows(), message=msg))
self._inner_shape.shape.assert_has_rank(1)
self._static_inner_shape = tensor_util.constant_value_as_shape(
self._inner_shape)
if static_inner_shape is not None:
self._static_inner_shape = self._static_inner_shape.merge_with(
static_inner_shape)
if row_partitions:
last_row_partition = row_partitions[-1]
static_nvals = last_row_partition.static_nvals
static_inner_shape_nvals = tensor_shape.dimension_value(
self._static_inner_shape[0])
if static_nvals is not None and static_inner_shape_nvals is not None:
if static_nvals != static_inner_shape_nvals:
raise ValueError("Last row partition does not match inner_shape.")
elif validate:
checks.append(
check_ops.assert_equal(
last_row_partition.nvals(),
self._inner_shape[0],
message="Last row partition does not match inner_shape."))
if checks:
self._inner_shape = control_flow_ops.with_dependencies(
checks, self._inner_shape, name="inner_shape_validated")
self._row_partitions = [
rp._with_dependencies(checks) for rp in self._row_partitions # pylint: disable=protected-access
]
@classmethod
def from_lengths(cls,
lengths: Sequence[Union[Sequence[int], int]],
num_row_partitions=None,
dtype=dtypes.int64):
"""Creates a shape with the given lengths and num_row_partitions.
The lengths can either be a nonnegative int or a list of nonnegative ints.
If num_row_partitions is None, then the minimal num_row_partitions is used.
For example, [2, (3, 2)] is the shape of [[0, 0, 0], [0, 0]], and
[2, 2] is the shape of [[0, 0], [0, 0]]
This chooses the minimal num_row_partitions required (including zero).
The following table gives a few examples (where `RP(lengths)` is short
for `RowPartition.from_lengths(lengths)`):
For example:
from_lengths | row_partitions | inner_shape
---------------------- | --------------------------| -------------
[] | [] | []
[2, (3, 2)] | [RP([3, 2])] | [5]
[2, 2] | [] | [2, 2]
[2, (3, 2), 7] | [RP([3, 2])] | [5, 7]
[2, (2, 2), 3] | [RP([2, 2])] | [4, 3]
[2, 2, 3] | [] | [2, 2, 3]
[2, (2, 1), (2, 0, 3)] | [RP(2, 1), RP([2, 0, 3])] | [5]
If we want the row partitions to end with uniform row partitions, then
we can set num_row_partitions.
For example,
below URP(3, 12) is RowPartition.from_uniform_row_length(3, 12)
from_lengths | num_row_partitions | row_partitions | inner_shape
---------------| -------------------|--------------------------|------------
[2, (3, 2), 2] | 2 | [RP([3, 2]), URP(2, 10)] | [10]
[2, 2] | 1 | [URP(2, 4)] | [4]
[2, 2, 3] | 0 | [] | [2, 2, 3]
[2, 2, 3] | 1 | [URP(2, 4)] | [4, 3]
[2, 2, 3] | 2 | [URP(2, 4), URP(3, 12)] | [12]
Representing the shapes from init():
from_lengths | Tensor Example
------------------------ | ------------------------------
`[2, 3]` | `[[1, 2, 3], [4, 5, 6]]`
`[3, (2, 0, 3)]` | `[[1, 2], [], [3, 4, 5]]`
`[2, (2, 1), 2]` | `[[[1, 2], [3, 4]], [[5, 6]]]`
`[2, (2, 1), (2, 1, 2)]` | `[[[1, 2], [3]], [[4, 5]]]`
Args:
lengths: the lengths of sublists along each axis.
num_row_partitions: the num_row_partitions of the result or None
indicating the minimum number of row_partitions.
dtype: the dtype of the shape (tf.int32 or tf.int64).
Returns:
a new DynamicRaggedShape
"""
if not isinstance(lengths, list):
raise ValueError("lengths should be a list")
for x in lengths:
if not _is_int_or_tuple_of_ints(x):
raise ValueError(
"element of lengths should be int or tuple of ints: instead %r" %
(x,))
if num_row_partitions is None:
# Calculate the minimal num_row_partitions.
is_list = [not isinstance(x, int) for x in lengths]
if any(is_list):
# Last index when not a list.
num_row_partitions = len(is_list) - is_list[-1::-1].index(True) - 1
else:
num_row_partitions = 0
if not isinstance(num_row_partitions, int):
raise ValueError("num_row_partitions should be an int or None")
if not lengths:
if num_row_partitions > 0:
raise ValueError("num_row_partitions==0 for a scalar shape")
return DynamicRaggedShape([], [], dtype=dtype)
if not num_row_partitions < len(lengths):
raise ValueError("num_row_partitions should be less than `len(lengths)` "
"if shape is not scalar.")
if num_row_partitions > 0:
(row_partitions, nvals) = _to_row_partitions_and_nvals_from_lengths(
lengths[:num_row_partitions + 1])
inner_shape = [nvals] + lengths[num_row_partitions + 1:]
return DynamicRaggedShape(row_partitions, inner_shape, dtype=dtype)
else:
return DynamicRaggedShape([], lengths, dtype=dtype)
@classmethod
def from_row_partitions(cls, row_partitions, dtype=None):
"""Create a shape from row_partitions.
Args:
row_partitions: a nonempty list of RowPartition objects.
dtype: the dtype to use, or None to use the row_partitions dtype.
Returns:
a DynamicRaggedShape with inner_rank==1.
"""
if not row_partitions:
raise ValueError("row_partitions cannot be empty")
inner_shape = [row_partitions[-1].nvals()]
return DynamicRaggedShape(row_partitions, inner_shape, dtype=dtype)
@classmethod
def _from_inner_shape(cls, inner_shape, dtype=None):
"""Create a shape from inner_shape, where num_row_partitions == 0."""
return DynamicRaggedShape([], inner_shape, dtype=dtype)
# pylint: disable=protected-access
@classmethod
def from_tensor(cls, t, dtype=None):
"""Constructs a ragged shape for a potentially ragged tensor."""
if ragged_tensor.is_ragged(t):
return DynamicRaggedShape(
t._nested_row_partitions, _flat_values_shape(t), dtype=dtype)
else:
return DynamicRaggedShape._from_inner_shape(
array_ops.shape(t), dtype=dtype)
@property
def row_partitions(self):
"""The row_partitions of the shape."""
return self._row_partitions
@property
def num_row_partitions(self):
"""The number of row_partitions of the shape."""
return len(self._row_partitions)
@property
def dtype(self):
"""The dtype of the shape -- one of tf.int32 or tf.int64."""
return self._inner_shape.dtype
def _static_inner_shape_as_list(self, truncate_first):
"""Returns the lengths of the inner shape (if rank known), or [...]."""
if self._static_inner_shape.rank is None:
return [...]
result = self._static_inner_shape.as_list()
if truncate_first:
return result[1:]
return result
def static_lengths(self, ragged_lengths=True):
"""Returns a list of statically known axis lengths.
This represents what values are known. For each row partition, it presents
either the uniform row length (if statically known),
the list of row lengths, or none if it is not statically known.
For the inner shape, if the rank is known, then each dimension is reported
if known, and None otherwise. If the rank of the inner shape is not known,
then the returned list ends with an ellipsis.
Args:
ragged_lengths: If false, returns None for all ragged dimensions.
Returns:
A Sequence[Union[Sequence[int],int, None]] of lengths, with a possible
Ellipsis at the end.
"""
if self.num_row_partitions == 0:
return self._static_inner_shape_as_list(False)
first_dim = self.row_partitions[0].static_nrows
if isinstance(first_dim, tensor_shape.Dimension):
first_dim = first_dim.value
rp_dims = [first_dim]
for rp in self.row_partitions:
if rp.is_uniform():
rp_dims.append(rp.static_uniform_row_length)
elif ragged_lengths:
const_vals = tensor_util.constant_value(rp.row_lengths())
if const_vals is None:
rp_dims.append(None)
else:
rp_dims.append(tuple(const_vals.tolist()))
else:
rp_dims.append(None)
return rp_dims + self._static_inner_shape_as_list(True)
def __repr__(self):
lengths = _list_with_ellipsis_to_str(self.static_lengths())
return ("<DynamicRaggedShape "
"lengths=%s num_row_partitions=%r>" %
(lengths, self.num_row_partitions))
def _to_tensor_shape(self) -> tensor_shape.TensorShape:
"""Returns a TensorShape representation of the shape."""
lengths = self.static_lengths(ragged_lengths=False)
if not lengths:
return tensor_shape.TensorShape(())
if lengths[-1] == Ellipsis:
return tensor_shape.TensorShape(None)
return tensor_shape.TensorShape(lengths)
def _slice_shape(self, start, stop):
"""Returns a shape self[start:stop].
If start == 0, then this truncates dimensions after stop.
If start != 0, then this will return a shape with num_row_partitions == 0.
See __getitem__.
Args:
start: the first dimension. 0 <= start <= rank
stop: the last dimension (exclusive). 0 <= stop <= rank
"""
if stop <= start:
return DynamicRaggedShape._from_inner_shape([])
elif start == 0:
if stop <= self.num_row_partitions:
if stop == 1:
return DynamicRaggedShape._from_inner_shape(
[self.row_partitions[0].nrows()])
new_row_partitions = self.row_partitions[:stop - 1]
new_inner_shape = [new_row_partitions[-1].nvals()]
return DynamicRaggedShape(new_row_partitions, new_inner_shape)
else:
if self.rank is None:
new_inner_rank = stop - self.num_row_partitions
new_inner_shape = self.inner_shape[:new_inner_rank]
return DynamicRaggedShape(
row_partitions=self.row_partitions,
inner_shape=new_inner_shape,
static_inner_shape=None,
validate=False)
elif self.rank <= stop:
return self
new_inner_rank = stop - self.num_row_partitions
new_inner_shape = self.inner_shape[:new_inner_rank]
return DynamicRaggedShape(
row_partitions=self.row_partitions,
inner_shape=new_inner_shape,
static_inner_shape=tensor_shape.TensorShape([None] *
new_inner_rank),
validate=False)
else:
if self.rank is None or stop < self.rank:
partial = self._slice_shape(0, stop)
else:
partial = self
for x in partial.row_partitions:
if not x.is_uniform():
raise ValueError("All relevant dimensions must be uniform")
if partial.rank is None:
# TODO(martinz): Implement _with_num_row_partitions(0) if rank is
# unknown, and remove.
raise NotImplementedError(
"__getitem__[start:stop] where start > 0 not implemented")
return DynamicRaggedShape._from_inner_shape(
partial._with_num_row_partitions(0).inner_shape[start:])
def _dimension(self, index):
"""Return a dimension, if the dimension is not ragged (see __getitem__)."""
rank = self.rank
if not isinstance(index, int):
raise TypeError("index should be an int")
if (self.num_row_partitions == 0 or index > self.num_row_partitions + 1):
# If num_row_partitions > 0 and index <= num_row_partitions + 1, then
# we are safe.
if rank is None:
raise ValueError(
"Rank must be known to use __getitem__ on a large index.")
if index >= rank:
raise IndexError("Index is too big: " + str(index) + ">=" + str(rank))
if index < 0:
raise IndexError("Index must be non-negative: " + str(index))
elif not self.is_uniform(index):
raise ValueError("Index " + str(index) + " is not uniform")
elif index == 0 and self.num_row_partitions > 0:
static_nrows = self.row_partitions[0].static_nrows
if static_nrows is not None:
return constant_op.constant(static_nrows, dtype=self.dtype)
return self.row_partitions[0].nrows()
elif self.num_row_partitions == 0:
static_result = tensor_shape.dimension_value(
self._static_inner_shape[index])
if static_result is not None:
return constant_op.constant(static_result, dtype=self.dtype)
return self.inner_shape[index]
elif index > self.num_row_partitions:
static_result = tensor_shape.dimension_value(
self._static_inner_shape[index - self.num_row_partitions])
if static_result is not None:
return constant_op.constant(static_result, dtype=self.dtype)
return self.inner_shape[index - self.num_row_partitions]
else:
return self.row_partitions[index - 1].uniform_row_length()
def __getitem__(self, index):
"""Returns a dimension or a slice of the shape.
Ragged shapes can have ragged dimensions that depend upon other dimensions.
Therefore, if you ask for a dimension that is ragged, this function returns
a ValueError. For similar reasons, if a slice is selected that includes
a ragged dimension without including the zero dimension, then this fails.
Any slice that does not start at zero will return a shape
with num_row_partitions == 0.
Args:
index: the index: can be an int or a slice.
Raises:
IndexError: if the index is not in range.
ValueError: if the rank is unknown, or a ragged rank is requested
incorrectly.
"""
rank = self.rank
if isinstance(index, slice):
if (index.step is not None) and (index.step != 1):
raise IndexError("Cannot stride through a shape")
start = index.start
stop = index.stop
if start is None:
start = 0
start = _fix_start_index(start, rank, self.num_row_partitions)
stop = _fix_stop_index(stop, rank)
return self._slice_shape(start, stop)
elif isinstance(index, int):
if index < 0:
if rank is None:
raise ValueError(
"Rank must be known to use __getitem__ with a negative index.")
return self._dimension(rank + index)
return self._dimension(index)
else:
raise TypeError("Argument is not an int or a slice")
def _num_elements(self):
"""Number of elements in a shape.
Returns:
The number of elements in the shape.
"""
return math_ops.reduce_prod(self.inner_shape)
def _num_slices_in_dimension(self, axis):
"""The total size of a dimension (like nvals).
Effectively, this is self[:axis+1]._num_elements()
Example:
shape = DynamicRaggedShape._from_inner_shape([2, 3, 4])
shape._num_slices_in_dimension(0) = 2
shape._num_slices_in_dimension(1) = 6
shape._num_slices_in_dimension(2) = 24
shape._num_slices_in_dimension(-1) = 24
shape._num_slices_in_dimension(-2) = 6
shape._num_slices_in_dimension(-2) = 2
Args:
axis: the last axis to include in the number of elements. If negative,
then axis = axis + rank.
Returns:
The number of elements in the shape.
"""
if not isinstance(axis, int):
raise TypeError("axis must be an integer")
if axis < 0:
rank = self.rank
if rank is None:
raise ValueError(
"You can't use negative values if the rank is undefined")
axis = axis + rank
if axis == 0:
return self._dimension(0)
if axis <= self.num_row_partitions:
return self.row_partitions[axis - 1].nvals()
# If self.num_row_partitions = 1, and
# self.inner_shape=[3,5,6], and axis=2, then you want:
# 15 = 3 * 5 = math_ops.reduce_prod(self.inner_shape[:2])
# 2 = axis - (self.num_row_partitions - 1)
# If num_row_partitions=0, and
# self.inner_shape=[3,5,6] and axis=2, then you want:
# 90 = 3 * 5 * 6 = math_ops.reduce_prod(self.inner_shape[:3])
# 3 = axis - (self.num_row_partitions - 1)
remainder = axis - (self.num_row_partitions - 1)
return _reduce_prod_patch(self.inner_shape[:remainder])
def is_uniform(self, axis):
"""Returns true if the indicated dimension is uniform."""
if not isinstance(axis, int):
raise TypeError("axis must be an integer")
rank = self.rank
if axis < 0:
raise IndexError("Negative axis values are not supported")
elif rank is not None and axis >= rank:
raise IndexError("Expected axis=%s < rank=%s" % (axis, rank))
else:
return ((axis == 0 or axis > len(self._row_partitions)) # pylint:disable=superfluous-parens
or self._row_partitions[axis - 1].is_uniform())
@property
def rank(self):
"""The number of dimensions in this shape, or None if unknown."""
inner_rank = self.inner_rank
if inner_rank is None:
return None
else:
return self.num_row_partitions + inner_rank
@property
def inner_shape(self):
"""The inner dimension sizes for this shape.
Returns:
A 1-D integer `Tensor`.
"""
return self._inner_shape
@property
def inner_rank(self):
"""The rank of inner_shape."""
return tensor_shape.dimension_value(self._static_inner_shape.rank)
def _alt_inner_shape(self, new_inner_rank):
"""Get an alternative inner shape with higher or lower rank.
For the rank of the inner shape to be be higher, the last few ragged
dimensions must have uniform_row_length.
Args:
new_inner_rank: the new rank of the inner_shape
Returns:
A new inner_shape of rank new_inner_rank.
"""
if new_inner_rank == 0:
raise ValueError("new_inner_rank cannot be zero")
elif self.inner_rank == 0:
raise ValueError("old inner_rank cannot be zero")
elif new_inner_rank == self.inner_rank:
return self.inner_shape
elif new_inner_rank < self.inner_rank:
if self._static_inner_shape.is_fully_defined():
return _alt_inner_shape_from_tensor_shape(self._static_inner_shape,
self.dtype, new_inner_rank)
first_dimension = self._num_slices_in_dimension(-new_inner_rank)
if new_inner_rank == 1:
return array_ops.expand_dims(first_dimension, 0)
remaining_dimensions = self.inner_shape[1 - new_inner_rank:]
return array_ops.concat(
[array_ops.expand_dims(first_dimension, 0), remaining_dimensions],
axis=0)
else:
assert new_inner_rank > self.inner_rank
new_dimensions = new_inner_rank - self.inner_rank
if any(
[not x.is_uniform() for x in self.row_partitions[-new_dimensions:]]):
raise ValueError("Cannot get an inner shape over a ragged dimension")
first_dimension = self._num_slices_in_dimension(-new_inner_rank)
new_dimensions = new_inner_rank - self.inner_rank
new_dims = [first_dimension] + [
x.uniform_row_length() for x in self.row_partitions[-new_dimensions:]
]
return array_ops.concat(
[array_ops_stack.stack(new_dims), self.inner_shape[1:]], axis=0)
def _inner_shape_dim(self, dimension):
"""Returns an int or a tensor representing _inner_shape[dimension]."""
result = tensor_shape.dimension_value(self._static_inner_shape[dimension])
return self._inner_shape[dimension] if result is None else result
def _with_inner_rank(self, inner_rank):
"""Returns the same shape but a different inner_rank.
All dimensions that are to be represented in the inner_shape must be dense.
See inner_rank.
Args:
inner_rank: the new inner_rank of the shape.
Returns:
the same shape but a different inner_rank
Raises:
ValueError if the new dense rank is invalid, or the old rank is unknown.
"""
rank = self.rank
if rank is None:
raise ValueError("Rank must be known to adjust inner_rank")
elif rank < 2:
if inner_rank == rank:
return self
raise ValueError("Cannot change inner_rank if rank < 2")
else:
# When self.rank is not None:
# self.rank = self.inner_rank + self.num_row_partitions
new_num_row_partitions = rank - inner_rank
return self._with_num_row_partitions(new_num_row_partitions)
def _with_num_row_partitions(self, num_row_partitions):
"""Creates an identical shape with the given num_row_partitions.
Note that the shape must be statically refactorable to this rank.
In particular:
* rank must be known.
* num_row_partitions must be a nonnegative int.
* num_row_partitions must be less than the rank of the shape
* num_row_partitions must be greater or equal to the index of any ragged
dimension.
Note that if the num_row_partitions is the same, self is returned.
Args:
num_row_partitions: the target num_row_partitions (must be a nonnegative
int).
Returns:
a shape with a (possibly) different num_row_partitions.
Raises:
ValueError: if the rank is unknown, the argument is not a nonnegative int,
or there is a dimension that is nonuniform.
"""
rank = self.rank
if rank is None:
raise ValueError("Rank must be known to adjust num_row_partitions")
if not isinstance(num_row_partitions, int):
raise ValueError("num_row_partitions must be an int")
if num_row_partitions < 0:
raise ValueError("num_row_partitions must be nonnegative")
if num_row_partitions == self.num_row_partitions:
return self
if num_row_partitions >= rank:
raise ValueError("num_row_partitions must be less than rank")
if num_row_partitions > self.num_row_partitions:
num_row_partitions_diff = num_row_partitions - self.num_row_partitions
new_inner_rank = self.rank - num_row_partitions
nvals = self._inner_shape_dim(0)
more_rp = []
for i in range(num_row_partitions_diff):
nrows = nvals
row_length = self._inner_shape_dim(i + 1)
nvals = nrows * row_length
rp = RowPartition.from_uniform_row_length(
row_length, nrows=nrows, dtype=self.dtype)
more_rp.append(rp)
alt_inner = self._alt_inner_shape(new_inner_rank)
return DynamicRaggedShape(list(self.row_partitions) + more_rp, alt_inner)
else:
assert num_row_partitions < self.num_row_partitions
return DynamicRaggedShape(
self.row_partitions[:num_row_partitions],
self._alt_inner_shape(self.rank - num_row_partitions))
def _merge_dims(self, outer_axis: int,
inner_axis: int) -> "DynamicRaggedShape":
"""Merges outer_axis...inner_axis into a single dimension.
Returns a copy of this shape with the specified range of dimensions
flattened into a single dimension, with elements in row-major order.
#### Examples:
>>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1),
... (1,2,3)])._merge_dims(0, 1)
<DynamicRaggedShape lengths=[3, (1, 2, 3)] num_row_partitions=1>
>>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1),
... (1,2,3)])._merge_dims(1, 2)
<DynamicRaggedShape lengths=[2, (3, 3)] num_row_partitions=1>
>>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1),
... (1,2,3)])._merge_dims(0, 2)
<DynamicRaggedShape lengths=[6] num_row_partitions=0>
To mimic the behavior of `np.flatten` (which flattens all dimensions), use
`rt.merge_dims(0, -1). To mimic the behavior of `tf.layers.Flatten` (which
flattens all dimensions except the outermost batch dimension), use
`rt.merge_dims(1, -1)`.
Args:
outer_axis: `int`: The first dimension in the range of dimensions to
merge. May be negative if `self.shape.rank` is statically known.
inner_axis: `int`: The last dimension in the range of dimensions to merge.
May be negative if `self.shape.rank` is statically known.
Returns:
A copy of this shape, with the specified dimensions merged into a
single dimension. The returned shape will be
`self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
is the total number of slices in the merged dimensions.
"""
outer_axis = array_ops.get_positive_axis(
outer_axis, self.rank, axis_name="outer_axis", ndims_name="rank(self)")
inner_axis = array_ops.get_positive_axis(
inner_axis, self.rank, axis_name="inner_axis", ndims_name="rank(self)")
if not outer_axis <= inner_axis:
raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or "
f"equal to inner_axis ({inner_axis}).")
if outer_axis == inner_axis:
return self
if self.num_row_partitions == 0:
# A dense tensor.
(new_inner_shape,
new_static_inner_shape) = _merge_inner_shape(self._inner_shape,
self._static_inner_shape,
outer_axis, inner_axis)
return DynamicRaggedShape([],
new_inner_shape,
dtype=self.dtype,
static_inner_shape=new_static_inner_shape)
if inner_axis <= self.num_row_partitions:
# Here, we are merging the row_partitions,
# but the inner_shape is unchanged.
if outer_axis == 0:
# There is no need to merge axes before the first, just truncate them.
return DynamicRaggedShape(
self._row_partitions[inner_axis:],
self.inner_shape,
dtype=self.dtype,
static_inner_shape=self._static_inner_shape)
prefix_rp = self._row_partitions[:outer_axis - 1]
suffix_rp = self._row_partitions[inner_axis:]
internal_rp = self._row_partitions[outer_axis - 1:inner_axis]
new_rp = prefix_rp + (_merge_row_partitions(internal_rp),) + suffix_rp
return DynamicRaggedShape(
new_rp,
self.inner_shape,
dtype=self.dtype,
static_inner_shape=self._static_inner_shape)
elif outer_axis > self.num_row_partitions:
# In this scenario, only the inner_shape is changed.
# Example #1:
# if [2, (1, 2), 5, 3], num_row_partitions=1, outer_axis=2, inner_axis=3.
# Result: [2, (1, 2), 15], num_row_partitions=1, outer_axis=2,
# inner_axis=3.
(new_inner_shape, new_static_inner_shape) = _merge_inner_shape(
self._inner_shape, self._static_inner_shape,
outer_axis - self.num_row_partitions,
inner_axis - self.num_row_partitions)
return DynamicRaggedShape(
self._row_partitions,
new_inner_shape,
dtype=self.dtype,
static_inner_shape=new_static_inner_shape)
else:
# Here, both inner_shape and row_partitions are changed.
rank = self.rank
if rank is None:
raise ValueError("Cannot merge_dims of the inner shape if the " +
"dimension of inner_shape is unknown")
if outer_axis == 0:
new_inner_shape = self._alt_inner_shape(rank - inner_axis)
return DynamicRaggedShape._from_inner_shape(new_inner_shape)
else:
prefix = self._row_partitions[:outer_axis - 1]
suffix = _merge_row_partitions(self._row_partitions[outer_axis - 1:])
new_inner_shape = self._alt_inner_shape(rank - inner_axis)
num_merged_inner = inner_axis - self.num_row_partitions
prod = _reduce_prod_patch(self._inner_shape[1:num_merged_inner + 1])
tail_suffix = RowPartition.from_row_splits(suffix.row_splits() * prod)
return DynamicRaggedShape(prefix + (tail_suffix,), new_inner_shape)
def with_dtype(self, dtype):
"""Change the dtype of the shape."""
if dtype == self.dtype:
return self
else:
return DynamicRaggedShape(
self.row_partitions, self.inner_shape, dtype=dtype)
def _merge_with(self, other: "DynamicRaggedShape") -> "DynamicRaggedShape":
"""Merge two shapes that are equal modulo num_row_partitions.
The resulting num_row_partitions is the maximum of the two
num_row_partitions.
Args:
other: a DynamicRaggedShape representing the same shape with a possibly
different number of row partitions.