-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patharray_grad.py
1224 lines (1002 loc) · 44.1 KB
/
array_grad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in array_ops.py."""
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices as indexed_slices_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
@ops.RegisterGradient("Pack")
def _PackGrad(op, grad):
"""Gradient for pack op."""
return array_ops_stack.unstack(
grad, num=op.get_attr("N"), axis=op.get_attr("axis"))
@ops.RegisterGradient("Unpack")
def _UnpackGrad(op, *grads):
"""Gradient for unpack op."""
return array_ops_stack.stack(grads, axis=op.get_attr("axis"))
def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
"""Gradient for concat op.
Args:
op: An operation.
grad: `Tensor` or `IndexedSlices` representing the gradients with respect to
each output of the op.
start_value_index: An integer index of the first value in the op.inputs.
end_value_index: An integer index of the last value in the op.inputs.
dim_index: An integer index of concat_dim or axis parameter in op.inputs.
Returns:
Tensors representing the partial gradients with respect to each input
of the op.
Raises:
ValueError: if concat_dim/axis is not statically known.
"""
def _CreateDenseMaskAndBegin(sizes, concat_dim):
"""Create variables for iteratively slicing a dense gradients tensor."""
# Since shape is 1-D, shape_of_shape = [rank-of-inputs]
shape_of_shape = array_ops.shape(sizes[0])
# Make a vector of length equal to the input's dimensions,
# with 0's everywhere and 1 in the concat dim position.
# Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
mask = array_ops.concat([
array_ops.zeros(
array_ops.expand_dims(concat_dim, 0), dtype=dtypes.int32), [1],
array_ops.zeros(shape_of_shape - concat_dim - 1, dtype=dtypes.int32)
], 0)
begin = array_ops.zeros(shape_of_shape, dtype=dtypes.int32)
return mask, begin
def _ExtractInputShapes(inputs):
"""Extract the shapes of a set of input tensors."""
if context.executing_eagerly():
return array_ops.shape_n(inputs)
sizes = []
fully_known = True
for x in inputs:
input_shape = array_ops.shape(x)
if not isinstance(input_shape,
ops.Tensor) or input_shape.op.type != "Const":
fully_known = False
break
sizes.append(input_shape)
if fully_known:
return sizes
else:
return array_ops.shape_n(inputs)
# Degenerate concatenation, just return grad.
if len(op.inputs) == 2:
return grad + [None] if end_value_index <= dim_index else [None] + grad
concat_dim = op.inputs[dim_index]
input_values = op.inputs[start_value_index:end_value_index]
out_grads = []
if isinstance(grad, ops.Tensor):
if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor):
# Using mod here for convenience since concat_dim is already verified
# in concat implementation to be within the allowed [-rank, rank) range.
non_neg_concat_dim = (
concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access
# All inputs are guaranteed to be EagerTensors in eager mode
sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values,
non_neg_concat_dim)
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
else:
if constant_op.is_constant(concat_dim):
# If concat_dim is a constant defined in a different context,
# then we duplicate it in the current context to avoid passing it
# through an Enter node.
# This is a small optimization in general, but it is required when
# compiling with XLA, as XLA needs the concat input to be folded into a
# constant.
grad_context = control_flow_util.GetOutputContext(grad.op)
dim_context = control_flow_util.GetOutputContext(concat_dim.op)
if dim_context != grad_context:
value = tensor_util.constant_value(concat_dim)
concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype)
# Using mod here for convenience since concat_dim is already verified
# in concat implementation to be within the allowed [-rank, rank) range.
non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
# Get the inputs' tensor shapes
sizes = _ExtractInputShapes(input_values)
# The magic number of 16 was found through benchmarking a range of sizes
# on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
# cases when switching implementations at N=16, but it is possible that
# there will be a small number of performance regressions.
if len(sizes) > 16:
# extract the size of each input along the concat dimension
sizes = array_ops.squeeze(
array_ops.slice(
array_ops_stack.stack(sizes, axis=1), [non_neg_concat_dim, 0],
[1, -1]))
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
else:
offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes)
for (begin, size) in zip(offset, sizes):
out_grads.append(array_ops.slice(grad, begin, size))
elif isinstance(grad, indexed_slices_lib.IndexedSlices):
# Using mod here for convenience since concat_dim is already verified
# in concat implementation to be within the allowed [-rank, rank) range.
non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
concat_dim_static = tensor_util.constant_value(concat_dim)
if concat_dim_static is None:
raise ValueError("Can only compute IndexedSlices gradient with "
"statically-known concat_dim")
if concat_dim_static < 0:
rank = tensor_util.constant_value(array_ops.rank(input_values[0]))
if rank is None:
raise ValueError("Can only compute IndexedSlices gradient with "
"negative concat_dim when first value rank is "
"statically-known.")
concat_dim_static %= rank
# Get the inputs' tensor shapes
sizes = [array_ops.shape(x) for x in input_values]
if concat_dim_static > 0:
# IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices
# gradients with all the indices, but with grad.values sliced accordingly.
# This is like the Tensor case, except shape(grad.values)[0] is not equal
# to shape(sizes[i])[0], since only a subset of the dim-0 values are
# stored.
mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim)
for size in sizes:
new_values = array_ops.slice(
grad.values, begin,
array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0))
out_grads.append(
indexed_slices_lib.IndexedSlices(new_values, grad.indices, size))
# Lint complains begin = begin + ...
begin = math_ops.add(begin, size * mask)
else:
# IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
# only for the relevant indices.
start = constant_op.constant(0, dtype=grad.indices.dtype)
for size in sizes:
size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
if size_concat_dim.dtype != grad.indices.dtype:
size_concat_dim = math_ops.cast(
size_concat_dim, dtype=grad.indices.dtype)
end = start + size_concat_dim
# Compute the 1-D Tensor of indices relevant for this input.
indices_to_select = array_ops.squeeze(
array_ops.where(
math_ops.logical_and(grad.indices >= start,
grad.indices < end)),
axis=[1])
new_indices = array_ops.gather(grad.indices, indices_to_select) - start
new_values = array_ops.gather(grad.values, indices_to_select)
out_grads.append(
indexed_slices_lib.IndexedSlices(new_values, new_indices, size))
start = end
else:
raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
return (out_grads + [None] if end_value_index <= dim_index else [None] +
out_grads)
@ops.RegisterGradient("Concat")
def _ConcatGrad(op, grad):
return _ConcatGradHelper(
op,
grad,
start_value_index=1,
end_value_index=len(op.inputs),
dim_index=0)
@ops.RegisterGradient("ConcatV2")
def _ConcatGradV2(op, grad):
return _ConcatGradHelper(
op, grad, start_value_index=0, end_value_index=-1, dim_index=-1)
ops.NotDifferentiable("ConcatOffset")
@ops.RegisterGradient("Slice")
def _SliceGrad(op, grad):
"""Gradient for Slice op."""
# Create an Nx2 padding where the first column represents how many
# zeros are to be prepended for each dimension, and the second
# column indicates how many zeros are appended.
#
# The number of zeros to append is the shape of the input
# elementwise-subtracted by both the begin vector and sizes vector.
#
# Some more reshaping is needed to assemble this tensor with the
# right dimensions.
input_vec = op.inputs[0]
begin_vec = op.inputs[1]
input_rank = array_ops.rank(input_vec)
index_dtype = begin_vec.dtype
slice_size = array_ops.shape(op.outputs[0], out_type=index_dtype)
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec),
grad, begin_vec), None, None
shape = array_ops_stack.stack([input_rank, 1])
before_pad = array_ops.reshape(begin_vec, shape)
after_pad = array_ops.reshape(
array_ops.shape(input_vec, out_type=index_dtype) - slice_size - begin_vec,
shape)
paddings = array_ops.concat([before_pad, after_pad], 1)
return array_ops.pad(grad, paddings), None, None
@ops.RegisterGradient("StridedSlice")
def _StridedSliceGrad(op, grad):
"""Gradient for StridedSlice op."""
begin = op.inputs[1]
end = op.inputs[2]
strides = op.inputs[3]
# StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the
# same dtype so we build a shape of the same type as other args.
# Note that the choice of `begin` for specifying `out_type` is arbitrary.
# We could choose any of {begin|end|strides}.dtype since they are required to
# be the same.
x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
x_static = tensor_util.constant_value(x)
x = x_static if x_static is not None else x
begin_static = tensor_util.constant_value(begin)
begin = begin_static if begin_static is not None else begin
end_static = tensor_util.constant_value(end)
end = end_static if end_static is not None else end
strides_static = tensor_util.constant_value(strides)
strides = strides_static if strides_static is not None else strides
return array_ops.strided_slice_grad(
x,
begin,
end,
strides,
grad,
begin_mask=op.get_attr("begin_mask"),
end_mask=op.get_attr("end_mask"),
ellipsis_mask=op.get_attr("ellipsis_mask"),
new_axis_mask=op.get_attr("new_axis_mask"),
shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None
@ops.RegisterGradient("StridedSliceGrad")
def _StridedSliceGradGrad(op, grad):
"""Gradient for StridedSliceGrad op."""
begin = op.inputs[1]
end = op.inputs[2]
strides = op.inputs[3]
return None, None, None, None, array_ops.strided_slice(
grad,
begin,
end,
strides,
begin_mask=op.get_attr("begin_mask"),
end_mask=op.get_attr("end_mask"),
ellipsis_mask=op.get_attr("ellipsis_mask"),
new_axis_mask=op.get_attr("new_axis_mask"),
shrink_axis_mask=op.get_attr("shrink_axis_mask"))
@ops.RegisterGradient("TensorStridedSliceUpdate")
def _TensorStridedSliceUpdateGrad(op, grad): # pylint:disable=missing-function-docstring
begin = op.inputs[1]
end = op.inputs[2]
strides = op.inputs[3]
begin_mask = op.get_attr("begin_mask")
end_mask = op.get_attr("end_mask")
ellipsis_mask = op.get_attr("ellipsis_mask")
new_axis_mask = op.get_attr("new_axis_mask")
shrink_axis_mask = op.get_attr("shrink_axis_mask")
def Apply(f, *args):
return f(*args,
begin_mask=begin_mask,
end_mask=end_mask,
shrink_axis_mask=shrink_axis_mask,
new_axis_mask=new_axis_mask,
ellipsis_mask=ellipsis_mask)
dy = Apply(array_ops.strided_slice,
grad, begin, end, strides)
dx = Apply(array_ops.tensor_strided_slice_update,
grad, begin, end, strides, array_ops.zeros_like(dy))
# The value is potentially broadcast to the shape of the strided slice, so we
# may need to adjust dy.
slice_shape = array_ops.shape(dy, out_type=begin.dtype)
value_shape = array_ops.shape(op.inputs[4], out_type=slice_shape.dtype)
_, reduction_axes = gen_array_ops.broadcast_gradient_args(
slice_shape, value_shape)
dy_reshaped = math_ops.reduce_sum(dy, axis=reduction_axes, keepdims=True)
dy = array_ops.reshape(dy_reshaped, value_shape)
return dx, None, None, None, dy
@ops.RegisterGradient("Split")
def _SplitGrad(op, *grads):
return None, array_ops.concat(list(grads), op.inputs[0])
@ops.RegisterGradient("SplitV")
def _SplitVGrad(op, *grads):
returnval = array_ops.concat(list(grads), op.inputs[2])
returnval = [returnval] + [
None,
] * (
len(op.inputs) - 1)
return returnval
ops.NotDifferentiable("Const")
@ops.RegisterGradient("Diag")
def _DiagGrad(_, grad):
return array_ops.diag_part(grad)
@ops.RegisterGradient("DiagPart")
def _DiagPartGrad(_, grad):
return array_ops.diag(grad)
@ops.RegisterGradient("MatrixDiag")
def _MatrixDiagGrad(_, grad):
return array_ops.matrix_diag_part(grad)
@ops.RegisterGradient("MatrixDiagV2")
def _MatrixDiagV2Grad(op, grad):
return array_ops.matrix_diag_part(
grad, k=op.inputs[1]), None, None, None, None
@ops.RegisterGradient("MatrixDiagV3")
def _MatrixDiagV3Grad(op, grad):
return array_ops.matrix_diag_part(
grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None
@ops.RegisterGradient("MatrixDiagPart")
def _MatrixDiagPartGrad(op, grad):
matrix_shape = op.inputs[0].get_shape()[-2:]
if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]:
return array_ops.matrix_diag(grad)
else:
return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad)
@ops.RegisterGradient("MatrixDiagPartV2")
def _MatrixDiagPartV2Grad(op, grad):
"""Gradient for MatrixDiagPartV2."""
matrix_shape = op.inputs[0].get_shape()[-2:]
if matrix_shape.is_fully_defined():
return array_ops.matrix_diag(
grad,
k=op.inputs[1],
num_rows=matrix_shape[0],
num_cols=matrix_shape[1]), None, None
else:
return array_ops.matrix_set_diag(
array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None
@ops.RegisterGradient("MatrixDiagPartV3")
def _MatrixDiagPartV3Grad(op, grad):
"""Gradient for MatrixDiagPartV3."""
matrix_shape = op.inputs[0].get_shape()[-2:]
align = op.get_attr("align")
if matrix_shape.is_fully_defined():
return array_ops.matrix_diag(
grad,
k=op.inputs[1],
num_rows=matrix_shape[0],
num_cols=matrix_shape[1],
align=align), None, None
else:
return array_ops.matrix_set_diag(
array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1],
align=align), None, None
@ops.RegisterGradient("MatrixSetDiag")
def _MatrixSetDiagGrad(op, grad):
"""Gradient for MatrixSetDiag."""
input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
diag_shape = op.inputs[1].get_shape()
batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
matrix_shape = input_shape[-2:]
if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
else:
with ops.colocate_with(grad):
grad_shape = array_ops.shape(grad)
grad_rank = array_ops.rank(grad)
batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
min_dim = math_ops.reduce_min(matrix_shape)
diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
grad_input = array_ops.matrix_set_diag(
grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
grad_diag = array_ops.matrix_diag_part(grad)
return (grad_input, grad_diag)
@ops.RegisterGradient("MatrixSetDiagV2")
def _MatrixSetDiagGradV2(op, grad):
"""Gradient for MatrixSetDiagV2."""
diag_shape = op.inputs[1].get_shape()
if not diag_shape.is_fully_defined():
# Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
grad_shape = array_ops.shape(grad)
batch_shape = grad_shape[:-2]
matrix_shape = grad_shape[-2:]
diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector.
d_lower = diag_index[0]
d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2.
y_offset = cond.cond(
math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
x_offset = cond.cond(
math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
matrix_shape[1] + x_offset)
# pylint: disable=g-long-lambda
# pyformat: disable
postfix = cond.cond(
math_ops.equal(d_lower, d_upper),
lambda: ops.convert_to_tensor([max_diag_len]),
lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
max_diag_len]))
# pyformat: enable
# pylint: enable=g-long-lambda
diag_shape = array_ops.concat([batch_shape, postfix], 0)
grad_input = array_ops.matrix_set_diag(
grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2])
grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2])
return (grad_input, grad_diag, None)
@ops.RegisterGradient("MatrixSetDiagV3")
def _MatrixSetDiagGradV3(op, grad):
"""Gradient for MatrixSetDiagV3."""
diag_shape = op.inputs[1].get_shape()
align = op.get_attr("align")
if not diag_shape.is_fully_defined():
# Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
grad_shape = array_ops.shape(grad)
batch_shape = grad_shape[:-2]
matrix_shape = grad_shape[-2:]
diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector.
d_lower = diag_index[0]
d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2.
y_offset = cond.cond(
math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
x_offset = cond.cond(
math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
matrix_shape[1] + x_offset)
# pylint: disable=g-long-lambda
# pyformat: disable
postfix = cond.cond(
math_ops.equal(d_lower, d_upper),
lambda: ops.convert_to_tensor([max_diag_len]),
lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
max_diag_len]))
# pyformat: enable
# pylint: enable=g-long-lambda
diag_shape = array_ops.concat([batch_shape, postfix], 0)
grad_input = array_ops.matrix_set_diag(
grad,
array_ops.zeros(diag_shape, dtype=grad.dtype),
k=op.inputs[2],
align=align)
grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align)
return (grad_input, grad_diag, None)
@ops.RegisterGradient("MatrixBandPart")
def _MatrixBandPartGrad(op, grad):
num_lower = op.inputs[1]
num_upper = op.inputs[2]
return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None)
# Edit Distance has no gradient (but can be used to eval seq2seq or CTC).
ops.NotDifferentiable("EditDistance")
@ops.RegisterGradient("Fill")
def _FillGrad(_, grad):
return None, math_ops.reduce_sum(grad)
ops.NotDifferentiable("ZerosLike")
ops.NotDifferentiable("OnesLike")
@ops.RegisterGradient("PreventGradient")
def _PreventGradientGrad(op, _):
raise LookupError("Gradient explicitly disabled. Reason: %s" %
op.get_attr("message"))
def _IndexedSlicesToTensorNoWarning(indexed_slices):
"""Converts an IndexedSlices to a Tensor without sparse->dense warnings."""
if not isinstance(indexed_slices, indexed_slices_lib.IndexedSlices):
# If it is not IndexedSlices, it's better be a tensor.
return indexed_slices
if indexed_slices.dense_shape is None:
raise ValueError(
"Tensor conversion requested for IndexedSlices without dense_shape: %s"
% str(indexed_slices))
return math_ops.unsorted_segment_sum(indexed_slices.values,
indexed_slices.indices,
indexed_slices.dense_shape[0])
@ops.RegisterGradient("Gather")
def _GatherGrad(op, grad):
"""Gradient for Gather op."""
# params can be large, so colocate the shape calculation with it.
params = op.inputs[0]
with ops.colocate_with(params):
params_shape = array_ops.shape(params)
# Build appropriately shaped IndexedSlices
indices = op.inputs[1]
size = array_ops.expand_dims(array_ops.size(indices), 0)
values_shape = array_ops.concat([size, params_shape[1:]], 0)
values = array_ops.reshape(
_IndexedSlicesToTensorNoWarning(grad), values_shape)
indices = array_ops.reshape(indices, size)
return [indexed_slices_lib.IndexedSlices(values, indices, params_shape), None]
def _GetBatchIndices(params_shape, indices, batch_dims):
"""Addds the batch offsets to the given indices and returns the results."""
batch_indices = indices
indices_dtype = indices.dtype.base_dtype
casted_params_shape = math_ops.cast(params_shape, indices_dtype)
accum_dim_value = array_ops.ones((), dtype=indices_dtype)
for dim in range(batch_dims, 0, -1):
dim_value = casted_params_shape[dim - 1]
accum_dim_value *= casted_params_shape[dim]
start = array_ops.zeros((), dtype=indices_dtype)
step = array_ops.ones((), dtype=indices_dtype)
dim_indices = math_ops.range(start, dim_value, step)
dim_indices *= accum_dim_value
dim_shape = array_ops.concat([
array_ops.tile([1], [dim - 1]), [dim_value],
array_ops.tile([1], [array_ops.rank(indices) - dim])
], axis=0)
batch_indices += array_ops.reshape(dim_indices, dim_shape)
return batch_indices
def _BatchGatherGrad(params_shape, values, indices, batch_dims,
gather_dim_size):
"""Returns the gradient of GatherV2 with batch dimensions."""
# Axis is the first non-batch dimension.
indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
if batch_dims:
values_shape = array_ops.shape(values)
# Add the batch offsets to indices and flatten the batch dimensions.
outer_shape = values_shape[:batch_dims]
inner_shape = values_shape[batch_dims:][1:]
batch_size = gen_math_ops.prod(outer_shape, [0], False)
flat_values_shape = array_ops.concat([[-1], inner_shape], 0)
gather_dim_size *= batch_size
indices = _GetBatchIndices(params_shape, indices, batch_dims)
values = array_ops.reshape(
_IndexedSlicesToTensorNoWarning(values), flat_values_shape)
indices = array_ops.reshape(indices, indices_size)
params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size)
if batch_dims:
# Put back the batch dimensions.
params_grad = array_ops.reshape(
params_grad, array_ops.concat([outer_shape, flat_values_shape], 0))
return params_grad
@ops.RegisterGradient("GatherV2")
def _GatherV2Grad(op, grad):
"""Gradient for GatherV2 op."""
# params can be large, so colocate the shape calculation with it.
#
# params can be very large for sparse model, array_ops.shape raises
# exception on the Windows platform when any dimension is larger than
# int32. params_shape is not used in optimizer apply_sparse gradients,
# so it's fine to convert it back to int32 regardless of truncation.
params = op.inputs[0]
with ops.colocate_with(params):
params_shape = array_ops.shape(params, out_type=ops.dtypes.int64)
params_shape = math_ops.cast(params_shape, dtypes.int32)
indices = op.inputs[1]
indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
axis = op.inputs[2]
axis_static = tensor_util.constant_value(axis)
batch_dims = int(op.get_attr("batch_dims"))
if batch_dims < 0:
if indices.shape.ndims is None:
raise ValueError(
f"Currently, it is unsupported to take the gradient of tf.gather "
f"when batch_dims < 0 and the rank of the indices is unknown. Please "
f"pass a positive batch_dims or use tf.ensure_shape to update the "
f"shape of indices when calling tf.gather. Got "
f"batch_dims={batch_dims} and indices={indices}")
batch_dims += indices.shape.ndims
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
if axis_static == 0:
if context.executing_eagerly():
with ops.device(indices_size.device):
params_tail_shape = array_ops.identity(params_shape)[1:]
else:
params_tail_shape = params_shape[1:]
values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
values = array_ops.reshape(
_IndexedSlicesToTensorNoWarning(grad), values_shape)
indices = array_ops.reshape(indices, indices_size)
params_grad = indexed_slices_lib.IndexedSlices(values, indices,
params_shape)
else:
# Handle axis by transposing the axis dimension to be the first non-batch
# dimension, compute the gradient and transpose the result back.
outer_shape = params_shape[:axis]
inner_shape = params_shape[axis:][1:]
values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0)
values_dims = array_ops.size(values_shape)
axis_dims = array_ops.size(outer_shape)
outer_batches_indices = math_ops.range(batch_dims)
batch_axis_indices = math_ops.range(batch_dims, axis_dims)
inner_axes_indices = math_ops.range(axis_dims + 1, values_dims)
values = array_ops.reshape(
_IndexedSlicesToTensorNoWarning(grad), values_shape)
# Move values[axis] up to values[batch_dims]
transpose_dims = array_ops.concat([
outer_batches_indices, [axis_dims], batch_axis_indices,
inner_axes_indices
], 0)
values_transpose = array_ops.transpose(values, transpose_dims)
params_shape_transpose = array_ops.gather(params_shape, transpose_dims)
params_grad = _BatchGatherGrad(params_shape_transpose, values_transpose,
indices, batch_dims, params_shape[axis])
# Inverts the above transpose by moving dimension batch_dims back to its
# original position.
invert_transpose_dims = array_ops.concat([
outer_batches_indices, batch_axis_indices + 1, [batch_dims],
inner_axes_indices
], 0)
params_grad = array_ops.transpose(params_grad, invert_transpose_dims)
return [params_grad, None, None]
@ops.RegisterGradient("GatherNd")
def _GatherNdGrad(op, grad):
ref = op.inputs[0]
indices = op.inputs[1]
ref_shape = array_ops.shape(ref, out_type=indices.dtype)
if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
ref_grad = indexed_slices_lib.IndexedSlices(
grad, array_ops.squeeze(indices, axis=-1), ref_shape)
else:
ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
return [ref_grad, None]
@ops.RegisterGradient("ResourceGatherNd")
def _ResourceGatherNdGrad(op, grad): # pylint: disable=missing-docstring
ref = op.inputs[0]
indices = op.inputs[1]
ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype)
if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
ref_grad = indexed_slices_lib.IndexedSlices(
grad, array_ops.squeeze(indices, axis=-1), ref_shape)
else:
ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
return [ref_grad, None]
@ops.RegisterGradient("CheckNumerics")
def _CheckNumericsGrad(op, grad):
"""Gradient for check_numerics op."""
return array_ops.check_numerics(
grad,
"Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
op.get_attr("message"))
@ops.RegisterGradient("CheckNumericsV2")
def _CheckNumericsV2Grad(op, grad):
"""Gradient for check_numerics op."""
return array_ops.check_numerics_v2(
grad,
"Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
op.get_attr("message"))
@ops.RegisterGradient("PlaceholderWithDefault")
@ops.RegisterGradient("Identity")
def _IdGrad(_, grad):
return grad
@ops.RegisterGradient("_EagerConst")
def _EagerConstGrad(_, grad):
raise AssertionError(
"This op should never interact with gradient APIs. Please file a bug.")
@ops.RegisterGradient("RefIdentity")
def _RefIdGrad(_, grad):
return grad
@ops.RegisterGradient("IdentityN")
def _IdNGrad(_, *grad):
return grad
ops.NotDifferentiable("StopGradient")
@ops.RegisterGradient("Reshape")
def _ReshapeGrad(op, grad):
return [
array_ops.reshape(
_IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])),
None
]
ops.NotDifferentiable("InvertPermutation")
def _ReshapeToInput(op, grad):
"""Reshapes the gradient to the shape of the original input."""
return array_ops.reshape(
_IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0]))
@ops.RegisterGradient("ExpandDims")
def _ExpandDimsGrad(op, grad):
return [_ReshapeToInput(op, grad), None]
@ops.RegisterGradient("Squeeze")
def _SqueezeGrad(op, grad):
return _ReshapeToInput(op, grad)
@ops.RegisterGradient("Transpose")
def _TransposeGrad(op, grad):
"""Returns unshuffle(grad)."""
p = op.inputs[1]
return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
@ops.RegisterGradient("ConjugateTranspose")
def _ConjugateTransposeGrad(op, grad):
"""Returns conj(unshuffle(grad))."""
p = op.inputs[1]
return [
array_ops.transpose(
grad, array_ops.invert_permutation(p), conjugate=True), None
]
ops.NotDifferentiable("Shape")
ops.NotDifferentiable("ShapeN")
ops.NotDifferentiable("Rank")
ops.NotDifferentiable("Size")
@ops.RegisterGradient("Tile")
def _TileGrad(op, grad):
"""Sum reduces grad along the tiled dimensions."""
input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype)
# We interleave multiples and input_shape to get split_shape,
# reshape grad to split_shape, and reduce along all even
# dimensions (the tiled dimensions) to get the result
# with shape input_shape. For example
# input_shape = [20, 30, 40]
# multiples = [2, 3, 4]
# split_shape = [2, 20, 3, 30, 4, 40]
# axes = [0, 2, 4]
split_shape = array_ops.reshape(
array_ops.transpose(array_ops_stack.stack([op.inputs[1], input_shape])),
[-1])
axes = math_ops.range(0, array_ops.size(split_shape), 2)
# Sum reduces grad along the first dimension for IndexedSlices
if isinstance(grad, indexed_slices_lib.IndexedSlices):
input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
grad = math_ops.unsorted_segment_sum(
grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
# Fix shape inference
if not context.executing_eagerly():
input_grad.set_shape(op.inputs[0].get_shape())
return [input_grad, None]
ops.NotDifferentiable("BroadcastGradientArgs")
def _PadGrad(op, grad):
"""Gradient for Pad."""
# Pad introduces values around the original tensor, so the gradient function
# slices the original shape out of the gradient."""
x = op.inputs[0]
a = op.inputs[1] # [Rank(x), 2]
# Takes a slice of a. The 1st column. [Rank(x), 1].
pad_before = array_ops.slice(a, [0, 0],
array_ops_stack.stack([array_ops.rank(x), 1]))
# Make it a 1-D tensor.
begin = array_ops.reshape(pad_before, [-1])
sizes = array_ops.shape(x, out_type=begin.dtype)
x_grad = array_ops.slice(grad, begin, sizes)
if len(op.inputs) == 3:
return x_grad, None, None
else:
return x_grad, None
ops.RegisterGradient("Pad")(_PadGrad)
ops.RegisterGradient("PadV2")(_PadGrad)
# ReverseSequence is just a permutation. The gradient permutes back.
@ops.RegisterGradient("ReverseSequence")
def _ReverseSequenceGrad(op, grad):
seq_lengths = op.inputs[1]
return [
array_ops.reverse_sequence(
grad,
batch_axis=op.get_attr("batch_dim"),
seq_axis=op.get_attr("seq_dim"),
seq_lengths=seq_lengths), None
]
@ops.RegisterGradient("Reverse")
def _ReverseGrad(op, grad):
reverse_dims = op.inputs[1]
return gen_array_ops.reverse(grad, reverse_dims), None
@ops.RegisterGradient("ReverseV2")
def _ReverseV2Grad(op, grad):
axis = op.inputs[1]
return array_ops.reverse_v2(grad, axis), None
@ops.RegisterGradient("SpaceToBatch")
def _SpaceToBatchGrad(op, grad):
# Its gradient is the opposite op: BatchToSpace.
block_size = op.get_attr("block_size")
return [
array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None
]
@ops.RegisterGradient("SpaceToBatchND")
def _SpaceToBatchNDGrad(op, grad):
# Its gradient is the opposite op: BatchToSpaceND.
return [
array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None
]
@ops.RegisterGradient("BatchToSpace")
def _BatchToSpaceGrad(op, grad):
# Its gradient is the opposite op: SpaceToBatch.
block_size = op.get_attr("block_size")
return [
array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None
]
@ops.RegisterGradient("BatchToSpaceND")
def _BatchToSpaceNDGrad(op, grad):
# Its gradient is the opposite op: SpaceToBatchND.
return [
array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None
]
@ops.RegisterGradient("SpaceToDepth")
def _SpaceToDepthGrad(op, grad):
# Its gradient is the opposite op: DepthToSpace.
block_size = op.get_attr("block_size")
data_format = op.get_attr("data_format")
if data_format == "NCHW_VECT_C":
raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. "
"NCHW_VECT_C requires qint8 data type.")
return array_ops.depth_to_space(grad, block_size, data_format=data_format)
@ops.RegisterGradient("DepthToSpace")
def _DepthToSpaceGrad(op, grad):
# Its gradient is the opposite op: SpaceToDepth.
block_size = op.get_attr("block_size")
data_format = op.get_attr("data_format")
if data_format == "NCHW_VECT_C":
raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. "
"NCHW_VECT_C requires qint8 data type.")
return array_ops.space_to_depth(grad, block_size, data_format=data_format)
ops.NotDifferentiable("OneHot")
@ops.RegisterGradient("MirrorPad")
def _MirrorPadGrad(op, grad):
mode = op.get_attr("mode")
return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None]