@@ -149,6 +149,18 @@ def _createAlgorithmInterceptor(self, interceptor, image, pos):
149
149
)
150
150
return res .interceptor ()
151
151
#####################################################
152
+ def _withBlur (self , reverseArgs , B ):
153
+ R = None
154
+ if 'blurRadius' in reverseArgs :
155
+ R = reverseArgs ['blurRadius' ]
156
+ if not tf .is_tensor (R ):
157
+ R = tf .convert_to_tensor (R , dtype = tf .float32 )
158
+ R = tf .reshape (R , [1 , 1 ])
159
+ R = tf .tile (R , [B , 1 ])
160
+ else :
161
+ R = tf .zeros ((B , 1 ), dtype = tf .float32 ) # let encoder to decide needed it or not
162
+ return R
163
+
152
164
@tf .function
153
165
def _inference (
154
166
self , src , pos ,
@@ -163,16 +175,7 @@ def _inference(
163
175
tf .assert_equal (tf .shape (initialValues )[:1 ], (B , ))
164
176
initialValues = tf .reshape (initialValues , (B , N , C ))
165
177
166
- R = None
167
- if 'blurRadius' in reverseArgs :
168
- R = reverseArgs ['blurRadius' ]
169
- if not tf .is_tensor (R ):
170
- R = tf .convert_to_tensor (R , dtype = tf .float32 )
171
- R = tf .reshape (R , [1 , 1 ])
172
- R = tf .tile (R , [B , 1 ])
173
- else :
174
- R = tf .zeros ((B , 1 ), dtype = tf .float32 )
175
- pass
178
+ R = self ._withBlur (reverseArgs , B )
176
179
encoded = self ._encoder (src , training = False , params = encoderParams , R = R )
177
180
def getChunk (ind , sz ):
178
181
posC = pos [ind :ind + sz ]
0 commit comments