/
layers.py
1345 lines (1202 loc) · 57.2 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""Higher level ops for building layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import moving_averages
# TODO(b/28426988): Replace legacy_* fns migrated from slim.
# TODO(b/28426988): Remove legacy_* when all uses have migrated to new API.
__all__ = ['avg_pool2d',
'batch_norm',
'bias_add',
'conv2d',
'conv2d_in_plane',
'conv2d_transpose',
'convolution2d',
'convolution2d_in_plane',
'convolution2d_transpose',
'dropout',
'flatten',
'fully_connected',
'linear',
'max_pool2d',
'one_hot_encoding',
'relu',
'relu6',
'repeat',
'separable_conv2d',
'separable_convolution2d',
'softmax',
'stack',
'unit_norm',
'legacy_fully_connected',
'legacy_linear',
'legacy_relu']
@add_arg_scope
def avg_pool2d(inputs,
kernel_size,
stride=2,
padding='VALID',
outputs_collections=None,
scope=None):
"""Adds a Avg Pooling op.
It is assumed by the wrapper that the pooling is only done per image and not
in depth or batch.
Args:
inputs: a tensor of size [batch_size, height, width, depth].
kernel_size: a list of length 2: [kernel_height, kernel_width] of the
pooling kernel over which the op is computed. Can be an int if both
values are the same.
stride: a list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: the padding method, either 'VALID' or 'SAME'.
outputs_collections: collection to add the outputs.
scope: Optional scope for op_scope.
Returns:
a tensor representing the results of the pooling operation.
"""
with ops.op_scope([inputs], scope, 'AvgPool2D') as sc:
inputs = ops.convert_to_tensor(inputs)
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
stride_h, stride_w = utils.two_element_tuple(stride)
outputs = nn.avg_pool(inputs,
ksize=[1, kernel_h, kernel_w, 1],
strides=[1, stride_h, stride_w, 1],
padding=padding)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@add_arg_scope
def batch_norm(inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
activation_fn=None,
updates_collections=ops.GraphKeys.UPDATE_OPS,
is_training=True,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
"Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift"
Sergey Ioffe, Christian Szegedy
Can be used as a normalizer function for conv2d and fully_connected.
Args:
inputs: a tensor of size `[batch_size, height, width, channels]`
or `[batch_size, channels]`.
decay: decay for the moving average.
center: If True, subtract `beta`. If False, `beta` is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can be
disabled since the scaling can be done by the next layer.
epsilon: small float added to variance to avoid dividing by zero.
activation_fn: Optional activation function.
updates_collections: collections to collect the update ops for computation.
If None, a control dependency would be added to make sure the updates are
computed.
is_training: whether or not the layer is in training mode. In training mode
it would accumulate the statistics of the moments into `moving_mean` and
`moving_variance` using an exponential moving average with the given
`decay`. When it is not in training mode then it would use the values of
the `moving_mean` and the `moving_variance`.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: optional collections for the variables.
outputs_collections: collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the output of the operation.
Raises:
ValueError: if rank or last dimension of `inputs` is undefined.
"""
with variable_scope.variable_op_scope([inputs],
scope, 'BatchNorm', reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_shape = inputs.get_shape()
inputs_rank = inputs_shape.ndims
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
dtype = inputs.dtype.base_dtype
axis = list(range(inputs_rank - 1))
params_shape = inputs_shape[-1:]
if not params_shape.is_fully_defined():
raise ValueError('Inputs %s has undefined last dimension %s.' % (
inputs.name, params_shape))
# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
if center:
beta_collections = utils.get_variable_collections(variables_collections,
'beta')
beta = variables.model_variable('beta',
shape=params_shape,
dtype=dtype,
initializer=init_ops.zeros_initializer,
collections=beta_collections,
trainable=trainable)
if scale:
gamma_collections = utils.get_variable_collections(variables_collections,
'gamma')
gamma = variables.model_variable('gamma',
shape=params_shape,
dtype=dtype,
initializer=init_ops.ones_initializer,
collections=gamma_collections,
trainable=trainable)
# Create moving_mean and moving_variance variables and add them to the
# appropiate collections.
moving_mean_collections = utils.get_variable_collections(
variables_collections, 'moving_mean')
moving_mean = variables.model_variable(
'moving_mean',
shape=params_shape,
dtype=dtype,
initializer=init_ops.zeros_initializer,
trainable=False,
collections=moving_mean_collections)
moving_variance_collections = utils.get_variable_collections(
variables_collections, 'moving_variance')
moving_variance = variables.model_variable(
'moving_variance',
shape=params_shape,
dtype=dtype,
initializer=init_ops.ones_initializer,
trainable=False,
collections=moving_variance_collections)
# If `is_training` doesn't have a constant value, because it is a `Tensor`,
# a `Variable` or `Placeholder` then is_training_value will be None and
# `needs_moments` will be true.
is_training_value = utils.constant_value(is_training)
need_moments = is_training_value is None or is_training_value
if need_moments:
# Calculate the moments based on the individual batch.
mean, variance = nn.moments(inputs, axis, shift=moving_mean)
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None:
def _force_updates():
"""Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay)
with ops.control_dependencies([update_moving_mean,
update_moving_variance]):
return array_ops.identity(mean), array_ops.identity(variance)
mean, variance = utils.smart_cond(is_training,
_force_updates,
moving_vars_fn)
else:
def _delay_updates():
"""Internal function that delay updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay)
return update_moving_mean, update_moving_variance
update_mean, update_variance = utils.smart_cond(is_training,
_delay_updates,
moving_vars_fn)
ops.add_to_collections(updates_collections, update_mean)
ops.add_to_collections(updates_collections, update_variance)
# Use computed moments during training and moving_vars otherwise.
vars_fn = lambda: (mean, variance)
mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
else:
mean, variance = moving_mean, moving_variance
# Compute batch_normalization.
outputs = nn.batch_normalization(
inputs, mean, variance, beta, gamma, epsilon)
outputs.set_shape(inputs_shape)
if activation_fn:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
@add_arg_scope
def bias_add(inputs,
activation_fn=None,
initializer=init_ops.zeros_initializer,
regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a bias to the inputs.
Can be used as a normalizer function for conv2d and fully_connected.
Args:
inputs: a tensor of with at least rank 2 and value for the last dimension,
e.g. `[batch_size, depth]`, `[None, None, None, depth]`.
activation_fn: Optional activation function.
initializer: An initializer for the bias, defaults to 0.
regularizer: A regularizer like the result of
`l1_regularizer` or `l2_regularizer`.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: optional collections for the variables.
outputs_collections: collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for variable_op_scope.
Returns:
a tensor representing the result of adding biases to the inputs.
"""
with variable_scope.variable_op_scope([inputs],
scope, 'BiasAdd', reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
dtype = inputs.dtype.base_dtype
num_features = utils.last_dimension(inputs.get_shape(), min_rank=2)
biases_collections = utils.get_variable_collections(variables_collections,
'biases')
biases = variables.model_variable('biases',
shape=[num_features,],
dtype=dtype,
initializer=initializer,
regularizer=regularizer,
collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(inputs, biases)
if activation_fn:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
@add_arg_scope
def convolution2d(inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
rate=1,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer,
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a 2D convolution followed by an optional batch_norm layer.
`convolution2d` creates a variable called `weights`, representing the
convolutional kernel, that is convolved with the `inputs` to produce a
`Tensor` of activations. If a `normalizer_fn` is provided (such as
`batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is
None and a `biases_initializer` is provided then a `biases` variable would be
created and added the activations. Finally, if `activation_fn` is not `None`,
it is applied to the activations as well.
Performs a'trous convolution with input stride equal to rate if rate is
greater than one.
Args:
inputs: a 4-D tensor `[batch_size, height, width, channels]`.
num_outputs: integer, the number of output filters.
kernel_size: a list of length 2 `[kernel_height, kernel_width]` of
of the filters. Can be an int if both values are the same.
stride: a list of length 2 `[stride_height, stride_width]`.
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: one of `VALID` or `SAME`.
rate: integer. If less than or equal to 1, a standard convolution is used.
If greater than 1, than the a'trous convolution is applied and `stride`
must be set to 1.
activation_fn: activation function.
normalizer_fn: normalization function to use instead of `biases`. If
`normalize_fn` is provided then `biases_initializer` and
`biases_regularizer` are ignored and `biases` are not created nor added.
normalizer_params: normalization function parameters.
weights_initializer: An initializer for the weights.
weights_regularizer: Optional regularizer for the weights.
biases_initializer: An initializer for the biases. If None skip biases.
biases_regularizer: Optional regularizer for the biases.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: optional list of collections for all the variables or
a dictionay containing a different list of collection per variable.
outputs_collections: collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for `variable_op_scope`.
Returns:
a tensor representing the output of the operation.
Raises:
ValueError: if both 'rate' and `stride` are larger than one.
"""
with variable_scope.variable_op_scope([inputs],
scope, 'Conv', reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
dtype = inputs.dtype.base_dtype
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
stride_h, stride_w = utils.two_element_tuple(stride)
if rate > 1 and (stride_h > 1 or stride_w > 1):
raise ValueError('Only one of rate or stride can be larger than one')
num_filters_in = utils.last_dimension(inputs.get_shape(), min_rank=4)
weights_shape = [kernel_h, kernel_w,
num_filters_in, num_outputs]
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
weights = variables.model_variable('weights',
shape=weights_shape,
dtype=dtype,
initializer=weights_initializer,
regularizer=weights_regularizer,
collections=weights_collections,
trainable=trainable)
if rate > 1:
outputs = nn.atrous_conv2d(inputs, weights, rate, padding=padding)
else:
outputs = nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1],
padding=padding)
if normalizer_fn:
normalizer_params = normalizer_params or {}
outputs = normalizer_fn(outputs, **normalizer_params)
else:
if biases_initializer is not None:
biases_collections = utils.get_variable_collections(
variables_collections, 'biases')
biases = variables.model_variable('biases',
shape=[num_outputs,],
dtype=dtype,
initializer=biases_initializer,
regularizer=biases_regularizer,
collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(outputs, biases)
if activation_fn:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
@add_arg_scope
def convolution2d_in_plane(
inputs,
kernel_size,
stride=1,
padding='SAME',
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer,
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Performs the same in-plane convolution to each channel independently.
This is useful for performing various simple channel-independent convolution
operations such as image gradients:
image = tf.constant(..., shape=(16, 240, 320, 3))
vert_gradients = layers.conv2d_in_plane(image,
kernel=[1, -1],
kernel_size=[2, 1])
horz_gradients = layers.conv2d_in_plane(image,
kernel=[1, -1],
kernel_size=[1, 2])
Args:
inputs: a 4-D tensor with dimensions [batch_size, height, width, channels].
kernel_size: a list of length 2 holding the [kernel_height, kernel_width] of
of the pooling. Can be an int if both values are the same.
stride: a list of length 2 `[stride_height, stride_width]`.
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: the padding type to use, either 'SAME' or 'VALID'.
activation_fn: activation function.
normalizer_fn: normalization function to use instead of `biases`. If
`normalize_fn` is provided then `biases_initializer` and
`biases_regularizer` are ignored and `biases` are not created nor added.
normalizer_params: normalization function parameters.
weights_initializer: An initializer for the weights.
weights_regularizer: Optional regularizer for the weights.
biases_initializer: An initializer for the biases. If None skip biases.
biases_regularizer: Optional regularizer for the biases.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: optional list of collections for all the variables or
a dictionay containing a different list of collection per variable.
outputs_collections: collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for `variable_op_scope`.
Returns:
A `Tensor` representing the output of the operation.
"""
with variable_scope.variable_op_scope(
[inputs], scope, 'ConvInPlane', reuse=reuse) as sc:
dtype = inputs.dtype.base_dtype
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
stride_h, stride_w = utils.two_element_tuple(stride)
num_filters_in = utils.last_dimension(inputs.get_shape(), min_rank=4)
weights_shape = [kernel_h, kernel_w, 1, 1]
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
weights = variables.model_variable('weights',
shape=weights_shape,
dtype=dtype,
initializer=weights_initializer,
regularizer=weights_regularizer,
collections=weights_collections,
trainable=trainable)
depthwise_weights = array_ops.tile(weights, [1, 1, num_filters_in, 1])
outputs = nn.depthwise_conv2d(inputs, depthwise_weights,
[1, stride_h, stride_w, 1], padding)
if normalizer_fn:
normalizer_params = normalizer_params or {}
outputs = normalizer_fn(outputs, **normalizer_params)
else:
if biases_initializer is not None:
biases_collections = utils.get_variable_collections(
variables_collections, 'biases')
biases = variables.model_variable('biases',
shape=[num_filters_in,],
dtype=dtype,
initializer=biases_initializer,
regularizer=biases_regularizer,
collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(outputs, biases)
if activation_fn:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
@add_arg_scope
def convolution2d_transpose(
inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer,
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a convolution2d_transpose with an optional batch normalization layer.
The function creates a variable called `weights`, representing the
kernel, that is convolved with the input. If `batch_norm_params` is `None`, a
second variable called 'biases' is added to the result of the operation.
Args:
inputs: a tensor of size [batch_size, height, width, channels].
num_outputs: integer, the number of output filters.
kernel_size: a list of length 2 holding the [kernel_height, kernel_width] of
of the filters. Can be an int if both values are the same.
stride: a list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: one of 'VALID' or 'SAME'.
activation_fn: activation function.
normalizer_fn: normalization function to use instead of `biases`. If
`normalize_fn` is provided then `biases_initializer` and
`biases_regularizer` are ignored and `biases` are not created nor added.
normalizer_params: normalization function parameters.
weights_initializer: An initializer for the weights.
weights_regularizer: Optional regularizer for the weights.
biases_initializer: An initializer for the biases. If None skip biases.
biases_regularizer: Optional regularizer for the biases.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: optional list of collections for all the variables or
a dictionay containing a different list of collection per variable.
outputs_collections: collection to add the outputs.
trainable: whether or not the variables should be trainable or not.
scope: Optional scope for variable_op_scope.
Returns:
a tensor representing the output of the operation.
Raises:
ValueError: if 'kernel_size' is not a list of length 2.
"""
with variable_scope.variable_op_scope(
[inputs], scope, 'Conv2d_transpose', reuse=reuse) as sc:
dtype = inputs.dtype.base_dtype
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
stride_h, stride_w = utils.two_element_tuple(stride)
num_filters_in = utils.last_dimension(
inputs.get_shape(), min_rank=4)
weights_shape = [kernel_h, kernel_w, num_outputs, num_filters_in]
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
weights = variables.model_variable(
'weights',
shape=weights_shape,
dtype=dtype,
initializer=weights_initializer,
regularizer=weights_regularizer,
trainable=trainable,
collections=weights_collections)
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
height, width = inputs_shape[1], inputs_shape[2]
def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
if isinstance(dim_size, ops.Tensor):
dim_size = math_ops.mul(dim_size, stride_size)
elif dim_size is not None:
dim_size *= stride_size
if padding == 'VALID' and dim_size is not None:
dim_size += max(kernel_size - stride_size, 0)
return dim_size
# Infer the dynamic output shape:
out_height = get_deconv_dim(height, stride_h, kernel_h, padding)
out_width = get_deconv_dim(width, stride_w, kernel_w, padding)
output_shape = array_ops.pack(
[batch_size, out_height, out_width, num_outputs])
outputs = nn.conv2d_transpose(inputs, weights, output_shape,
[1, stride_h, stride_w, 1],
padding=padding)
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[-1] = num_outputs
out_shape[1] = get_deconv_dim(out_shape[1], stride_h, kernel_h, padding)
out_shape[2] = get_deconv_dim(out_shape[2], stride_w, kernel_w, padding)
outputs.set_shape(out_shape)
if normalizer_fn:
normalizer_params = normalizer_params or {}
outputs = normalizer_fn(outputs, **normalizer_params)
else:
if biases_initializer is not None:
biases_collections = utils.get_variable_collections(
variables_collections, 'biases')
biases = variables.model_variable('biases',
shape=[num_outputs,],
dtype=dtype,
initializer=biases_initializer,
regularizer=biases_regularizer,
collections=biases_collections)
outputs = nn.bias_add(outputs, biases)
if activation_fn:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
@add_arg_scope
def dropout(inputs,
keep_prob=0.5,
noise_shape=None,
is_training=True,
outputs_collections=None,
scope=None):
"""Returns a dropout op applied to the input.
With probability `keep_prob`, outputs the input element scaled up by
`1 / keep_prob`, otherwise outputs `0`. The scaling is so that the expected
sum is unchanged.
Args:
inputs: the tensor to pass to the nn.dropout op.
keep_prob: A scalar `Tensor` with the same type as x. The probability
that each element is kept.
noise_shape: A 1-D `Tensor` of type `int32`, representing the
shape for randomly generated keep/drop flags.
is_training: A bool `Tensor` indicating whether or not the model
is in training mode. If so, dropout is applied and values scaled.
Otherwise, inputs is returned.
outputs_collections: collection to add the outputs.
scope: Optional scope for op_scope.
Returns:
a tensor representing the output of the operation.
"""
with ops.op_scope([inputs], scope, 'Dropout') as sc:
inputs = ops.convert_to_tensor(inputs)
dropout_fn = lambda: nn.dropout(inputs, keep_prob, noise_shape)
id_fn = lambda: inputs
outputs = utils.smart_cond(is_training, dropout_fn, id_fn)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@add_arg_scope
def flatten(inputs,
outputs_collections=None,
scope=None):
"""Flattens the input while maintaining the batch_size.
Assumes that the first dimension represents the batch.
Args:
inputs: a tensor of size [batch_size, ...].
outputs_collections: collection to add the outputs.
scope: Optional scope for op_scope.
Returns:
a flattened tensor with shape [batch_size, k].
Raises:
ValueError: if inputs.shape is wrong.
"""
with ops.op_scope([inputs], scope, 'Flatten') as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_shape = inputs.get_shape()
inputs_rank = inputs_shape.ndims
if (inputs_rank is None) or (inputs_rank < 2):
raise ValueError('Inputs must have a least 2 dimensions.')
dims = inputs_shape[1:]
if not dims.is_fully_defined():
raise ValueError('Inputs 2nd dimension must be defined.')
k = dims.num_elements()
outputs = array_ops.reshape(inputs, [-1, k])
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@add_arg_scope
def fully_connected(inputs,
num_outputs,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer,
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a fully connected layer.
`fully_connected` creates a variable called `weights`, representing a fully
connected weight matrix, which is multiplied by the `inputs` to produce a
`Tensor` of hidden units. If a `normalizer_fn` is provided (such as
`batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is
None and a `biases_initializer` is provided then a `biases` variable would be
created and added the hidden units. Finally, if `activation_fn` is not `None`,
it is applied to the hidden units as well.
Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened
prior to the initial matrix multiply by `weights`.
Args:
inputs: A tensor of with at least rank 2 and value for the last dimension,
i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
num_outputs: Integer, the number of output units in the layer.
activation_fn: activation function.
normalizer_fn: normalization function to use instead of `biases`. If
`normalize_fn` is provided then `biases_initializer` and
`biases_regularizer` are ignored and `biases` are not created nor added.
normalizer_params: normalization function parameters.
weights_initializer: An initializer for the weights.
weights_regularizer: Optional regularizer for the weights.
biases_initializer: An initializer for the biases. If None skip biases.
biases_regularizer: Optional regularizer for the biases.
reuse: whether or not the layer and its variables should be reused. To be
able to reuse the layer scope must be given.
variables_collections: Optional list of collections for all the variables or
a dictionary containing a different list of collections per variable.
outputs_collections: collection to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional scope for variable_op_scope.
Returns:
the tensor variable representing the result of the series of operations.
Raises:
ValueError: if x has rank less than 2 or if its last dimension is not set.
"""
if not isinstance(num_outputs, int):
raise ValueError('num_outputs should be integer, got %s.', num_outputs)
with variable_scope.variable_op_scope([inputs],
scope,
'fully_connected',
reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
dtype = inputs.dtype.base_dtype
inputs_shape = inputs.get_shape()
num_input_units = utils.last_dimension(inputs_shape, min_rank=2)
static_shape = inputs_shape.as_list()
static_shape[-1] = num_outputs
out_shape = array_ops.unpack(array_ops.shape(inputs))
out_shape[-1] = num_outputs
weights_shape = [num_input_units, num_outputs]
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
weights = variables.model_variable('weights',
shape=weights_shape,
dtype=dtype,
initializer=weights_initializer,
regularizer=weights_regularizer,
collections=weights_collections,
trainable=trainable)
if len(static_shape) > 2:
# Reshape inputs
inputs = array_ops.reshape(inputs, [-1, num_input_units])
outputs = standard_ops.matmul(inputs, weights)
if normalizer_fn:
normalizer_params = normalizer_params or {}
outputs = normalizer_fn(outputs, **normalizer_params)
else:
if biases_initializer is not None:
biases_collections = utils.get_variable_collections(
variables_collections, 'biases')
biases = variables.model_variable('biases',
shape=[num_outputs,],
dtype=dtype,
initializer=biases_initializer,
regularizer=biases_regularizer,
collections=biases_collections,
trainable=trainable)
outputs = nn.bias_add(outputs, biases)
if activation_fn:
outputs = activation_fn(outputs)
if len(static_shape) > 2:
# Reshape back outputs
outputs = array_ops.reshape(outputs, array_ops.pack(out_shape))
outputs.set_shape(static_shape)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
@add_arg_scope
def max_pool2d(inputs,
kernel_size,
stride=2,
padding='VALID',
outputs_collections=None,
scope=None):
"""Adds a Max Pooling op.
It is assumed by the wrapper that the pooling is only done per image and not
in depth or batch.
Args:
inputs: a tensor of size [batch_size, height, width, depth].
kernel_size: a list of length 2: [kernel_height, kernel_width] of the
pooling kernel over which the op is computed. Can be an int if both
values are the same.
stride: a list of length 2: [stride_height, stride_width].
Can be an int if both strides are the same. Note that presently
both strides must have the same value.
padding: the padding method, either 'VALID' or 'SAME'.
outputs_collections: collection to add the outputs.
scope: Optional scope for op_scope.
Returns:
a tensor representing the results of the pooling operation.
Raises:
ValueError: if 'kernel_size' is not a 2-D list
"""
with ops.op_scope([inputs], scope, 'MaxPool2D') as sc:
inputs = ops.convert_to_tensor(inputs)
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
stride_h, stride_w = utils.two_element_tuple(stride)
outputs = nn.max_pool(inputs,
ksize=[1, kernel_h, kernel_w, 1],
strides=[1, stride_h, stride_w, 1],
padding=padding)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@add_arg_scope
def one_hot_encoding(labels,
num_classes,
on_value=1.0,
off_value=0.0,
outputs_collections=None,
scope=None):
"""Transform numeric labels into onehot_labels using tf.one_hot.
Args:
labels: [batch_size] target labels.
num_classes: total number of classes.
on_value: A scalar defining the on-value.
off_value: A scalar defining the off-value.
outputs_collections: collection to add the outputs.
scope: Optional scope for op_scope.
Returns:
one hot encoding of the labels.
"""
with ops.op_scope([labels, num_classes], scope, 'OneHotEncoding') as sc:
labels = ops.convert_to_tensor(labels)
if labels.dtype == dtypes.int32:
labels = standard_ops.to_int64(labels)
outputs = standard_ops.one_hot(labels,
num_classes,
on_value=on_value,
off_value=off_value)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
def _apply_activation(y, activation_fn, output_collections):
if activation_fn:
y = activation_fn(y)
ops.add_to_collections(list(output_collections or []) +
[ops.GraphKeys.ACTIVATIONS], y)
return y
def repeat(inputs, repetitions, layer, *args, **kwargs):
"""Applies the same layer with the same arguments repeatedly.
```python
y = repeat(x, 3, conv2d, 64, [3, 3], scope='conv1')
# It is equivalent to:
x = conv2d(x, 64, [3, 3], scope='conv1/conv1_1')
x = conv2d(x, 64, [3, 3], scope='conv1/conv1_2')
y = conv2d(x, 64, [3, 3], scope='conv1/conv1_3')
```
If the `scope` argument is not given in `kwargs`, it is set to
`layer.__name__`, or `layer.func.__name__` (for `functools.partial`
objects). If neither `__name__` nor `func.__name__` is available, the
layers are called with `scope='stack'`.
Args:
inputs: A `Tensor` suitable for layer.
repetitions: Int, number of repetitions.
layer: A layer with arguments `(inputs, *args, **kwargs)`
*args: Extra args for the layer.
**kwargs: Extra kwargs for the layer.
Returns:
a tensor result of applying the layer, repetitions times.
Raises:
ValueError: if the op is unknown or wrong.
"""
scope = kwargs.pop('scope', None)
with variable_scope.variable_op_scope([inputs], scope, 'Repeat'):
inputs = ops.convert_to_tensor(inputs)
if scope is None:
if hasattr(layer, '__name__'):
scope = layer.__name__
elif hasattr(layer, 'func') and hasattr(layer.func, '__name__'):
scope = layer.func.__name__ # In case layer is a functools.partial.
else:
scope = 'repeat'
outputs = inputs
for i in range(repetitions):
kwargs['scope'] = scope + '_' + str(i+1)
outputs = layer(outputs, *args, **kwargs)
return outputs
@add_arg_scope
def separable_convolution2d(
inputs,
num_outputs,
kernel_size,
depth_multiplier,
stride=1,
padding='SAME',
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer,
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
"""Adds a depth-separable 2D convolution with optional batch_norm layer.
This op first performs a depthwise convolution that acts separately on
channels, creating a variable called `depthwise_weights`. If `num_outputs`
is not None, it adds a pointwise convolution that mixes channels, creating a
variable called `pointwise_weights`. Then, if `batch_norm_params` is None,
it adds bias to the result, creating a variable called 'biases', otherwise
it adds a batch normalization layer. It finally applies an activation function
to produce the end result.
Args:
inputs: a tensor of size [batch_size, height, width, channels].
num_outputs: the number of pointwise convolution output filters. If is
None, then we skip the pointwise convolution stage.
kernel_size: a list of length 2: [kernel_height, kernel_width] of
of the filters. Can be an int if both values are the same.
depth_multiplier: the number of depthwise convolution output channels for