/
decoder.py
executable file
·54 lines (44 loc) · 2.04 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# coding=utf-8
from keras.layers.advanced_activations import PReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Activation
from keras.layers.merge import add
from keras.layers.normalization import BatchNormalization
from ..layers.pooling import MaxUnpooling2D
def bottleneck(encoder, output, upsample=False, reverse_module=False):
internal = output // 4
x = Conv2D(internal, (1, 1), use_bias=False)(encoder)
x = BatchNormalization(momentum=0.1)(x)
# x = Activation('relu')(x)
x = PReLU(shared_axes=[1, 2])(x)
if not upsample:
x = Conv2D(internal, (3, 3), padding='same', use_bias=True)(x)
else:
x = Conv2DTranspose(filters=internal, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization(momentum=0.1)(x)
# x = Activation('relu')(x)
x = PReLU(shared_axes=[1, 2])(x)
x = Conv2D(output, (1, 1), padding='same', use_bias=False)(x)
other = encoder
if encoder.get_shape()[-1] != output or upsample:
other = Conv2D(output, (1, 1), padding='same', use_bias=False)(other)
other = BatchNormalization(momentum=0.1)(other)
if upsample and reverse_module is not False:
other = MaxUnpooling2D()([other, reverse_module])
if upsample and reverse_module is False:
decoder = x
else:
x = BatchNormalization(momentum=0.1)(x)
decoder = add([x, other])
# decoder = Activation('relu')(decoder)
decoder = PReLU(shared_axes=[1, 2])(decoder)
return decoder
def build(encoder, nc):
network, index_stack = encoder
enet = bottleneck(network, 64, upsample=True, reverse_module=index_stack.pop()) # bottleneck 4.0
enet = bottleneck(enet, 64) # bottleneck 4.1
enet = bottleneck(enet, 64) # bottleneck 4.2
enet = bottleneck(enet, 16, upsample=True, reverse_module=index_stack.pop()) # bottleneck 5.0
enet = bottleneck(enet, 16) # bottleneck 5.1
enet = Conv2DTranspose(filters=nc, kernel_size=(2, 2), strides=(2, 2), padding='same')(enet)
return enet