Skip to content

Commit 6ea65fb

Browse files
misc
1 parent 0bf7637 commit 6ea65fb

File tree

1 file changed

+8
-22
lines changed

1 file changed

+8
-22
lines changed

NN/Nerf2D.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,12 @@ def test_step(self, images):
132132

133133
def _createAlgorithmInterceptor(self, interceptor, image, pos):
134134
from NN.restorators.samplers.CWatcherWithExtras import CWatcherWithExtras
135-
def interceptorFactory(algorithm):
136-
res = interceptor(algorithm)
137-
res = CWatcherWithExtras(
138-
watcher=res,
139-
converter=self._converter,
140-
residuals=None # residuals applied in the renderer
141-
)
142-
return res
143-
return interceptorFactory
135+
res = CWatcherWithExtras(
136+
watcher=interceptor,
137+
converter=self._converter,
138+
residuals=None # residuals applied in the renderer
139+
)
140+
return res.interceptor()
144141
#####################################################
145142
@tf.function
146143
def _inference(
@@ -156,15 +153,6 @@ def _inference(
156153
tf.assert_equal(tf.shape(initialValues)[:1], (B, ))
157154
initialValues = tf.reshape(initialValues, (B, N, C))
158155

159-
if 'algorithmInterceptor' in reverseArgs: # update algorithm interceptor if provided
160-
newParams = {k: v for k, v in encoderParams.items() if k != 'algorithmInterceptor'}
161-
newParams['algorithmInterceptor'] = self._createAlgorithmInterceptor(
162-
interceptor=reverseArgs['algorithmInterceptor'],
163-
image=src, pos=tf.tile(pos[None], [B, 1, 1])
164-
)
165-
encoderParams = newParams
166-
pass
167-
168156
encoded = self._encoder(src, training=False, params=encoderParams)
169157
def getChunk(ind, sz):
170158
posC = pos[ind:ind+sz]
@@ -233,14 +221,12 @@ def call(self,
233221
reverseArgs = {k: v for k, v in reverseArgs.items() if k != 'encoder'}
234222
# add interceptors if needed
235223
if 'algorithmInterceptor' in reverseArgs:
236-
newParams = {k: v for k, v in encoderParams.items()}
224+
newParams = {k: v for k, v in reverseArgs.items()}
237225
newParams['algorithmInterceptor'] = self._createAlgorithmInterceptor(
238226
interceptor=reverseArgs['algorithmInterceptor'],
239227
image=src, pos=tf.tile(pos[None], [B, 1, 1])
240228
)
241-
encoderParams = newParams
242-
# remove the interceptor from the reverseArgs
243-
reverseArgs = {k: v for k, v in reverseArgs.items() if k != 'algorithmInterceptor'}
229+
reverseArgs = newParams
244230
pass
245231

246232
probes = self._inference(

0 commit comments

Comments
 (0)