1
+ # a bit hacky way to make encoder head obtain input shape dynamically
2
+ import tensorflow as tf
3
+ import tensorflow .keras .layers as L
4
+ from NN .utils import sMLP
5
+ from NN .encoding import CCoordsGridLayer , CCoordsEncodingLayer
6
+ from NN .layers import MixerConvLayer , Patches , TransformerBlock
7
+ from Utils .utils import dumb_deepcopy
8
+
9
+ def block_params_from_config (config ):
10
+ layers = config .get ('layers' , None )
11
+ if not (layers is None ): return layers
12
+
13
+ defaultConvParams = {
14
+ 'kernel size' : config .get ('kernel size' , 3 ),
15
+ 'activation' : config .get ('activation' , 'relu' ),
16
+ 'name' : config .get ('name' , 'Conv2D' ),
17
+ }
18
+ convBefore = config ['conv before' ]
19
+ # if convBefore is an integer, then it's the same for all layers
20
+ if isinstance (convBefore , int ):
21
+ convParams = { 'channels' : config ['channels' ], ** defaultConvParams }
22
+ convBefore = [convParams ] * convBefore # repeat the same parameters
23
+ pass
24
+ assert isinstance (convBefore , list ), 'convBefore must be a list'
25
+ # if convBefore is a list of integers, then each integer is the number of channels
26
+ if (0 < len (convBefore )) and isinstance (convBefore [0 ], int ):
27
+ convBefore = [ {'channels' : sz , ** defaultConvParams } for sz in convBefore ]
28
+ pass
29
+
30
+ # add separately last layer
31
+ lastConvParams = {
32
+ 'channels' : config .get ('channels last' , config ['channels' ]),
33
+ 'kernel size' : config .get ('kernel size last' , defaultConvParams ['kernel size' ]),
34
+ 'activation' : config .get ('final activation' , defaultConvParams ['activation' ]),
35
+ 'name' : config .get ('last name' , 'Conv2D' ),
36
+ }
37
+ return convBefore + [lastConvParams ]
38
+
39
+ def conv_block_from_config (data , config , defaults , name = 'CB' ):
40
+ config = {** defaults , ** config } # merge defaults and config
41
+ convParams = block_params_from_config (config )
42
+ # apply convolutions to the data
43
+ for i , parameters in enumerate (convParams ):
44
+ parameters = dumb_deepcopy (parameters )
45
+ Name = parameters .get ('name' , 'Conv2D' )
46
+ if 'Conv2D' == Name :
47
+ data = L .Conv2D (
48
+ filters = parameters ['channels' ],
49
+ padding = 'same' ,
50
+ kernel_size = parameters ['kernel size' ],
51
+ activation = parameters ['activation' ],
52
+ name = '%s/conv-%d' % (name , i )
53
+ )(data )
54
+ continue
55
+
56
+ if 'MLP Mixer' == Name :
57
+ data = MixerConvLayer (
58
+ token_mixing = parameters .get ('token mixing' , 512 ),
59
+ channel_mixing = parameters .get ('channel mixing' , 512 ),
60
+ name = '%s/conv-mixer-%d' % (name , i )
61
+ )(data )
62
+ continue
63
+
64
+ if 'Patches' == Name :
65
+ data = Patches (
66
+ patch_size = parameters ['patch size' ],
67
+ name = '%s/patches-%d' % (name , i )
68
+ )(data )
69
+ continue
70
+
71
+ if 'CoordsGrid' == Name :
72
+ parameters = {k : v for k , v in parameters .items () if k not in ['name' ]}
73
+ parameters ['name' ] = '%s/coordsGrid-%d' % (name , i )
74
+ data = CCoordsGridLayer (
75
+ CCoordsEncodingLayer (
76
+ N = parameters .get ('N' , 32 ),
77
+ ** parameters
78
+ ),
79
+ name = '%s/coordsGrid-%d' % (name , i )
80
+ )(data )
81
+ continue
82
+
83
+ if 'Transformer' == Name :
84
+ parameters = {k : v for k , v in parameters .items ()}
85
+ parameters ['name' ] = '%s/transformer-%d' % (name , i )
86
+ parameters ['intermediate_dim' ] = parameters .pop ('intermediate dim' , 512 )
87
+ parameters ['num_heads' ] = parameters .pop ('num heads' , 8 )
88
+ data = TransformerBlock (** parameters )(data )
89
+ continue
90
+
91
+ if 'Reshape' == Name :
92
+ shape = list (parameters ['shape' ])
93
+ for j , sz in enumerate (shape ):
94
+ if sz <= - 2 :
95
+ sz = data .shape [sz + 1 ]
96
+ shape [j ] = sz
97
+ continue
98
+ data = L .Reshape (
99
+ shape ,
100
+ name = '%s/reshape-%d' % (name , i )
101
+ )(data )
102
+ continue
103
+
104
+ if 'MLP' == Name :
105
+ parameters ['name' ] = '%s/mlp-%d' % (name , i )
106
+ data = sMLP (** parameters )(data )
107
+ continue
108
+
109
+ raise NotImplementedError ('Unknown layer: {}' .format (Name ))
110
+ return data
111
+
112
+ def _createGCMv2 (dataShape , config , latentDim , name ):
113
+ data = L .Input (shape = dataShape )
114
+
115
+ res = data
116
+ for i , blockConfig in enumerate (config ['downsample steps' ]):
117
+ # downsample
118
+ res = L .Conv2D (
119
+ filters = blockConfig ['channels' ],
120
+ kernel_size = blockConfig ['kernel size' ],
121
+ strides = 2 ,
122
+ padding = 'same' ,
123
+ activation = 'relu' ,
124
+ name = name + '/downsample-%d' % (i + 1 ,)
125
+ )(res )
126
+ # convolutions
127
+ for layerId in range (blockConfig ['layers' ]):
128
+ res = L .Conv2D (
129
+ filters = blockConfig ['channels' ],
130
+ kernel_size = blockConfig ['kernel size' ],
131
+ padding = 'same' ,
132
+ activation = 'relu' ,
133
+ name = name + '/downsample-%d/layer-%d' % (i + 1 , layerId + 1 )
134
+ )(res )
135
+ continue
136
+ continue
137
+
138
+ return tf .keras .Model (inputs = [data ], outputs = res , name = name )
139
+
140
+ def _createGlobalContextModel (X , config , latentDim , name ):
141
+ model = config .get ('name' , 'v1' )
142
+ if 'v1' == model : # simple convolutional model
143
+ res = conv_block_from_config (
144
+ data = X , config = config , defaults = {
145
+ 'conv before' : 0 , # by default, no convolutions before the last layer
146
+ }
147
+ )
148
+ # calculate global context
149
+ latent = L .Flatten ()(res )
150
+ context = sMLP (sizes = config ['mlp' ], activation = 'relu' , name = name + '/globalMixer' )(latent )
151
+ context = L .Dense (latentDim , activation = config ['final activation' ], name = name + '/dense-latent' )(context )
152
+ return context # end of 'v1' model
153
+
154
+ if 'v2' == model :
155
+ res = data = L .Input (shape = X .shape [1 :])
156
+ res = _createGCMv2 (res .shape [1 :], config , latentDim , name )(res )
157
+ # calculate global context
158
+ latent = L .Flatten ()(res )
159
+ context = sMLP (sizes = config ['mlp' ], activation = 'relu' , name = name + '/globalMixer' )(latent )
160
+ context = L .Dense (latentDim , activation = config ['final activation' ], name = name + '/dense-latent' )(context )
161
+ model = tf .keras .Model (inputs = [data ], outputs = context , name = name )
162
+ return model (X ) # end of 'v2' model
163
+
164
+ raise NotImplementedError ('Unknown global context model: {}' .format (model ))
165
+
166
+ def _withPositionConfig (config , name ):
167
+ if config is None :
168
+ print ('[Encoder] Positions: No' )
169
+ return lambda x , _ : x
170
+
171
+ print ('[Encoder] Positions: Yes' )
172
+
173
+ if isinstance (config , bool ): config = { 'N' : 32 }
174
+ assert isinstance (config , dict ), 'config must be a dictionary'
175
+
176
+ def withPosition (x , i ):
177
+ if not config .get ('stage-%d' % i , True ): return x
178
+
179
+ encoding = config .get ('encoding' , {})
180
+ encoding = dict (** encoding )
181
+ encoding ['N' ] = config .get ('stage-%d N' % i , config .get ('N' , 32 ))
182
+ return CCoordsGridLayer (
183
+ CCoordsEncodingLayer (** encoding , name = '%s/coordsGrid-%d/encoding' % (name , i )),
184
+ name = '%s/coordsGrid-%d' % (name , i )
185
+ )(x )
186
+ return withPosition
187
+
188
+ ##################
189
+ def createEncoderHead_full (
190
+ imgWidth ,
191
+ config ,
192
+ channels , downsampleSteps , latentDim ,
193
+ ConvBeforeStage , ConvAfterStage ,
194
+ localContext , globalContext ,
195
+ positionsConfigs ,
196
+ name
197
+ ):
198
+ assert config is not None , 'config must be a dictionary'
199
+ assert isinstance (downsampleSteps , list ) and (0 < len (downsampleSteps )), 'downsampleSteps must be a list of integers'
200
+ data = L .Input (shape = (imgWidth , imgWidth , channels ))
201
+
202
+ withPosition = _withPositionConfig (positionsConfigs , name )
203
+ res = data
204
+ intermediate = []
205
+ for i , sz in enumerate (downsampleSteps ):
206
+ if config .get ('use downsampling' , True ):
207
+ res = L .Conv2D (sz , 3 , strides = 2 , padding = 'same' , activation = 'relu' )(res )
208
+ res = withPosition (res , i ) # add position encoding if needed
209
+ for _ in range (ConvBeforeStage ):
210
+ res = L .Conv2D (sz , 3 , padding = 'same' , activation = 'relu' )(res )
211
+
212
+ # local context
213
+ if not (localContext is None ):
214
+ intermediate .append (
215
+ conv_block_from_config (
216
+ data = res , config = localContext , defaults = {
217
+ 'channels' : sz ,
218
+ 'channels last' : latentDim , # last layer should have latentDim channels
219
+ },
220
+ name = '%s/intermediate-%d' % (name , i )
221
+ )
222
+ )
223
+ ################################
224
+ for _ in range (ConvAfterStage ):
225
+ res = L .Conv2D (sz , 3 , padding = 'same' , activation = 'relu' )(res )
226
+ continue
227
+
228
+ if not (globalContext is None ): # global context
229
+ res = withPosition (res , len (downsampleSteps ))
230
+ context = _createGlobalContextModel (res , globalContext , latentDim , name + '/globalContext' )
231
+ else : # no global context
232
+ # return dummy context to keep the interface consistent
233
+ context = L .Lambda (
234
+ lambda x : tf .zeros ((tf .shape (x )[0 ], 1 ), dtype = res .dtype )
235
+ )(res )
236
+
237
+ return tf .keras .Model (
238
+ inputs = [data ],
239
+ outputs = {
240
+ 'intermediate' : intermediate , # intermediate representations
241
+ 'context' : context , # global context
242
+ },
243
+ name = name
244
+ )
245
+
246
+ class CEncoderHead (tf .keras .Model ):
247
+ def __init__ (self ,
248
+ config ,
249
+ downsampleSteps , latentDim ,
250
+ ConvBeforeStage , ConvAfterStage ,
251
+ localContext , globalContext ,
252
+ positionsConfigs ,
253
+ ** kwargs
254
+ ):
255
+ super ().__init__ (** kwargs )
256
+ self ._config = config
257
+ self ._downsampleSteps = downsampleSteps
258
+ self ._latentDim = latentDim
259
+ self ._ConvBeforeStage = ConvBeforeStage
260
+ self ._ConvAfterStage = ConvAfterStage
261
+ self ._localContext = localContext
262
+ self ._globalContext = globalContext
263
+ self ._positionsConfigs = positionsConfigs
264
+ return
265
+
266
+ def build (self , inputShape ):
267
+ H , W , C = inputShape [1 :]
268
+ self ._encoderHead = createEncoderHead_full (
269
+ imgWidth = H , config = self ._config ,
270
+ channels = C , downsampleSteps = self ._downsampleSteps , latentDim = self ._latentDim ,
271
+ ConvBeforeStage = self ._ConvBeforeStage , ConvAfterStage = self ._ConvAfterStage ,
272
+ localContext = self ._localContext , globalContext = self ._globalContext ,
273
+ positionsConfigs = self ._positionsConfigs ,
274
+ name = self .name + '/EncoderHead'
275
+ )
276
+ self ._encoderHead .build (inputShape )
277
+ return super ().build (inputShape )
278
+
279
+ def call (self , src , training = None ):
280
+ return self ._encoderHead (src , training = training )
281
+ '''
282
+ Simple encoder that takes image as input and returns corresponding latent vector with intermediate representations
283
+ '''
284
+ def createEncoderHead (
285
+ config ,
286
+ downsampleSteps , latentDim ,
287
+ ConvBeforeStage , ConvAfterStage ,
288
+ localContext , globalContext ,
289
+ positionsConfigs ,
290
+ name
291
+ ):
292
+ return CEncoderHead (
293
+ config = config ,
294
+ downsampleSteps = downsampleSteps ,
295
+ latentDim = latentDim ,
296
+ ConvBeforeStage = ConvBeforeStage ,
297
+ ConvAfterStage = ConvAfterStage ,
298
+ localContext = localContext ,
299
+ globalContext = globalContext ,
300
+ positionsConfigs = positionsConfigs ,
301
+ name = name
302
+ )
0 commit comments