@@ -132,15 +132,12 @@ def test_step(self, images):
132
132
133
133
def _createAlgorithmInterceptor (self , interceptor , image , pos ):
134
134
from NN .restorators .samplers .CWatcherWithExtras import CWatcherWithExtras
135
- def interceptorFactory (algorithm ):
136
- res = interceptor (algorithm )
137
- res = CWatcherWithExtras (
138
- watcher = res ,
139
- converter = self ._converter ,
140
- residuals = None # residuals applied in the renderer
141
- )
142
- return res
143
- return interceptorFactory
135
+ res = CWatcherWithExtras (
136
+ watcher = interceptor ,
137
+ converter = self ._converter ,
138
+ residuals = None # residuals applied in the renderer
139
+ )
140
+ return res .interceptor ()
144
141
#####################################################
145
142
@tf .function
146
143
def _inference (
@@ -156,15 +153,6 @@ def _inference(
156
153
tf .assert_equal (tf .shape (initialValues )[:1 ], (B , ))
157
154
initialValues = tf .reshape (initialValues , (B , N , C ))
158
155
159
- if 'algorithmInterceptor' in reverseArgs : # update algorithm interceptor if provided
160
- newParams = {k : v for k , v in encoderParams .items () if k != 'algorithmInterceptor' }
161
- newParams ['algorithmInterceptor' ] = self ._createAlgorithmInterceptor (
162
- interceptor = reverseArgs ['algorithmInterceptor' ],
163
- image = src , pos = tf .tile (pos [None ], [B , 1 , 1 ])
164
- )
165
- encoderParams = newParams
166
- pass
167
-
168
156
encoded = self ._encoder (src , training = False , params = encoderParams )
169
157
def getChunk (ind , sz ):
170
158
posC = pos [ind :ind + sz ]
@@ -233,14 +221,12 @@ def call(self,
233
221
reverseArgs = {k : v for k , v in reverseArgs .items () if k != 'encoder' }
234
222
# add interceptors if needed
235
223
if 'algorithmInterceptor' in reverseArgs :
236
- newParams = {k : v for k , v in encoderParams .items ()}
224
+ newParams = {k : v for k , v in reverseArgs .items ()}
237
225
newParams ['algorithmInterceptor' ] = self ._createAlgorithmInterceptor (
238
226
interceptor = reverseArgs ['algorithmInterceptor' ],
239
227
image = src , pos = tf .tile (pos [None ], [B , 1 , 1 ])
240
228
)
241
- encoderParams = newParams
242
- # remove the interceptor from the reverseArgs
243
- reverseArgs = {k : v for k , v in reverseArgs .items () if k != 'algorithmInterceptor' }
229
+ reverseArgs = newParams
244
230
pass
245
231
246
232
probes = self ._inference (
0 commit comments