/
nn_ops.cc
2505 lines (2209 loc) · 95.3 KB
/
nn_ops.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/util/mirror_pad_mode.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
namespace {
// A shape function that uses the tensor value at <input_idx> as a shape for
// output 0. If the tensor value is not available, it uses a shape with <ndims>
// unknown dims.
Status InputTensorShapeOrUnknown(InferenceContext* c, int input_idx,
int ndims) {
ShapeHandle out;
const Tensor* input = c->input_tensor(input_idx);
if (input == nullptr) {
out = c->UnknownShapeOfRank(ndims);
} else {
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(input_idx, &out));
}
c->set_output(0, out);
return Status::OK();
}
Status FractionalPoolShapeFn(InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
std::vector<float> pooling_ratio;
TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio));
if (pooling_ratio.size() != 4) {
return errors::InvalidArgument(
"pooling_ratio field must specify 4 dimensions");
}
std::vector<DimensionHandle> output_dims;
for (int i = 0; i < 4; ++i) {
DimensionHandle d = c->Dim(input, i);
if (c->ValueKnown(d)) {
// This must match the same logic in the kernel function in
// core/kernels/fractional_max_pool_op.cc.
auto val = static_cast<int64>(floor(c->Value(d) / pooling_ratio[i]));
if (val < 0) {
return errors::InvalidArgument("Size computed for dim ", i,
" is negative: ", val);
}
output_dims.push_back(c->MakeDim(val));
} else {
output_dims.push_back(c->UnknownDim());
}
}
c->set_output(0, c->MakeShape(output_dims));
c->set_output(1, c->Vector(output_dims[1]));
c->set_output(2, c->Vector(output_dims[2]));
return Status::OK();
}
} // namespace
// --------------------------------------------------------------------------
REGISTER_OP("AvgPool")
.Input("value: T")
.Output("output: T")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {float, half, double}")
.SetShapeFn(shape_inference::AvgPoolShape)
.Doc(R"doc(
Performs average pooling on the input.
Each entry in `output` is the mean of the corresponding size `ksize`
window in `value`.
value: 4-D with shape `[batch, height, width, channels]`.
ksize: The size of the sliding window for each dimension of `value`.
strides: The stride of the sliding window for each dimension of `value`.
padding: The type of padding algorithm to use.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, in_height, in_width, in_channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
output: The average pooled output tensor.
)doc");
REGISTER_OP("AvgPoolGrad")
.Input("orig_input_shape: int32")
.Input("grad: T")
.Output("output: T")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {float, half, double}")
.SetShapeFn([](InferenceContext* c) {
// NOTE(mrry): We could in principle work out the shape from the
// gradients and the attrs, but if we do not know orig_input_shape
// statically, then we are unlikely to know the shape of the
// gradients either.
return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
})
.Doc(R"doc(
Computes gradients of the average pooling function.
orig_input_shape: 1-D. Shape of the original input to `avg_pool`.
grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t.
the output of `avg_pool`.
ksize: The size of the sliding window for each dimension of the input.
strides: The stride of the sliding window for each dimension of the input.
padding: The type of padding algorithm to use.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, in_height, in_width, in_channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
output: 4-D. Gradients w.r.t. the input of `avg_pool`.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("BatchNormWithGlobalNormalization")
.Input("t: T")
.Input("m: T")
.Input("v: T")
.Input("beta: T")
.Input("gamma: T")
.Output("result: T")
.Attr("T: numbertype")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
.Deprecated(9, "Use tf.nn.batch_normalization()")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
DimensionHandle last_dim = c->Dim(input, 3);
for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
ShapeHandle vec;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
}
ShapeHandle out;
TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
c->set_output(0, out);
return Status::OK();
})
.Doc(R"doc(
Batch normalization.
This op is deprecated. Prefer `tf.nn.batch_normalization`.
t: A 4D input Tensor.
m: A 1D mean Tensor with size matching the last dimension of t.
This is the first output from tf.nn.moments,
or a saved moving average thereof.
v: A 1D variance Tensor with size matching the last dimension of t.
This is the second output from tf.nn.moments,
or a saved moving average thereof.
beta: A 1D beta Tensor with size matching the last dimension of t.
An offset to be added to the normalized tensor.
gamma: A 1D gamma Tensor with size matching the last dimension of t.
If "scale_after_normalization" is true, this tensor will be multiplied
with the normalized tensor.
variance_epsilon: A small float number to avoid dividing by 0.
scale_after_normalization: A bool indicating whether the resulted tensor
needs to be multiplied with gamma.
)doc");
REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
.Input("t: T")
.Input("m: T")
.Input("v: T")
.Input("gamma: T")
.Input("backprop: T")
.Output("dx: T")
.Output("dm: T")
.Output("dv: T")
.Output("db: T")
.Output("dg: T")
.Attr("T: numbertype")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
.Deprecated(9, "Use tf.nn.batch_normalization()")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
TF_RETURN_IF_ERROR(
c->Merge(input, c->input(4), &input)); // with backprop
DimensionHandle last_dim = c->Dim(input, 3);
for (int i = 1; i < 4; ++i) { // covers m, v, gamma
ShapeHandle vec;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
}
ShapeHandle dx;
TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx));
c->set_output(0, dx);
ShapeHandle vector_shape = c->Vector(last_dim);
c->set_output(1, vector_shape);
c->set_output(2, vector_shape);
c->set_output(3, vector_shape);
c->set_output(4, vector_shape);
return Status::OK();
})
.Doc(R"doc(
Gradients for batch normalization.
This op is deprecated. See `tf.nn.batch_normalization`.
t: A 4D input Tensor.
m: A 1D mean Tensor with size matching the last dimension of t.
This is the first output from tf.nn.moments,
or a saved moving average thereof.
v: A 1D variance Tensor with size matching the last dimension of t.
This is the second output from tf.nn.moments,
or a saved moving average thereof.
gamma: A 1D gamma Tensor with size matching the last dimension of t.
If "scale_after_normalization" is true, this Tensor will be multiplied
with the normalized Tensor.
backprop: 4D backprop Tensor.
variance_epsilon: A small float number to avoid dividing by 0.
scale_after_normalization: A bool indicating whether the resulted tensor
needs to be multiplied with gamma.
dx: 4D backprop tensor for input.
dm: 1D backprop tensor for mean.
dv: 1D backprop tensor for variance.
db: 1D backprop tensor for beta.
dg: 1D backprop tensor for gamma.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("FusedBatchNorm")
.Input("x: T")
.Input("scale: T")
.Input("offset: T")
.Input("mean: T")
.Input("variance: T")
.Output("y: T")
.Output("batch_mean: T")
.Output("batch_variance: T")
.Output("reserve_space_1: T")
.Output("reserve_space_2: T")
.Attr("T: numbertype")
.Attr("epsilon: float = 0.0001")
.Attr("data_format: string = 'NHWC'")
.Attr("is_training: bool = true")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
bool is_training;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
int number_inputs = (is_training) ? 3 : 5;
string data_format;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
DimensionHandle channel_dim =
(data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
// covers scale, offset, and if is_training is false, mean, variance
for (int i = 1; i < number_inputs; ++i) {
ShapeHandle vec;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
}
ShapeHandle y;
if (data_format == "NHWC") {
TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
} else {
TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
}
c->set_output(0, y);
ShapeHandle vector_shape = c->Vector(channel_dim);
c->set_output(1, vector_shape);
c->set_output(2, vector_shape);
c->set_output(3, vector_shape);
c->set_output(4, vector_shape);
return Status::OK();
})
.Doc(R"doc(
Batch normalization.
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
The size of 1D Tensors matches the dimension C of the 4D Tensors.
x: A 4D Tensor for input data.
scale: A 1D Tensor for scaling factor, to scale the normalized x.
offset: A 1D Tensor for offset, to shift to the normalized x.
mean: A 1D Tensor for population mean. Used for inference only;
must be empty for training.
variance: A 1D Tensor for population variance. Used for inference only;
must be empty for training.
y: A 4D Tensor for output data.
batch_mean: A 1D Tensor for the computed batch mean, to be used by TensorFlow
to compute the running mean.
batch_variance: A 1D Tensor for the computed batch variance, to be used by
TensorFlow to compute the running variance.
reserve_space_1: A 1D Tensor for the computed batch mean, to be reused
in the gradient computation.
reserve_space_2: A 1D Tensor for the computed batch variance (inverted variance
in the cuDNN case), to be used in the gradient computation.
T: The data type for the elements of input and output Tensors.
epsilon: A small float number added to the variance of x.
data_format: The data format for x and y. Either "NHWC" (default) or "NCHW".
is_training: A bool value to indicate the operation is for training (default)
or inference.
)doc");
REGISTER_OP("FusedBatchNormGrad")
.Input("y_backprop: T")
.Input("x: T")
.Input("scale: T")
.Input("reserve_space_1: T")
.Input("reserve_space_2: T")
.Output("x_backprop: T")
.Output("scale_backprop: T")
.Output("offset_backprop: T")
.Output("reserve_space_3: T")
.Output("reserve_space_4: T")
.Attr("T: numbertype")
.Attr("epsilon: float = 0.0001")
.Attr("data_format: string = 'NHWC'")
.Attr("is_training: bool = true")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle y_backprop;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
ShapeHandle x;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
bool is_training;
string data_format;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
DimensionHandle channel_dim = (data_format == "NHWC")
? c->Dim(y_backprop, 3)
: c->Dim(y_backprop, 1);
if (data_format == "NHWC") {
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
} else {
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
}
// covers scale, mean (reserve_space_1), variance (reserve_space_2)
for (int i = 2; i < 5; ++i) {
ShapeHandle vec;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
}
ShapeHandle x_backprop;
if (data_format == "NHWC") {
TF_RETURN_IF_ERROR(
c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
} else {
TF_RETURN_IF_ERROR(
c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
}
c->set_output(0, x_backprop);
c->set_output(1, c->Vector(channel_dim));
c->set_output(2, c->Vector(channel_dim));
// Set the correct shapes for reserve_spaces
// so that gradients can be performed when
// the op is in a symbolic condition.
if (is_training) {
c->set_output(3, c->Vector(0));
c->set_output(4, c->Vector(0));
} else {
c->set_output(3, c->Vector(channel_dim));
c->set_output(4, c->Vector(channel_dim));
}
return Status::OK();
})
.Doc(R"doc(
Gradient for batch normalization.
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
The size of 1D Tensors matches the dimension C of the 4D Tensors.
y_backprop: A 4D Tensor for the gradient with respect to y.
x: A 4D Tensor for input data.
scale: A 1D Tensor for scaling factor, to scale the normalized x.
reserve_space_1: A 1D Tensor for the computed batch mean, to be reused
in the gradient computation.
reserve_space_2: A 1D Tensor for the computed batch variance (inverted variance
in the cuDNN case), to be used in the gradient computation.
x_backprop: A 4D Tensor for the gradient with respect to x.
scale_backprop: A 1D Tensor for the gradient with respect to scale.
offset_backprop: A 1D Tensor for the gradient with respect to offset.
reserve_space_3: Unused placeholder to match the mean input in FusedBatchNorm.
reserve_space_4: Unused placeholder to match the variance input
in FusedBatchNorm.
T: The data type for the elements of input and output Tensors.
epsilon: A small float number added to the variance of x.
data_format: The data format for y_backprop, x, x_backprop.
Either "NHWC" (default) or "NCHW".
is_training: A bool value to indicate the operation is for training (default)
or inference.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("BiasAdd")
.Attr("T: numbertype")
.Input("value: T")
.Input("bias: T")
.Attr(GetConvnetDataFormatAttrString())
.Output("output: T")
.SetShapeFn(shape_inference::BiasAddShape)
.Doc(R"doc(
Adds `bias` to `value`.
This is a special case of `tf.add` where `bias` is restricted to be 1-D.
Broadcasting is supported, so `value` may have any number of dimensions.
value: Any number of dimensions.
bias: 1-D with size the last dimension of `value`.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the bias tensor will be added to the last dimension
of the value tensor.
Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
The tensor will be added to "in_channels", the third-to-the-last
dimension.
output: Broadcasted sum of `value` and `bias`.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("BiasAddGrad")
.Attr("T: numbertype")
.Input("out_backprop: T")
.Attr(GetConvnetDataFormatAttrString())
.Output("output: T")
.SetShapeFn(shape_inference::BiasAddGradShape)
.Doc(R"doc(
The backward operation for "BiasAdd" on the "bias" tensor.
It accumulates all the values from out_backprop into the feature dimension.
For NHWC data format, the feature dimension is the last. For NCHW data format,
the feature dimension is the third-to-last.
out_backprop: Any number of dimensions.
output: 1-D with size the feature dimension of `out_backprop`.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the bias tensor will be added to the last dimension
of the value tensor.
Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
The tensor will be added to "in_channels", the third-to-the-last
dimension.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("BiasAddV1")
.Attr("T: numbertype")
.Input("value: T")
.Input("bias: T")
.Output("output: T")
.SetShapeFn(shape_inference::BiasAddShape)
.Doc(R"doc(
Adds `bias` to `value`.
This is a deprecated version of BiasAdd and will be soon removed.
This is a special case of `tf.add` where `bias` is restricted to be 1-D.
Broadcasting is supported, so `value` may have any number of dimensions.
value: Any number of dimensions.
bias: 1-D with size the last dimension of `value`.
output: Broadcasted sum of `value` and `bias`.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("Conv2D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
.Attr("T: {half, float, double}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn(shape_inference::Conv2DShape)
.Doc(R"doc(
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
and a filter / kernel tensor of shape
`[filter_height, filter_width, in_channels, out_channels]`, this op
performs the following:
1. Flattens the filter to a 2-D matrix with shape
`[filter_height * filter_width * in_channels, output_channels]`.
2. Extracts image patches from the input tensor to form a *virtual*
tensor of shape `[batch, out_height, out_width,
filter_height * filter_width * in_channels]`.
3. For each patch, right-multiplies the filter matrix and the image patch
vector.
In detail, with the default NHWC format,
output[b, i, j, k] =
sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
filter[di, dj, q, k]
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
input: A 4-D tensor. The dimension order is interpreted according to the value
of `data_format`, see below for details.
filter: A 4-D tensor of shape
`[filter_height, filter_width, in_channels, out_channels]`
output: A 4-D tensor. The dimension order is determined by the value of
`data_format`, see below for details.
strides: 1-D tensor of length 4. The stride of the sliding window for each
dimension of `input`. The dimension order is determined by the value of
`data_format`, see below for details.
padding: The type of padding algorithm to use.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
)doc");
REGISTER_OP("Conv2DBackpropInput")
.Input("input_sizes: int32")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
.Attr("T: {half, float, double}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
// NOTE(mrry): We could in principle work out the shape from the
// gradients and the attrs, but if we do not know orig_input_shape
// statically, then we are unlikely to know the shape of the
// gradients either.
return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
})
.Doc(R"doc(
Computes the gradients of convolution with respect to the input.
input_sizes: An integer vector representing the shape of `input`,
where `input` is a 4-D `[batch, height, width, channels]` tensor.
filter: 4-D with shape
`[filter_height, filter_width, in_channels, out_channels]`.
out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`.
Gradients w.r.t. the output of the convolution.
strides: The stride of the sliding window for each dimension of the input
of the convolution. Must be in the same order as the dimension specified with
format.
padding: The type of padding algorithm to use.
output: 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient
w.r.t. the input of the convolution.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, in_height, in_width, in_channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
)doc");
// TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
// more general string attribute ('kernel_impl'?) that can be used to
// select among several possible implementations.
REGISTER_OP("Conv2DBackpropFilter")
.Input("input: T")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
.Attr("T: {half, float, double}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
// NOTE(mrry): We could in principle work out the shape from the
// gradients and the attrs, but if we do not know orig_input_shape
// statically, then we are unlikely to know the shape of the
// gradients either.
return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
})
.Doc(R"doc(
Computes the gradients of convolution with respect to the filter.
input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
filter_sizes: An integer vector representing the tensor shape of `filter`,
where `filter` is a 4-D
`[filter_height, filter_width, in_channels, out_channels]` tensor.
out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`.
Gradients w.r.t. the output of the convolution.
strides: The stride of the sliding window for each dimension of the input
of the convolution. Must be in the same order as the dimension specified with
format.
padding: The type of padding algorithm to use.
output: 4-D with shape
`[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
the `filter` input of the convolution.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, in_height, in_width, in_channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
)doc");
namespace {
Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
ShapeHandle resized = input;
int paddings_index = 1;
int filter_index = 2;
if (has_resize) {
paddings_index = 2;
filter_index = 3;
ShapeHandle unused_size;
TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size));
const Tensor* size = c->input_tensor(1);
DimensionHandle new_height = c->UnknownDim();
DimensionHandle new_width = c->UnknownDim();
if (size != nullptr) {
new_height = c->MakeDim(size->flat<int32>()(0));
new_width = c->MakeDim(size->flat<int32>()(1));
}
TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized));
TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized));
}
ShapeHandle paddings;
TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings));
TF_RETURN_IF_ERROR(
c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized));
TF_RETURN_IF_ERROR(
c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings));
const Tensor* paddings_t = c->input_tensor(paddings_index);
ShapeHandle padded;
if (paddings_t != nullptr) {
std::vector<DimensionHandle> output_dims;
for (int i = 0; i < 4; ++i) {
DimensionHandle dim = c->Dim(resized, i);
int64 p0 = static_cast<int64>(paddings_t->matrix<int32>()(i, 0));
int64 p1 = static_cast<int64>(paddings_t->matrix<int32>()(i, 1));
if (p0 < 0 || p1 < 0) {
return errors::InvalidArgument("Paddings must be non-negative");
}
TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim));
output_dims.push_back(dim);
}
padded = c->MakeShape(output_dims);
} else {
padded = c->UnknownShapeOfRank(4);
}
// Work out the convolution's effect with 'padded' as the input.
ShapeHandle filter;
TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter));
std::vector<int32> strides;
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
if (strides.size() != 4) {
return errors::InvalidArgument(
"Operation requires the stride attribute to contain 4 values, but ",
"got: ", strides.size());
}
int32 stride_rows = strides[1];
int32 stride_cols = strides[2];
DimensionHandle batch_size_dim = c->Dim(padded, 0);
DimensionHandle in_rows_dim = c->Dim(padded, 1);
DimensionHandle in_cols_dim = c->Dim(padded, 2);
DimensionHandle filter_rows_dim = c->Dim(filter, 0);
DimensionHandle filter_cols_dim = c->Dim(filter, 1);
DimensionHandle output_depth_dim = c->Dim(filter, 3);
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused));
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
DimensionHandle output_rows, output_cols;
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
ShapeHandle output_shape = c->MakeShape(
{batch_size_dim, output_rows, output_cols, output_depth_dim});
c->set_output(0, output_shape);
return Status::OK();
}
} // namespace
REGISTER_OP("FusedResizeAndPadConv2D")
.Input("input: T")
.Input("size: int32")
.Input("paddings: int32")
.Input("filter: T")
.Output("output: T")
.Attr("T: {half, float, double}")
.Attr("resize_align_corners: bool = false")
.Attr(GetMirrorPadModeAttrString())
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
return CommonFusedConvCalculations(c, true /* has_resize */);
})
.Doc(R"doc(
Performs a resize and padding as a preprocess during a convolution.
It's often possible to do spatial transformations more efficiently as part of
the packing stage of a convolution, so this op allows for an optimized
implementation where these stages are fused together. This prevents the need to
write out the intermediate results as whole tensors, reducing memory pressure,
and we can get some latency gains by merging the transformation calculations.
The data_format attribute for Conv2D isn't supported by this op, and defaults to
'NHWC' order.
Internally this op uses a single per-graph scratch buffer, which means that it
will block if multiple versions are being run in parallel. This is because this
operator is primarily an optimization to minimize memory usage.
input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
paddings: A two-column matrix specifying the padding sizes. The number of
rows must be the same as the rank of `input`.
filter: 4-D with shape
`[filter_height, filter_width, in_channels, out_channels]`.
resize_align_corners: If true, rescale input by (new_height - 1) / (height - 1),
which exactly aligns the 4 corners of images and resized images. If false, rescale
by new_height / height. Treat similarly the width dimension.
strides: 1-D of length 4. The stride of the sliding window for each dimension
of `input`. Must be in the same order as the dimension specified with format.
padding: The type of padding algorithm to use.
)doc");
REGISTER_OP("FusedPadConv2D")
.Input("input: T")
.Input("paddings: int32")
.Input("filter: T")
.Output("output: T")
.Attr("T: {half, float, double}")
.Attr(GetMirrorPadModeAttrString())
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
return CommonFusedConvCalculations(c, false /* has_resize */);
})
.Doc(R"doc(
Performs a padding as a preprocess during a convolution.
Similar to FusedResizeAndPadConv2d, this op allows for an optimized
implementation where the spatial padding transformation stage is fused with the
im2col lookup, but in this case without the bilinear filtering required for
resizing. Fusing the padding prevents the need to write out the intermediate
results as whole tensors, reducing memory pressure, and we can get some latency
gains by merging the transformation calculations.
The data_format attribute for Conv2D isn't supported by this op, and 'NHWC'
order is used instead.
Internally this op uses a single per-graph scratch buffer, which means that it
will block if multiple versions are being run in parallel. This is because this
operator is primarily an optimization to minimize memory usage.
input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
paddings: A two-column matrix specifying the padding sizes. The number of
rows must be the same as the rank of `input`.
filter: 4-D with shape
`[filter_height, filter_width, in_channels, out_channels]`.
strides: 1-D of length 4. The stride of the sliding window for each dimension
of `input`. Must be in the same order as the dimension specified with format.
padding: The type of padding algorithm to use.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("DepthwiseConv2dNative")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape)
.Doc(R"doc(
Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors.
Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
and a filter / kernel tensor of shape
`[filter_height, filter_width, in_channels, channel_multiplier]`, containing
`in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies
a different filter to each input channel (expanding from 1 channel to
`channel_multiplier` channels for each), then concatenates the results
together. Thus, the output has `in_channels * channel_multiplier` channels.
for k in 0..in_channels-1
for q in 0..channel_multiplier-1
output[b, i, j, k * channel_multiplier + q] =
sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
filter[di, dj, k, q]
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
strides: 1-D of length 4. The stride of the sliding window for each dimension
of `input`.
padding: The type of padding algorithm to use.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
)doc");
REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
.Input("input_sizes: int32")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
// NOTE(mrry): We could in principle work out the shape from the
// gradients and the attrs, but if we do not know orig_input_shape
// statically, then we are unlikely to know the shape of the
// gradients either.
return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
})
.Doc(R"doc(
Computes the gradients of depthwise convolution with respect to the input.
input_sizes: An integer vector representing the shape of `input`, based
on `data_format`. For example, if `data_format` is 'NHWC' then
`input` is a 4-D `[batch, height, width, channels]` tensor.
filter: 4-D with shape
`[filter_height, filter_width, in_channels, depthwise_multiplier]`.
out_backprop: 4-D with shape based on `data_format`.
For example, if `data_format` is 'NHWC' then
out_backprop shape is `[batch, out_height, out_width, out_channels]`.
Gradients w.r.t. the output of the convolution.
strides: The stride of the sliding window for each dimension of the input
of the convolution.
padding: The type of padding algorithm to use.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
output: 4-D with shape according to `data_format`. For example, if
`data_format` is 'NHWC', output shape is `[batch, in_height,
in_width, in_channels]`. Gradient w.r.t. the input of the
convolution.
)doc");
REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
.Input("input: T")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
// NOTE(mrry): We could in principle work out the shape from the
// gradients and the attrs, but if we do not know orig_input_shape
// statically, then we are unlikely to know the shape of the
// gradients either.
return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
})
.Doc(R"doc(
Computes the gradients of depthwise convolution with respect to the filter.
input: 4-D with shape based on `data_format`. For example, if
`data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height,
in_width, in_channels]` tensor.
filter_sizes: An integer vector representing the tensor shape of `filter`,
where `filter` is a 4-D
`[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor.
out_backprop: 4-D with shape based on `data_format`.
For example, if `data_format` is 'NHWC' then
out_backprop shape is `[batch, out_height, out_width, out_channels]`.
Gradients w.r.t. the output of the convolution.
strides: The stride of the sliding window for each dimension of the input
of the convolution.
padding: The type of padding algorithm to use.
data_format: Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
output: 4-D with shape
`[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
the `filter` input of the convolution.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("Conv3D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
.Attr("T: numbertype")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.SetShapeFn(shape_inference::Conv3DShape)
.Doc(R"doc(
Computes a 3-D convolution given 5-D `input` and `filter` tensors.
In signal processing, cross-correlation is a measure of similarity of
two waveforms as a function of a time-lag applied to one of them. This
is also known as a sliding dot product or sliding inner-product.
Our Conv3D implements a form of cross-correlation.
input: Shape `[batch, in_depth, in_height, in_width, in_channels]`.
filter: Shape `[filter_depth, filter_height, filter_width, in_channels,
out_channels]`. `in_channels` must match between `input` and `filter`.
strides: 1-D tensor of length 5. The stride of the sliding window for each
dimension of `input`. Must have `strides[0] = strides[4] = 1`.
padding: The type of padding algorithm to use.
)doc");
REGISTER_OP("Conv3DBackpropInput")
.Input("input: T")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
.Attr("T: numbertype")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Deprecated(10, "Use Conv3DBackpropInputV2")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 5);
})
.Doc(R"doc(
Computes the gradients of 3-D convolution with respect to the input.
input: Shape `[batch, depth, rows, cols, in_channels]`.
filter: Shape `[depth, rows, cols, in_channels, out_channels]`.
`in_channels` must match between `input` and `filter`.
out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols,
out_channels]`.
strides: 1-D tensor of length 5. The stride of the sliding window for each
dimension of `input`. Must have `strides[0] = strides[4] = 1`.
padding: The type of padding algorithm to use.
)doc");
REGISTER_OP("Conv3DBackpropFilter")
.Input("input: T")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
.Attr("T: numbertype")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Deprecated(10, "Use Conv3DBackpropFilterV2")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle out;