/
conv.py
2220 lines (1812 loc) · 90.3 KB
/
conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import tensorflow as tf
import numpy as np
from math import ceil
import tflearn
from .. import variables as vs
from .. import activations
from .. import initializations
from .. import regularizers
from .. import utils
from ..layers.normalization import batch_normalization
def conv_2d(incoming, nb_filter, filter_size, strides=1, padding='same',
activation='linear', bias=True, weights_init='uniform_scaling',
bias_init='zeros', regularizer=None, weight_decay=0.001,
trainable=True, restore=True, reuse=False, scope=None,
name="Conv2D"):
""" Convolution 2D.
Input:
4-D Tensor [batch, height, width, in_channels].
Output:
4-D Tensor [batch, new height, new width, nb_filter].
Arguments:
incoming: `Tensor`. Incoming 4-D Tensor.
nb_filter: `int`. The number of convolutional filters.
filter_size: `int` or `list of int`. Size of filters.
strides: `int` or list of `int`. Strides of conv operation.
Default: [1 1 1 1].
padding: `str` from `"same", "valid"`. Padding algo to use.
Default: 'same'.
activation: `str` (name) or `function` (returning a `Tensor`) or None.
Activation applied to this layer (see tflearn.activations).
Default: 'linear'.
bias: `bool`. If True, a bias is used.
weights_init: `str` (name) or `Tensor`. Weights initialization.
(see tflearn.initializations) Default: 'truncated_normal'.
bias_init: `str` (name) or `Tensor`. Bias initialization.
(see tflearn.initializations) Default: 'zeros'.
regularizer: `str` (name) or `Tensor`. Add a regularizer to this
layer weights (see tflearn.regularizers). Default: None.
weight_decay: `float`. Regularizer decay parameter. Default: 0.001.
trainable: `bool`. If True, weights will be trainable.
restore: `bool`. If True, this layer weights will be restored when
loading a model.
reuse: `bool`. If True and 'scope' is provided, this layer variables
will be reused (shared).
scope: `str`. Define this layer scope (optional). A scope can be
used to share variables between layers. Note that scope will
override name.
name: A name for this layer (optional). Default: 'Conv2D'.
Attributes:
scope: `Scope`. This layer scope.
W: `Variable`. Variable representing filter weights.
b: `Variable`. Variable representing biases.
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 4, "Incoming Tensor shape must be 4-D, not %d-D" % len(input_shape)
filter_size = utils.autoformat_filter_conv2d(filter_size,
input_shape[-1],
nb_filter)
strides = utils.autoformat_kernel_2d(strides)
padding = utils.autoformat_padding(padding)
with tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse) as scope:
name = scope.name
W_init = weights_init
if isinstance(weights_init, str):
W_init = initializations.get(weights_init)()
elif type(W_init) in [tf.Tensor, np.ndarray, list]:
filter_size = None
W_regul = None
if regularizer is not None:
W_regul = lambda x: regularizers.get(regularizer)(x, weight_decay)
W = vs.variable('W', shape=filter_size, regularizer=W_regul,
initializer=W_init, trainable=trainable,
restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, W)
b = None
if bias:
b_shape = [nb_filter]
if isinstance(bias_init, str):
bias_init = initializations.get(bias_init)()
elif type(bias_init) in [tf.Tensor, np.ndarray, list]:
b_shape = None
b = vs.variable('b', shape=b_shape, initializer=bias_init,
trainable=trainable, restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, b)
inference = tf.nn.conv2d(incoming, W, strides, padding)
if b is not None: inference = tf.nn.bias_add(inference, b)
if activation:
if isinstance(activation, str):
inference = activations.get(activation)(inference)
elif hasattr(activation, '__call__'):
inference = activation(inference)
else:
raise ValueError("Invalid Activation.")
# Track activations.
tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, inference)
# Add attributes to Tensor to easy access weights.
inference.scope = scope
inference.W = W
inference.b = b
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)
return inference
def conv_2d_transpose(incoming, nb_filter, filter_size, output_shape,
strides=1, padding='same', activation='linear',
bias=True, weights_init='uniform_scaling',
bias_init='zeros', regularizer=None, weight_decay=0.001,
trainable=True, restore=True, reuse=False, scope=None,
name="Conv2DTranspose"):
""" Convolution 2D Transpose.
This operation is sometimes called "deconvolution" after (Deconvolutional
Networks)[http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf], but is
actually the transpose (gradient) of `conv_2d` rather than an actual
deconvolution.
Input:
4-D Tensor [batch, height, width, in_channels].
Output:
4-D Tensor [batch, new height, new width, nb_filter].
Arguments:
incoming: `Tensor`. Incoming 4-D Tensor.
nb_filter: `int`. The number of convolutional filters.
filter_size: `int` or `list of int`. Size of filters.
output_shape: `list of int`. Dimensions of the output tensor.
Can optionally include the number of conv filters.
[new height, new width, nb_filter] or [new height, new width].
strides: `int` or list of `int`. Strides of conv operation.
Default: [1 1 1 1].
padding: `str` from `"same", "valid"`. Padding algo to use.
Default: 'same'.
activation: `str` (name) or `function` (returning a `Tensor`).
Activation applied to this layer (see tflearn.activations).
Default: 'linear'.
bias: `bool`. If True, a bias is used.
weights_init: `str` (name) or `Tensor`. Weights initialization.
(see tflearn.initializations) Default: 'truncated_normal'.
bias_init: `str` (name) or `Tensor`. Bias initialization.
(see tflearn.initializations) Default: 'zeros'.
regularizer: `str` (name) or `Tensor`. Add a regularizer to this
layer weights (see tflearn.regularizers). Default: None.
weight_decay: `float`. Regularizer decay parameter. Default: 0.001.
trainable: `bool`. If True, weights will be trainable.
restore: `bool`. If True, this layer weights will be restored when
loading a model.
reuse: `bool`. If True and 'scope' is provided, this layer variables
will be reused (shared).
scope: `str`. Define this layer scope (optional). A scope can be
used to share variables between layers. Note that scope will
override name.
name: A name for this layer (optional). Default: 'Conv2DTranspose'.
Attributes:
scope: `Scope`. This layer scope.
W: `Variable`. Variable representing filter weights.
b: `Variable`. Variable representing biases.
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 4, "Incoming Tensor shape must be 4-D, not %d-D" % len(input_shape)
filter_size = utils.autoformat_filter_conv2d(filter_size,
nb_filter,
input_shape[-1])
strides = utils.autoformat_kernel_2d(strides)
padding = utils.autoformat_padding(padding)
with tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse) as scope:
name = scope.name
W_init = weights_init
if isinstance(weights_init, str):
W_init = initializations.get(weights_init)()
elif type(W_init) in [tf.Tensor, np.ndarray, list]:
filter_size = None
W_regul = None
if regularizer is not None:
W_regul = lambda x: regularizers.get(regularizer)(x, weight_decay)
W = vs.variable('W', shape=filter_size,
regularizer=W_regul, initializer=W_init,
trainable=trainable, restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, W)
b = None
if bias:
b_shape = [nb_filter]
if isinstance(bias_init, str):
bias_init = initializations.get(bias_init)()
elif type(bias_init) in [tf.Tensor, np.ndarray, list]:
b_shape = None
b = vs.variable('b', shape=b_shape, initializer=bias_init,
trainable=trainable, restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, b)
# Determine the complete shape of the output tensor.
batch_size = tf.gather(tf.shape(incoming), tf.constant([0]))
if len(output_shape) == 2:
output_shape = output_shape + [nb_filter]
elif len(output_shape) != 3:
raise Exception("output_shape length error: "
+ str(len(output_shape))
+ ", only a length of 2 or 3 is supported.")
complete_out_shape = tf.concat([batch_size, tf.constant(output_shape)], 0)
inference = tf.nn.conv2d_transpose(incoming, W, complete_out_shape,
strides, padding)
# Reshape tensor so its shape is correct.
inference.set_shape([None] + output_shape)
if b is not None: inference = tf.nn.bias_add(inference, b)
if isinstance(activation, str):
inference = activations.get(activation)(inference)
elif hasattr(activation, '__call__'):
inference = activation(inference)
else:
raise ValueError("Invalid Activation.")
# Track activations.
tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, inference)
# Add attributes to Tensor to easy access weights.
inference.scope = scope
inference.W = W
inference.b = b
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)
return inference
def atrous_conv_2d(incoming, nb_filter, filter_size, rate=1, padding='same',
activation='linear', bias=True, weights_init='uniform_scaling',
bias_init='zeros', regularizer=None, weight_decay=0.001,
trainable=True, restore=True, reuse=False, scope=None,
name="AtrousConv2D"):
""" Atrous Convolution 2D.
(a.k.a. convolution with holes or dilated convolution).
Computes a 2-D atrous convolution, also known as convolution with holes or
dilated convolution, given 4-D value and filters tensors. If the rate
parameter is equal to one, it performs regular 2-D convolution. If the rate
parameter is greater than one, it performs convolution with holes, sampling
the input values every rate pixels in the height and width dimensions. This
is equivalent to convolving the input with a set of upsampled filters,
produced by inserting rate - 1 zeros between two consecutive values of the
filters along the height and width dimensions, hence the name atrous
convolution or convolution with holes (the French word trous means holes
in English).
More specifically
```
output[b, i, j, k] = sum_{di, dj, q} filters[di, dj, q, k] *
value[b, i + rate * di, j + rate * dj, q]
```
Atrous convolution allows us to explicitly control how densely to compute
feature responses in fully convolutional networks. Used in conjunction
with bilinear interpolation, it offers an alternative to conv2d_transpose
in dense prediction tasks such as semantic image segmentation,
optical flow computation, or depth estimation. It also allows us to
effectively enlarge the field of view of filters without increasing the
number of parameters or the amount of computation.
Input:
4-D Tensor [batch, height, width, in_channels].
Output:
4-D Tensor [batch, new height, new width, nb_filter].
Arguments:
incoming: `Tensor`. Incoming 4-D Tensor.
nb_filter: `int`. The number of convolutional filters.
filter_size: `int` or `list of int`. Size of filters.
rate: `int`. A positive int32. The stride with which we sample input
values across the height and width dimensions. Equivalently, the
rate by which we upsample the filter values by inserting zeros
across the height and width dimensions. In the literature, the
same parameter is sometimes called input `stride` or `dilation`.
padding: `str` from `"same", "valid"`. Padding algo to use.
Default: 'same'.
activation: `str` (name) or `function` (returning a `Tensor`) or None.
Activation applied to this layer (see tflearn.activations).
Default: 'linear'.
bias: `bool`. If True, a bias is used.
weights_init: `str` (name) or `Tensor`. Weights initialization.
(see tflearn.initializations) Default: 'truncated_normal'.
bias_init: `str` (name) or `Tensor`. Bias initialization.
(see tflearn.initializations) Default: 'zeros'.
regularizer: `str` (name) or `Tensor`. Add a regularizer to this
layer weights (see tflearn.regularizers). Default: None.
weight_decay: `float`. Regularizer decay parameter. Default: 0.001.
trainable: `bool`. If True, weights will be trainable.
restore: `bool`. If True, this layer weights will be restored when
loading a model.
reuse: `bool`. If True and 'scope' is provided, this layer variables
will be reused (shared).
scope: `str`. Define this layer scope (optional). A scope can be
used to share variables between layers. Note that scope will
override name.
name: A name for this layer (optional). Default: 'Conv2D'.
Attributes:
scope: `Scope`. This layer scope.
W: `Variable`. Variable representing filter weights.
b: `Variable`. Variable representing biases.
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 4, "Incoming Tensor shape must be 4-D, not %d-D" % len(input_shape)
filter_size = utils.autoformat_filter_conv2d(filter_size,
input_shape[-1],
nb_filter)
padding = utils.autoformat_padding(padding)
with tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse) as scope:
name = scope.name
W_init = weights_init
if isinstance(weights_init, str):
W_init = initializations.get(weights_init)()
elif type(W_init) in [tf.Tensor, np.ndarray, list]:
filter_size = None
W_regul = None
if regularizer is not None:
W_regul = lambda x: regularizers.get(regularizer)(x, weight_decay)
W = vs.variable('W', shape=filter_size, regularizer=W_regul,
initializer=W_init, trainable=trainable,
restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, W)
b = None
if bias:
b_shape = [nb_filter]
if isinstance(bias_init, str):
bias_init = initializations.get(bias_init)()
elif type(bias_init) in [tf.Tensor, np.ndarray, list]:
b_shape = None
b = vs.variable('b', shape=b_shape, initializer=bias_init,
trainable=trainable, restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, b)
inference = tf.nn.atrous_conv2d(incoming, W, rate, padding)
if b is not None: inference = tf.nn.bias_add(inference, b)
if activation:
if isinstance(activation, str):
inference = activations.get(activation)(inference)
elif hasattr(activation, '__call__'):
inference = activation(inference)
else:
raise ValueError("Invalid Activation.")
# Track activations.
tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, inference)
# Add attributes to Tensor to easy access weights.
inference.scope = scope
inference.W = W
inference.b = b
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)
return inference
def grouped_conv_2d(incoming, channel_multiplier, filter_size, strides=1,
padding='same', activation='linear', bias=False,
weights_init='uniform_scaling', bias_init='zeros',
regularizer=None, weight_decay=0.001, trainable=True,
restore=True, reuse=False, scope=None,
name="GroupedConv2D"):
""" Grouped Convolution 2D.
a.k.a DepthWise Convolution 2D.
Given a 4D input tensor ('NHWC' or 'NCHW' data formats), a kernel_size and
a channel_multiplier, grouped_conv_2d applies a different filter to each
input channel (expanding from 1 channel to channel_multiplier channels
for each), then concatenates the results together. The output has
in_channels * channel_multiplier channels.
In detail,
```
output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
strides[2] * j + rate[1] * dj, k]
```
Must have strides[0] = strides[3] = 1. For the most common case of the same
horizontal and vertical strides, strides = [1, stride, stride, 1]. If any
value in rate is greater than 1, we perform atrous depthwise convolution,
in which case all values in the strides tensor must be equal to 1.
Input:
4-D Tensor [batch, height, width, in_channels].
Output:
4-D Tensor [batch, new height, new width, in_channels * channel_multiplier].
Arguments:
incoming: `Tensor`. Incoming 4-D Tensor.
channel_multiplier: `int`. The number of channels to expand to.
filter_size: `int` or `list of int`. Size of filters.
strides: `int` or list of `int`. Strides of conv operation.
Default: [1 1 1 1].
padding: `str` from `"same", "valid"`. Padding algo to use.
Default: 'same'.
activation: `str` (name) or `function` (returning a `Tensor`) or None.
Activation applied to this layer (see tflearn.activations).
Default: 'linear'.
bias: `bool`. If True, a bias is used.
weights_init: `str` (name) or `Tensor`. Weights initialization.
(see tflearn.initializations) Default: 'truncated_normal'.
bias_init: `str` (name) or `Tensor`. Bias initialization.
(see tflearn.initializations) Default: 'zeros'.
regularizer: `str` (name) or `Tensor`. Add a regularizer to this
layer weights (see tflearn.regularizers). Default: None.
weight_decay: `float`. Regularizer decay parameter. Default: 0.001.
trainable: `bool`. If True, weights will be trainable.
restore: `bool`. If True, this layer weights will be restored when
loading a model.
reuse: `bool`. If True and 'scope' is provided, this layer variables
will be reused (shared).
scope: `str`. Define this layer scope (optional). A scope can be
used to share variables between layers. Note that scope will
override name.
name: A name for this layer (optional). Default: 'Conv2D'.
Attributes:
scope: `Scope`. This layer scope.
W: `Variable`. Variable representing filter weights.
b: `Variable`. Variable representing biases.
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 4, "Incoming Tensor shape must be 4-D, not %d-D" % len(input_shape)
nb_filter = channel_multiplier * input_shape[-1]
strides = utils.autoformat_kernel_2d(strides)
filter_size = utils.autoformat_filter_conv2d(filter_size,
input_shape[-1],
channel_multiplier)
padding = utils.autoformat_padding(padding)
with tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse) as scope:
name = scope.name
W_init = weights_init
if isinstance(weights_init, str):
W_init = initializations.get(weights_init)()
elif type(W_init) in [tf.Tensor, np.ndarray, list]:
filter_size = None
W_regul = None
if regularizer is not None:
W_regul = lambda x: regularizers.get(regularizer)(x, weight_decay)
W = vs.variable('W', shape=filter_size, regularizer=W_regul,
initializer=W_init, trainable=trainable,
restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, W)
b = None
if bias:
b_shape = [nb_filter]
if isinstance(bias_init, str):
bias_init = initializations.get(bias_init)()
elif type(bias_init) in [tf.Tensor, np.ndarray, list]:
b_shape = None
b = vs.variable('b', shape=b_shape, initializer=bias_init,
trainable=trainable, restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, b)
inference = tf.nn.depthwise_conv2d(incoming, W, strides, padding)
if b is not None: inference = tf.nn.bias_add(inference, b)
if activation:
if isinstance(activation, str):
inference = activations.get(activation)(inference)
elif hasattr(activation, '__call__'):
inference = activation(inference)
else:
raise ValueError("Invalid Activation.")
# Track activations.
tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, inference)
# Add attributes to Tensor to easy access weights.
inference.scope = scope
inference.W = W
inference.b = b
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)
return inference
def max_pool_2d(incoming, kernel_size, strides=None, padding='same',
name="MaxPool2D"):
""" Max Pooling 2D.
Input:
4-D Tensor [batch, height, width, in_channels].
Output:
4-D Tensor [batch, pooled height, pooled width, in_channels].
Arguments:
incoming: `Tensor`. Incoming 4-D Layer.
kernel_size: `int` or `list of int`. Pooling kernel size.
strides: `int` or `list of int`. Strides of conv operation.
Default: same as kernel_size.
padding: `str` from `"same", "valid"`. Padding algo to use.
Default: 'same'.
name: A name for this layer (optional). Default: 'MaxPool2D'.
Attributes:
scope: `Scope`. This layer scope.
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 4, "Incoming Tensor shape must be 4-D, not %d-D" % len(input_shape)
kernel = utils.autoformat_kernel_2d(kernel_size)
strides = utils.autoformat_kernel_2d(strides) if strides else kernel
padding = utils.autoformat_padding(padding)
with tf.name_scope(name) as scope:
inference = tf.nn.max_pool(incoming, kernel, strides, padding)
# Track activations.
tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, inference)
# Add attributes to Tensor to easy access weights
inference.scope = scope
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)
return inference
def avg_pool_2d(incoming, kernel_size, strides=None, padding='same',
name="AvgPool2D"):
""" Average Pooling 2D.
Input:
4-D Tensor [batch, height, width, in_channels].
Output:
4-D Tensor [batch, pooled height, pooled width, in_channels].
Arguments:
incoming: `Tensor`. Incoming 4-D Layer.
kernel_size: `int` or `list of int`. Pooling kernel size.
strides: `int` or `list of int`. Strides of conv operation.
Default: same as kernel_size.
padding: `str` from `"same", "valid"`. Padding algo to use.
Default: 'same'.
name: A name for this layer (optional). Default: 'AvgPool2D'.
Attributes:
scope: `Scope`. This layer scope.
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 4, "Incoming Tensor shape must be 4-D, not %d-D" % len(input_shape)
kernel = utils.autoformat_kernel_2d(kernel_size)
strides = utils.autoformat_kernel_2d(strides) if strides else kernel
padding = utils.autoformat_padding(padding)
with tf.name_scope(name) as scope:
inference = tf.nn.avg_pool(incoming, kernel, strides, padding)
# Track activations.
tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, inference)
# Add attributes to Tensor to easy access weights
inference.scope = scope
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)
return inference
def upsample_2d(incoming, kernel_size, name="UpSample2D"):
""" UpSample 2D.
Input:
4-D Tensor [batch, height, width, in_channels].
Output:
4-D Tensor [batch, pooled height, pooled width, in_channels].
Arguments:
incoming: `Tensor`. Incoming 4-D Layer to upsample.
kernel_size: `int` or `list of int`. Upsampling kernel size.
name: A name for this layer (optional). Default: 'UpSample2D'.
Attributes:
scope: `Scope`. This layer scope.
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 4, "Incoming Tensor shape must be 4-D, not %d-D" % len(input_shape)
kernel = utils.autoformat_kernel_2d(kernel_size)
with tf.name_scope(name) as scope:
inference = tf.image.resize_nearest_neighbor(
incoming, size=input_shape[1:3] * tf.constant(kernel[1:3]))
inference.set_shape((None, input_shape[1] * kernel[1],
input_shape[2] * kernel[2], None))
# Add attributes to Tensor to easy access weights
inference.scope = scope
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)
return inference
# Shortcut
deconv_2d = upsample_2d
def upscore_layer(incoming, num_classes, shape=None, kernel_size=4,
strides=2, trainable=True, restore=True,
reuse=False, scope=None, name='Upscore'):
""" Upscore.
This implements the upscore layer as used in
(Fully Convolutional Networks)[http://arxiv.org/abs/1411.4038].
The upscore layer is initialized as bilinear upsampling filter.
Input:
4-D Tensor [batch, height, width, in_channels].
Output:
4-D Tensor [pooled height, pooled width].
Arguments:
incoming: `Tensor`. Incoming 4-D Layer to upsample.
num_classes: `int`. Number of output feature maps.
shape: `list of int`. Dimension of the output map
[batch_size, new height, new width]. For convinience four values
are allows [batch_size, new height, new width, X], where X
is ignored.
kernel_size: `int` or `list of int`. Upsampling kernel size.
strides: `int` or `list of int`. Strides of conv operation.
Default: [1 2 2 1].
trainable: `bool`. If True, weights will be trainable.
restore: `bool`. If True, this layer weights will be restored when
loading a model.
reuse: `bool`. If True and 'scope' is provided, this layer variables
will be reused (shared).
scope: `str`. Define this layer scope (optional). A scope can be
used to share variables between layers. Note that scope will
override name.
name: A name for this layer (optional). Default: 'Upscore'.
Attributes:
scope: `Scope`. This layer scope.
Links:
(Fully Convolutional Networks)[http://arxiv.org/abs/1411.4038]
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 4, "Incoming Tensor shape must be 4-D, not %d-D" % len(input_shape)
strides = utils.autoformat_kernel_2d(strides)
filter_size = utils.autoformat_filter_conv2d(kernel_size,
num_classes,
input_shape[-1])
with tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse) as scope:
name = scope.name
in_shape = tf.shape(incoming)
if shape is None:
# Compute shape out of Bottom
h = ((in_shape[1] - 1) * strides[1]) + 1
w = ((in_shape[2] - 1) * strides[1]) + 1
new_shape = [in_shape[0], h, w, num_classes]
else:
new_shape = [in_shape[0], shape[0], shape[1], num_classes]
output_shape = tf.stack(new_shape)
def get_deconv_filter(f_shape):
"""
Create filter weights initialized as bilinear upsampling.
"""
width = f_shape[0]
heigh = f_shape[0]
f = ceil(width/2.0)
c = (2 * f - 1 - f % 2) / (2.0 * f)
bilinear = np.zeros([f_shape[0], f_shape[1]])
for x in range(width):
for y in range(heigh):
value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
bilinear[x, y] = value
weights = np.zeros(f_shape)
for i in range(f_shape[2]):
weights[:, :, i, i] = bilinear
init = tf.constant_initializer(value=weights,
dtype=tf.float32)
W = vs.variable(name="up_filter", initializer=init,
shape=weights.shape, trainable=trainable,
restore=restore)
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, W)
return W
weights = get_deconv_filter(filter_size)
deconv = tf.nn.conv2d_transpose(incoming, weights, output_shape,
strides=strides, padding='SAME')
deconv.scope = scope
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, deconv)
return deconv
def upscore_layer3d(incoming, num_classes, shape=None, kernel_size=4,
strides=2, trainable=True, restore=True,
reuse=False, scope=None, name='Upscore'):
""" Upscore.
This implements the upscore layer as used in
(Fully Convolutional Networks)[http://arxiv.org/abs/1411.4038].
The upscore layer is initialized as bilinear upsampling filter.
Input:
5-D Tensor [batch, height, width, depth, in_channels].
Output:
5-D Tensor [batch, pooled height, pooled width, pooled depth, in_channels].
Arguments:
incoming: `Tensor`. Incoming 4-D Layer to upsample.
num_classes: `int`. Number of output feature maps.
shape: `list of int`. Dimension of the output map
[new height, new width, new depth]. For convinience four values
are allows [new height, new width, new depth, X], where X
is ignored.
kernel_size: 'int` or `list of int`. Upsampling kernel size.
strides: 'int` or `list of int`. Strides of conv operation.
Default: [1 2 2 2 1].
trainable: `bool`. If True, weights will be trainable.
restore: `bool`. If True, this layer weights will be restored when
loading a model.
reuse: `bool`. If True and 'scope' is provided, this layer variables
will be reused (shared).
scope: `str`. Define this layer scope (optional). A scope can be
used to share variables between layers. Note that scope will
override name.
name: A name for this layer (optional). Default: 'Upscore'.
Attributes:
scope: `Scope`. This layer scope.
Links:
(Fully Convolutional Networks)[http://arxiv.org/abs/1411.4038]
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 5, "Incoming Tensor shape must be 5-D, not %d-D" % len(input_shape)
strides = utils.autoformat_kernel_3d(strides)
filter_size = utils.autoformat_filter_conv3d(kernel_size,
num_classes,
input_shape[-1])
# Variable Scope fix for older TF
try:
vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse)
except Exception:
vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
with vscope as scope:
name = scope.name
in_shape = tf.shape(incoming)
if shape is None:
# Compute shape out of Bottom
h = ((in_shape[1] - 1) * strides[1]) + 1
w = ((in_shape[2] - 1) * strides[1]) + 1
d = ((in_shape[3] - 1) * strides[1]) + 1
new_shape = [in_shape[0], h, w, d, num_classes]
else:
new_shape = [in_shape[0], shape[0], shape[1], shape[2], num_classes]
output_shape = tf.stack(new_shape)
def get_deconv_filter(f_shape):
"""
Create filter weights initialized as bilinear upsampling.
"""
width = f_shape[0]
heigh = f_shape[0]
depth = f_shape[0]
f = ceil(width/2.0)
c = (2 * f - 1 - f % 2) / (2.0 * f)
bilinear = np.zeros([f_shape[0], f_shape[1], f_shape[2]])
for x in range(width):
for y in range(heigh):
for z in range(depth):
value = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) * (1 - abs(z / f - c))
bilinear[x, y, z] = value
weights = np.zeros(f_shape)
for i in range(f_shape[3]):
weights[:, :, :, i, i] = bilinear
init = tf.constant_initializer(value=weights,
dtype=tf.float32)
W = vs.variable(name="up_filter", initializer=init,
shape=weights.shape, trainable=trainable,
restore=restore)
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, W)
return W
weights = get_deconv_filter(filter_size)
deconv = tf.nn.conv3d_transpose(incoming, weights, output_shape,
strides=strides, padding='SAME')
deconv.scope = scope
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, deconv)
return deconv
def conv_1d(incoming, nb_filter, filter_size, strides=1, padding='same',
activation='linear', bias=True, weights_init='uniform_scaling',
bias_init='zeros', regularizer=None, weight_decay=0.001,
trainable=True, restore=True, reuse=False, scope=None,
name="Conv1D"):
""" Convolution 1D.
Input:
3-D Tensor [batch, steps, in_channels].
Output:
3-D Tensor [batch, new steps, nb_filters].
Arguments:
incoming: `Tensor`. Incoming 3-D Tensor.
nb_filter: `int`. The number of convolutional filters.
filter_size: `int` or `list of int`. Size of filters.
strides: `int` or `list of int`. Strides of conv operation.
Default: [1 1 1 1].
padding: `str` from `"same", "valid"`. Padding algo to use.
Default: 'same'.
activation: `str` (name) or `function` (returning a `Tensor`).
Activation applied to this layer (see tflearn.activations).
Default: 'linear'.
bias: `bool`. If True, a bias is used.
weights_init: `str` (name) or `Tensor`. Weights initialization.
(see tflearn.initializations) Default: 'truncated_normal'.
bias_init: `str` (name) or `Tensor`. Bias initialization.
(see tflearn.initializations) Default: 'zeros'.
regularizer: `str` (name) or `Tensor`. Add a regularizer to this
layer weights (see tflearn.regularizers). Default: None.
weight_decay: `float`. Regularizer decay parameter. Default: 0.001.
trainable: `bool`. If True, weights will be trainable.
restore: `bool`. If True, this layer weights will be restored when
loading a model
reuse: `bool`. If True and 'scope' is provided, this layer variables
will be reused (shared).
scope: `str`. Define this layer scope (optional). A scope can be
used to share variables between layers. Note that scope will
override name.
name: A name for this layer (optional). Default: 'Conv1D'.
Attributes:
scope: `Scope`. This layer scope.
W: `Variable`. Variable representing filter weights.
b: `Variable`. Variable representing biases.
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) == 3, "Incoming Tensor shape must be 3-D, not %d-D" % len(input_shape)
filter_size = utils.autoformat_filter_conv2d(filter_size,
input_shape[-1],
nb_filter)
#filter_size = [1, filter_size[1], 1, 1]
filter_size[1] = 1
strides = utils.autoformat_kernel_2d(strides)
strides = [1, strides[1], 1, 1]
#strides[1] = 1
padding = utils.autoformat_padding(padding)
with tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse) as scope:
name = scope.name
W_init = weights_init
if isinstance(weights_init, str):
W_init = initializations.get(weights_init)()
elif type(W_init) in [tf.Tensor, np.ndarray, list]:
filter_size = None
W_regul = None
if regularizer is not None:
W_regul = lambda x: regularizers.get(regularizer)(x, weight_decay)
W = vs.variable('W', shape=filter_size, regularizer=W_regul,
initializer=W_init, trainable=trainable,
restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, W)
b = None
if bias:
b_shape = [nb_filter]
if isinstance(bias_init, str):
bias_init = initializations.get(bias_init)()
elif type(bias_init) in [tf.Tensor, np.ndarray, list]:
b_shape = None
b = vs.variable('b', shape=b_shape, initializer=bias_init,
trainable=trainable, restore=restore)
# Track per layer variables
tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/' + name, b)
# Adding dummy dimension to fit with Tensorflow conv2d
inference = tf.expand_dims(incoming, 2)
inference = tf.nn.conv2d(inference, W, strides, padding)
if b is not None: inference = tf.nn.bias_add(inference, b)
inference = tf.squeeze(inference, [2])
if isinstance(activation, str):
inference = activations.get(activation)(inference)
elif hasattr(activation, '__call__'):
inference = activation(inference)
else:
raise ValueError("Invalid Activation.")
# Track activations.
tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, inference)
# Add attributes to Tensor to easy access weights.
inference.scope = scope
inference.W = W
inference.b = b
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)
return inference
def max_pool_1d(incoming, kernel_size, strides=None, padding='same',