-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtensor_util.py
1158 lines (986 loc) · 41.2 KB
/
tensor_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to create TensorProtos."""
import typing
from typing import Protocol
import numpy as np
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.python.client import pywrap_tf_session as c_api
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import tensor_shape
from tensorflow.python.types import core
from tensorflow.python.types import internal
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# Fallback in case fast_tensor_util is not properly compiled.
# pylint: disable=g-import-not-at-top
try:
from tensorflow.python.framework import fast_tensor_util
_FAST_TENSOR_UTIL_AVAILABLE = True
except ImportError:
_FAST_TENSOR_UTIL_AVAILABLE = False
# pylint: enable=g-import-not-at-top
def ExtractBitsFromFloat16(x):
return np.asarray(x, dtype=np.float16).view(np.uint16).item()
def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.half_val.extend(
[ExtractBitsFromFloat16(x) for x in proto_values])
def _MediumAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
# TODO: Remove the conversion if cython supports np.float16_t
fast_tensor_util.AppendFloat16ArrayToTensorProto(
tensor_proto,
np.asarray(proto_values, dtype=np.float16).view(np.uint16))
def ExtractBitsFromBFloat16(x):
return np.asarray(
x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16).item()
def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.half_val.extend(
[ExtractBitsFromBFloat16(x) for x in proto_values])
def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
fast_tensor_util.AppendBFloat16ArrayToTensorProto(
tensor_proto, np.asarray(
proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16))
def ExtractBitsFromFloat8e5m2(x):
return np.asarray(
x, dtype=dtypes.float8_e5m2.as_numpy_dtype).view(np.uint8).item()
def SlowAppendFloat8e5m2ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.half_val.extend(
[ExtractBitsFromFloat8e5m2(x) for x in proto_values])
def FastAppendFloat8e5m2ArrayToTensorProto(tensor_proto, proto_values):
fast_tensor_util.AppendFloat8ArrayToTensorProto(
tensor_proto,
np.asarray(proto_values,
dtype=dtypes.float8_e5m2.as_numpy_dtype).view(np.uint8))
def ExtractBitsFromFloat8e4m3fn(x):
return np.asarray(
x, dtype=dtypes.float8_e4m3fn.as_numpy_dtype).view(np.uint8).item()
def SlowAppendFloat8e4m3fnArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.half_val.extend(
[ExtractBitsFromFloat8e4m3fn(x) for x in proto_values])
def FastAppendFloat8e4m3fnArrayToTensorProto(tensor_proto, proto_values):
fast_tensor_util.AppendFloat8ArrayToTensorProto(
tensor_proto,
np.asarray(proto_values,
dtype=dtypes.float8_e4m3fn.as_numpy_dtype).view(np.uint8))
if _FAST_TENSOR_UTIL_AVAILABLE:
_NP_TO_APPEND_FN = {
dtypes.bfloat16.as_numpy_dtype:
FastAppendBFloat16ArrayToTensorProto,
dtypes.float8_e5m2.as_numpy_dtype:
FastAppendFloat8e5m2ArrayToTensorProto,
dtypes.float8_e4m3fn.as_numpy_dtype:
FastAppendFloat8e4m3fnArrayToTensorProto,
np.float16:
_MediumAppendFloat16ArrayToTensorProto,
np.float32:
fast_tensor_util.AppendFloat32ArrayToTensorProto,
np.float64:
fast_tensor_util.AppendFloat64ArrayToTensorProto,
np.int32:
fast_tensor_util.AppendInt32ArrayToTensorProto,
np.int64:
fast_tensor_util.AppendInt64ArrayToTensorProto,
np.uint8:
fast_tensor_util.AppendUInt8ArrayToTensorProto,
np.uint16:
fast_tensor_util.AppendUInt16ArrayToTensorProto,
np.uint32:
fast_tensor_util.AppendUInt32ArrayToTensorProto,
np.uint64:
fast_tensor_util.AppendUInt64ArrayToTensorProto,
np.int8:
fast_tensor_util.AppendInt8ArrayToTensorProto,
np.int16:
fast_tensor_util.AppendInt16ArrayToTensorProto,
np.complex64:
fast_tensor_util.AppendComplex64ArrayToTensorProto,
np.complex128:
fast_tensor_util.AppendComplex128ArrayToTensorProto,
np.object_:
fast_tensor_util.AppendObjectArrayToTensorProto,
np.bool_:
fast_tensor_util.AppendBoolArrayToTensorProto,
dtypes.qint8.as_numpy_dtype:
fast_tensor_util.AppendInt8ArrayToTensorProto,
dtypes.quint8.as_numpy_dtype:
fast_tensor_util.AppendUInt8ArrayToTensorProto,
dtypes.qint16.as_numpy_dtype:
fast_tensor_util.AppendInt16ArrayToTensorProto,
dtypes.quint16.as_numpy_dtype:
fast_tensor_util.AppendUInt16ArrayToTensorProto,
dtypes.qint32.as_numpy_dtype:
fast_tensor_util.AppendInt32ArrayToTensorProto,
# NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
}
else:
def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.float_val.extend([x.item() for x in proto_values])
def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.double_val.extend([x.item() for x in proto_values])
def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.int_val.extend([x.item() for x in proto_values])
def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.int64_val.extend([x.item() for x in proto_values])
def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.int_val.extend([x.item()[0] for x in proto_values])
def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.uint32_val.extend([x.item() for x in proto_values])
def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.uint64_val.extend([x.item() for x in proto_values])
def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.scomplex_val.extend(
[v.item() for x in proto_values for v in [x.real, x.imag]])
def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.dcomplex_val.extend(
[v.item() for x in proto_values for v in [x.real, x.imag]])
def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.bool_val.extend([x.item() for x in proto_values])
_NP_TO_APPEND_FN = {
dtypes.bfloat16.as_numpy_dtype:
SlowAppendBFloat16ArrayToTensorProto,
dtypes.float8_e5m2.as_numpy_dtype:
SlowAppendFloat8e5m2ArrayToTensorProto,
dtypes.float8_e4m3fn.as_numpy_dtype:
SlowAppendFloat8e4m3fnArrayToTensorProto,
np.float16:
SlowAppendFloat16ArrayToTensorProto,
np.float32:
SlowAppendFloat32ArrayToTensorProto,
np.float64:
SlowAppendFloat64ArrayToTensorProto,
np.int32:
SlowAppendIntArrayToTensorProto,
np.int64:
SlowAppendInt64ArrayToTensorProto,
np.uint8:
SlowAppendIntArrayToTensorProto,
np.uint16:
SlowAppendIntArrayToTensorProto,
np.uint32:
SlowAppendUInt32ArrayToTensorProto,
np.uint64:
SlowAppendUInt64ArrayToTensorProto,
np.int8:
SlowAppendIntArrayToTensorProto,
np.int16:
SlowAppendIntArrayToTensorProto,
np.complex64:
SlowAppendComplex64ArrayToTensorProto,
np.complex128:
SlowAppendComplex128ArrayToTensorProto,
np.object_:
SlowAppendObjectArrayToTensorProto,
np.bool_:
SlowAppendBoolArrayToTensorProto,
dtypes.qint8.as_numpy_dtype:
SlowAppendQIntArrayToTensorProto,
dtypes.quint8.as_numpy_dtype:
SlowAppendQIntArrayToTensorProto,
dtypes.qint16.as_numpy_dtype:
SlowAppendQIntArrayToTensorProto,
dtypes.quint16.as_numpy_dtype:
SlowAppendQIntArrayToTensorProto,
dtypes.qint32.as_numpy_dtype:
SlowAppendQIntArrayToTensorProto,
# NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
}
def GetFromNumpyDTypeDict(dtype_dict, dtype):
# NOTE: dtype_dict.get(dtype) always returns None.
for key, val in dtype_dict.items():
if key == dtype:
return val
return None
def GetNumpyAppendFn(dtype):
# numpy dtype for strings are variable length. We can not compare
# dtype with a single constant (np.string does not exist) to decide
# dtype is a "string" type. We need to compare the dtype.type to be
# sure it's a string type.
if dtype.type == np.bytes_ or dtype.type == np.str_:
if _FAST_TENSOR_UTIL_AVAILABLE:
return fast_tensor_util.AppendObjectArrayToTensorProto
else:
return SlowAppendObjectArrayToTensorProto
return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
def TensorShapeProtoToList(shape):
"""Convert a TensorShape to a list.
Args:
shape: A TensorShapeProto.
Returns:
List of integers representing the dimensions of the tensor.
"""
return [dim.size for dim in shape.dim]
def _GetDenseDimensions(list_of_lists):
"""Returns the inferred dense dimensions of a list of lists."""
if not isinstance(list_of_lists, (list, tuple)):
return []
elif not list_of_lists:
return [0]
else:
return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
def _FlattenToStrings(nested_strings):
if isinstance(nested_strings, (list, tuple)):
for inner in nested_strings:
for flattened_string in _FlattenToStrings(inner):
yield flattened_string
else:
yield nested_strings
_TENSOR_CONTENT_TYPES = frozenset([
dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8,
dtypes.int16, dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8,
dtypes.qint16, dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64,
dtypes.float8_e5m2, dtypes.float8_e4m3fn
])
# pylint: disable=invalid-name
def _check_failed(v):
# NB. none of the _check_* functions could raise a ValueError, so
# it is safe to use here.
raise ValueError(v)
def _check_quantized(values):
# Cannot rely on `nest` because the leaves are tuples.
if not isinstance(values, (list, tuple)):
_check_failed(values)
if isinstance(values, tuple):
_ = [_check_int(v) for v in values]
else:
_ = [_check_quantized(v) for v in values]
def _generate_isinstance_check(expected_types):
def inner(values):
for v in nest.flatten(values):
if not (isinstance(v, expected_types) or
(isinstance(v, np.ndarray) and
issubclass(v.dtype.type, expected_types))):
_check_failed(v)
return inner
_check_int = _generate_isinstance_check(
(compat.integral_types, tensor_shape.Dimension))
_check_float = _generate_isinstance_check(compat.real_types)
_check_complex = _generate_isinstance_check(compat.complex_types)
_check_str = _generate_isinstance_check(compat.bytes_or_text_types)
_check_bool = _generate_isinstance_check(bool)
def _check_not_tensor(values):
_ = [_check_failed(v) for v in nest.flatten(values)
if isinstance(v, core.Symbol)]
# pylint: enable=invalid-name
_TF_TO_IS_OK = {
dtypes.bool: _check_bool,
dtypes.complex128: _check_complex,
dtypes.complex64: _check_complex,
dtypes.float16: _check_float,
dtypes.float32: _check_float,
dtypes.float64: _check_float,
dtypes.int16: _check_int,
dtypes.int32: _check_int,
dtypes.int64: _check_int,
dtypes.int8: _check_int,
dtypes.qint16: _check_quantized,
dtypes.qint32: _check_quantized,
dtypes.qint8: _check_quantized,
dtypes.quint16: _check_quantized,
dtypes.quint8: _check_quantized,
dtypes.string: _check_str,
dtypes.uint16: _check_int,
dtypes.uint8: _check_int,
dtypes.uint32: _check_int,
dtypes.uint64: _check_int,
}
def _AssertCompatible(values, dtype):
if dtype is None:
fn = _check_not_tensor
else:
try:
fn = _TF_TO_IS_OK[dtype]
except KeyError:
# There isn't a specific fn, so we try to do the best possible.
if dtype.is_integer:
fn = _check_int
elif dtype.is_floating:
fn = _check_float
elif dtype.is_complex:
fn = _check_complex
elif dtype.is_quantized:
fn = _check_quantized
else:
fn = _check_not_tensor
try:
fn(values)
except ValueError as e:
[mismatch] = e.args
if dtype is None:
raise TypeError("Expected any non-tensor type, but got a tensor instead.")
else:
raise TypeError(f"Expected {dtype.name}, but got {mismatch} of type "
f"'{type(mismatch).__name__}'.")
def _is_array_like(obj): # pylint: disable=invalid-name
"""Check if a given object is array-like."""
if isinstance(obj, core.Symbol) and not isinstance(obj, core.Value): # pylint: disable=protected-access
# Tensor implements __array__ only so it can inform the user that it is not
# a valid array.
return False
# TODO(slebedev): an object could also implement C-level array interface.
if (callable(getattr(obj, "__array__", None)) or
isinstance(getattr(obj, "__array_interface__", None), dict)):
return True
try:
memoryview(obj)
except TypeError:
return False
else:
return not isinstance(obj, bytes)
# pylint: disable=invalid-name
@tf_export("make_tensor_proto")
def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False,
allow_broadcast=False):
"""Create a TensorProto.
In TensorFlow 2.0, representing tensors as protos should no longer be a
common workflow. That said, this utility function is still useful for
generating TF Serving request protos:
```python
request = tensorflow_serving.apis.predict_pb2.PredictRequest()
request.model_spec.name = "my_model"
request.model_spec.signature_name = "serving_default"
request.inputs["images"].CopyFrom(tf.make_tensor_proto(X_new))
```
`make_tensor_proto` accepts "values" of a python scalar, a python list, a
numpy ndarray, or a numpy scalar.
If "values" is a python scalar or a python list, make_tensor_proto
first convert it to numpy ndarray. If dtype is None, the
conversion tries its best to infer the right numpy data
type. Otherwise, the resulting numpy array has a compatible data
type with the given dtype.
In either case above, the numpy ndarray (either the caller provided
or the auto-converted) must have the compatible type with dtype.
`make_tensor_proto` then converts the numpy array to a tensor proto.
If "shape" is None, the resulting tensor proto represents the numpy
array precisely.
Otherwise, "shape" specifies the tensor's shape and the numpy array
can not have more elements than what "shape" specifies.
Args:
values: Values to put in the TensorProto.
dtype: Optional tensor_pb2 DataType value.
shape: List of integers representing the dimensions of tensor.
verify_shape: Boolean that enables verification of a shape of values.
allow_broadcast: Boolean that enables allowing scalars and 1 length vector
broadcasting. Cannot be true when verify_shape is true.
Returns:
A `TensorProto`. Depending on the type, it may contain data in the
"tensor_content" attribute, which is not directly useful to Python programs.
To access the values you should convert the proto back to a numpy ndarray
with `tf.make_ndarray(proto)`.
If `values` is a `TensorProto`, it is immediately returned; `dtype` and
`shape` are ignored.
Raises:
TypeError: if unsupported types are provided.
ValueError: if arguments have inappropriate values or if verify_shape is
True and shape of values is not equals to a shape from the argument.
"""
if allow_broadcast and verify_shape:
raise ValueError("allow_broadcast and verify_shape are not both allowed.")
if isinstance(values, tensor_pb2.TensorProto):
return values
if dtype:
dtype = dtypes.as_dtype(dtype)
is_quantized = (
dtype in [
dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16,
dtypes.qint32
])
if _is_array_like(values):
values = np.asarray(values)
# We first convert value to a numpy array or scalar.
if isinstance(values, (np.ndarray, np.generic)):
if dtype and dtype.is_numpy_compatible:
nparray = values.astype(dtype.as_numpy_dtype)
else:
nparray = values
else:
if values is None:
raise ValueError("None values not supported.")
# if dtype is provided, forces numpy array to be the type
# provided if possible.
if dtype and dtype.is_numpy_compatible:
np_dt = dtype.as_numpy_dtype
else:
np_dt = None
# If shape is None, numpy.prod returns None when dtype is not set, but
# raises exception when dtype is set to np.int64
if shape is not None and np.prod(shape, dtype=np.int64) == 0:
nparray = np.empty(shape, dtype=np_dt)
else:
_AssertCompatible(values, dtype)
nparray = np.array(values, dtype=np_dt)
# check to them.
# We need to pass in quantized values as tuples, so don't apply the shape
if (list(nparray.shape) != _GetDenseDimensions(values) and
not is_quantized):
raise ValueError(f"Expected values {values} to be a dense tensor with "
f"shape {_GetDenseDimensions(values)}, but got shape "
f"{list(nparray.shape)}.")
# python/numpy default float type is float64. We prefer float32 instead.
if (nparray.dtype == np.float64) and dtype is None:
nparray = nparray.astype(np.float32)
# python/numpy default int type is int64. We prefer int32 instead.
elif (nparray.dtype == np.int64) and dtype is None:
downcasted_array = nparray.astype(np.int32)
# Do not down cast if it leads to precision loss.
if np.array_equal(downcasted_array, nparray):
nparray = downcasted_array
# if dtype is provided, it must be compatible with what numpy
# conversion says.
numpy_dtype = dtypes.as_dtype(nparray.dtype)
if numpy_dtype is None:
raise TypeError(f"Unrecognized data type: {nparray.dtype}.")
# If dtype was specified and is a quantized type, we convert
# numpy_dtype back into the quantized version.
if is_quantized:
numpy_dtype = dtype
if dtype is not None and (not hasattr(dtype, "base_dtype") or
dtype.base_dtype != numpy_dtype.base_dtype):
raise TypeError(f"`dtype` {dtype} is not compatible with {values} of "
f"dtype {nparray.dtype}.")
# If shape is not given, get the shape from the numpy array.
if shape is None:
shape = nparray.shape
is_same_size = True
shape_size = nparray.size
else:
shape = [int(dim) for dim in shape]
shape_size = np.prod(shape, dtype=np.int64)
is_same_size = shape_size == nparray.size
if allow_broadcast:
if nparray.shape == (1,) or nparray.shape == tuple():
pass
elif nparray.size != shape_size:
raise TypeError(f"Expected Tensor's shape: {tuple(shape)}, but got "
f"{nparray.shape}.")
else:
if verify_shape and nparray.shape != tuple(shape):
raise TypeError(f"Expected Tensor's shape: {tuple(shape)}, but got "
f"{nparray.shape}.")
if nparray.size > shape_size:
raise ValueError("Too many elements provided. Takes at most "
f"{shape_size:d}, but got {nparray.size:d}.")
tensor_proto = tensor_pb2.TensorProto(
dtype=numpy_dtype.as_datatype_enum,
tensor_shape=tensor_shape.as_shape(shape).as_proto())
if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
if nparray.size * nparray.itemsize >= (1 << 31):
raise ValueError(
"Cannot create a tensor proto whose content is larger than 2GB.")
tensor_proto.tensor_content = nparray.tobytes()
return tensor_proto
# If we were not given values as a numpy array, compute the proto_values
# from the given values directly, to avoid numpy trimming nulls from the
# strings. Since values could be a list of strings, or a multi-dimensional
# list of lists that might or might not correspond to the given shape,
# we flatten it conservatively.
if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray):
proto_values = _FlattenToStrings(values)
# At this point, values may be a list of objects that we could not
# identify a common type for (hence it was inferred as
# np.object_/dtypes.string). If we are unable to convert it to a
# string, we raise a more helpful error message.
#
# Ideally, we'd be able to convert the elements of the list to a
# common type, but this type inference requires some thinking and
# so we defer it for now.
try:
str_values = [compat.as_bytes(x) for x in proto_values]
except TypeError:
raise TypeError(f"Failed to convert elements of {values} to Tensor. "
"Consider casting elements to a supported type. See "
"https://www.tensorflow.org/api_docs/python/tf/dtypes "
"for supported TF dtypes.")
tensor_proto.string_val.extend(str_values)
return tensor_proto
# TensorFlow expects C order (a.k.a., eigen row major).
proto_values = nparray.ravel()
append_fn = GetNumpyAppendFn(proto_values.dtype)
if append_fn is None:
raise TypeError(
f"Element type not supported in TensorProto: {numpy_dtype.name}.")
append_fn(tensor_proto, proto_values)
return tensor_proto
# pylint: enable=invalid-name
@tf_export("make_ndarray")
def MakeNdarray(tensor):
"""Create a numpy ndarray from a tensor.
Create a numpy ndarray with the same shape and data as the tensor.
For example:
```python
# Tensor a has shape (2,3)
a = tf.constant([[1,2,3],[4,5,6]])
proto_tensor = tf.make_tensor_proto(a) # convert `tensor a` to a proto tensor
tf.make_ndarray(proto_tensor) # output: array([[1, 2, 3],
# [4, 5, 6]], dtype=int32)
# output has shape (2,3)
```
Args:
tensor: A TensorProto.
Returns:
A numpy array with the tensor contents.
Raises:
TypeError: if tensor has unsupported type.
"""
shape = [d.size for d in tensor.tensor_shape.dim]
num_elements = np.prod(shape, dtype=np.int64)
tensor_dtype = dtypes.as_dtype(tensor.dtype)
dtype = tensor_dtype.as_numpy_dtype
if tensor.tensor_content:
return (np.frombuffer(tensor.tensor_content,
dtype=dtype).copy().reshape(shape))
if tensor_dtype == dtypes.string:
# np.pad throws on these arrays of type np.object_.
values = list(tensor.string_val)
padding = num_elements - len(values)
if padding > 0:
last = values[-1] if values else ""
values.extend([last] * padding)
return np.array(values, dtype=dtype).reshape(shape)
if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
# the half_val field of the TensorProto stores the binary representation
# of the fp16: we need to reinterpret this as a proper float16
values = np.fromiter(tensor.half_val, dtype=np.uint16)
values.dtype = tensor_dtype.as_numpy_dtype
elif tensor_dtype == dtypes.float8_e5m2 or tensor_dtype == dtypes.float8_e4m3fn:
values = np.fromiter(tensor.float8_val, dtype=np.uint8)
values.dtype = tensor_dtype.as_numpy_dtype
elif tensor_dtype == dtypes.float32:
values = np.fromiter(tensor.float_val, dtype=dtype)
elif tensor_dtype == dtypes.float64:
values = np.fromiter(tensor.double_val, dtype=dtype)
elif tensor_dtype in [
dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8,
dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16
]:
values = np.fromiter(tensor.int_val, dtype=dtype)
elif tensor_dtype == dtypes.int64:
values = np.fromiter(tensor.int64_val, dtype=dtype)
elif tensor_dtype == dtypes.uint32:
values = np.fromiter(tensor.uint32_val, dtype=dtype)
elif tensor_dtype == dtypes.uint64:
values = np.fromiter(tensor.uint64_val, dtype=dtype)
elif tensor_dtype == dtypes.complex64:
it = iter(tensor.scomplex_val)
values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
elif tensor_dtype == dtypes.complex128:
it = iter(tensor.dcomplex_val)
values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
elif tensor_dtype == dtypes.bool:
values = np.fromiter(tensor.bool_val, dtype=dtype)
else:
raise TypeError(f"Unsupported tensor type: {tensor.dtype}. See "
"https://www.tensorflow.org/api_docs/python/tf/dtypes "
"for supported TF dtypes.")
if values.size == 0:
return np.zeros(shape, dtype)
if values.size != num_elements:
values = np.pad(values, (0, num_elements - values.size), "edge")
return values.reshape(shape)
def ShapeEquals(tensor_proto, shape):
"""Returns True if "tensor_proto" has the given "shape".
Args:
tensor_proto: A TensorProto.
shape: A tensor shape, expressed as a TensorShape, list, or tuple.
Returns:
True if "tensor_proto" has the given "shape", otherwise False.
Raises:
TypeError: If "tensor_proto" is not a TensorProto, or shape is not a
TensorShape, list, or tuple.
"""
if not isinstance(tensor_proto, tensor_pb2.TensorProto):
raise TypeError("`tensor_proto` must be a tensor_pb2.TensorProto object, "
f"but got type {type(tensor_proto)}.")
if isinstance(shape, tensor_shape_pb2.TensorShapeProto):
shape = [d.size for d in shape.dim]
elif not isinstance(shape, (list, tuple)):
raise TypeError("`shape` must be a list or tuple, but got type "
f"{type(shape)}.")
tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim]
return all(x == y for x, y in zip(tensor_shape_list, shape))
def _ConstantValue(tensor, partial):
# TODO(touts): Support Variables?
if not isinstance(tensor, core.Symbol):
raise TypeError(f"{tensor!r} must be a Tensor, but got {type(tensor)}.")
if tensor.op.type == "Const":
return MakeNdarray(tensor.op.get_attr("value"))
elif tensor.op.type == "Shape":
input_shape = tensor.op.inputs[0].get_shape()
if input_shape.is_fully_defined():
return np.array(
[dim.value for dim in input_shape.dims],
dtype=tensor.dtype.as_numpy_dtype)
else:
return None
elif tensor.op.type == "Size":
input_shape = tensor.op.inputs[0].get_shape()
if input_shape.is_fully_defined():
return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32)
else:
return None
elif tensor.op.type == "Rank":
input_shape = tensor.op.inputs[0].get_shape()
if input_shape.ndims is not None:
return np.ndarray(
shape=(),
buffer=np.array([input_shape.ndims], dtype=np.int32),
dtype=np.int32)
else:
return None
elif tensor.op.type == "Range":
start = constant_value(tensor.op.inputs[0])
if start is None:
return None
limit = constant_value(tensor.op.inputs[1])
if limit is None:
return None
delta = constant_value(tensor.op.inputs[2])
if delta is None:
return None
return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype)
elif tensor.op.type == "Cast":
pre_cast = constant_value(tensor.op.inputs[0])
if pre_cast is None:
return None
cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
return pre_cast.astype(cast_dtype.as_numpy_dtype)
elif tensor.op.type == "Concat":
dim = constant_value(tensor.op.inputs[0])
if dim is None:
return None
values = []
for x in tensor.op.inputs[1:]:
value = constant_value(x)
if value is None:
return None
values.append(value)
return np.concatenate(values, axis=dim)
elif tensor.op.type == "ConcatV2":
dim = constant_value(tensor.op.inputs[-1])
if dim is None:
return None
values = []
for x in tensor.op.inputs[:-1]:
value = constant_value(x)
if value is None:
return None
values.append(value)
return np.concatenate(values, axis=dim)
elif tensor.op.type == "Pack":
values = []
# Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
# and shouldn't be produced, but to deal sensibly with them here we check
# and return None.
if not tensor.op.inputs:
return None
# We can't handle axis != 0 Packs at the moment.
if tensor.op.get_attr("axis") != 0:
return None
for x in tensor.op.inputs:
value = constant_value(x, partial)
if value is None and not partial:
return None
values.append(value)
try:
return np.array(values)
except ValueError:
# If partial=True, some of the elements of values may be None.
return np.array(values, dtype=object)
elif tensor.op.type == "Unpack":
# We can't handle axis != 0 Unpacks at the moment.
if tensor.op.get_attr("axis") != 0:
return None
value = constant_value(tensor.op.inputs[0], partial)
if value is None:
return None
return value[tensor.value_index]
elif tensor.op.type == "Split":
dim = constant_value(tensor.op.inputs[0])
value = constant_value(tensor.op.inputs[1], partial)
if value is None or dim is None:
return None
split = np.split(value, tensor.op.get_attr("num_split"), dim)
return split[tensor.value_index]
elif tensor.op.type == "Fill":
fill_shape = tensor.shape
fill_value = constant_value(tensor.op.inputs[1])
if fill_shape.is_fully_defined() and fill_value is not None:
return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype)
else:
return None
elif tensor.op.type == "Equal":
value1 = constant_value(tensor.op.inputs[0])
if value1 is None:
return None
value2 = constant_value(tensor.op.inputs[1])
if value2 is None:
return None
return np.equal(value1, value2)
elif tensor.op.type == "NotEqual":
value1 = constant_value(tensor.op.inputs[0])
if value1 is None:
return None
value2 = constant_value(tensor.op.inputs[1])
if value2 is None:
return None
return np.not_equal(value1, value2)
elif tensor.op.type == "StopGradient":
return constant_value(tensor.op.inputs[0], partial)
elif tensor.op.type in ("CheckNumericsV2", "DebugIdentityV2", "Identity"):
return constant_value(tensor.op.inputs[0], partial)
else:
return None
@tf_export("get_static_value")
def constant_value(tensor, partial=False): # pylint: disable=invalid-name
"""Returns the constant value of the given tensor, if efficiently calculable.
This function attempts to partially evaluate the given tensor, and
returns its value as a numpy ndarray if this succeeds.
Example usage:
>>> a = tf.constant(10)
>>> tf.get_static_value(a)
10
>>> b = tf.constant(20)
>>> tf.get_static_value(tf.add(a, b))
30
>>> # `tf.Variable` is not supported.
>>> c = tf.Variable(30)
>>> print(tf.get_static_value(c))
None
Using `partial` option is most relevant when calling `get_static_value` inside
a `tf.function`. Setting it to `True` will return the results but for the
values that cannot be evaluated will be `None`. For example:
```python
class Foo:
def __init__(self):
self.a = tf.Variable(1)
self.b = tf.constant(2)
@tf.function
def bar(self, partial):
packed = tf.raw_ops.Pack(values=[self.a, self.b])
static_val = tf.get_static_value(packed, partial=partial)
tf.print(static_val)
f = Foo()
f.bar(partial=True) # `array([None, array(2, dtype=int32)], dtype=object)`
f.bar(partial=False) # `None`
```
Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it
will no longer be possible to feed a different value for `tensor`. This allows
the result of this function to influence the graph that is constructed, and
permits static shape optimizations.
Args:
tensor: The Tensor to be evaluated.
partial: If True, the returned numpy array is allowed to have partially
evaluated values. Values that can't be evaluated will be None.
Returns:
A numpy ndarray containing the constant value of the given `tensor`,
or None if it cannot be calculated.
Raises:
TypeError: if tensor is not an ops.Tensor.
"""
if isinstance(tensor, core.Value):
try:
return tensor.numpy()
except errors_impl.UnimplementedError:
# Some EagerTensors may not implement .numpy/resolve, e.g. parallel
# tensors with multiple components on different devices.
return None
if not is_tensor(tensor):
return tensor
if not isinstance(tensor, core.Symbol):
return None
ret = _ConstantValue(tensor, partial)
if ret is not None:
# The caller may now depend on the constant value of `tensor`, so we
# conservatively prevent it from being fed.
tensor.graph.prevent_feeding(tensor)
return ret
def constant_value_as_shape(tensor): # pylint: disable=invalid-name
"""A version of `constant_value()` that returns a `TensorShape`.
This version should be used when a constant tensor value is
interpreted as a (possibly partial) shape, e.g. in the shape
function for `tf.reshape()`. By explicitly requesting a
`TensorShape` as the return value, it is possible to represent
unknown dimensions; by contrast, `constant_value()` is
all-or-nothing.
Args:
tensor: The rank-0 or rank-1 Tensor to be evaluated.
Returns:
A `TensorShape` based on the constant value of the given `tensor`.
Raises:
ValueError: If the shape is rank-0 and is not statically known to be -1.
"""
if isinstance(tensor, core.Value):
return tensor_shape.TensorShape(
[dim if dim != -1 else None for dim in tensor.numpy()])
if tensor.get_shape().ndims == 0:
value = constant_value(tensor)
if value is None:
raise ValueError(
"Received a scalar with unknown value as shape; require a statically "
"known scalar with value '-1' to describe an unknown shape.")
if value != -1:
raise ValueError(
f"Received a scalar value '{value}' as shape; require a statically "
"known scalar with value '-1' to describe an unknown shape.")
return tensor_shape.unknown_shape()
shape = tensor.get_shape().with_rank(1)
if shape == [0]:
return tensor_shape.TensorShape([])
elif tensor.op.type == "Cast":
pre_cast = constant_value_as_shape(tensor.op.inputs[0])
if pre_cast.dims is None:
# the input to cast has a totally undefined shape; just return that.
return pre_cast
cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))