Skip to content

Commit 3a4c3e1

Browse files
misc
1 parent c97ee86 commit 3a4c3e1

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

NN/Nerf2D.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,18 @@ def _createAlgorithmInterceptor(self, interceptor, image, pos):
149149
)
150150
return res.interceptor()
151151
#####################################################
152+
def _withBlur(self, reverseArgs, B):
153+
R = None
154+
if 'blurRadius' in reverseArgs:
155+
R = reverseArgs['blurRadius']
156+
if not tf.is_tensor(R):
157+
R = tf.convert_to_tensor(R, dtype=tf.float32)
158+
R = tf.reshape(R, [1, 1])
159+
R = tf.tile(R, [B, 1])
160+
else:
161+
R = tf.zeros((B, 1), dtype=tf.float32) # let encoder to decide needed it or not
162+
return R
163+
152164
@tf.function
153165
def _inference(
154166
self, src, pos,
@@ -163,16 +175,7 @@ def _inference(
163175
tf.assert_equal(tf.shape(initialValues)[:1], (B, ))
164176
initialValues = tf.reshape(initialValues, (B, N, C))
165177

166-
R = None
167-
if 'blurRadius' in reverseArgs:
168-
R = reverseArgs['blurRadius']
169-
if not tf.is_tensor(R):
170-
R = tf.convert_to_tensor(R, dtype=tf.float32)
171-
R = tf.reshape(R, [1, 1])
172-
R = tf.tile(R, [B, 1])
173-
else:
174-
R = tf.zeros((B, 1), dtype=tf.float32)
175-
pass
178+
R = self._withBlur(reverseArgs, B)
176179
encoded = self._encoder(src, training=False, params=encoderParams, R=R)
177180
def getChunk(ind, sz):
178181
posC = pos[ind:ind+sz]

0 commit comments

Comments
 (0)