Skip to content

Commit 4d942fd

Browse files
remove tf.stop_gradient
1 parent 74630f9 commit 4d942fd

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

NN/RestorationModel/CRepeatedRestorator.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def call(self, latents, pos, T, V, residual, training=None):
2828

2929
def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
3030
for i in range(self._N):
31-
if tf.is_tensor(value): value = tf.stop_gradient(value)
3231
value = self._restorator.reverse(
3332
latents=self._withID(latents, i, training),
3433
pos=pos, reverseArgs=reverseArgs,
@@ -50,11 +49,11 @@ def train_step(self, x0, latents, positions, params, xT=None):
5049

5150
for i in range(1, self._N):
5251
trainStep = self._restorator.train_step(
53-
x0=tf.stop_gradient(x0),
54-
xT=tf.stop_gradient(xT),
55-
latents=self._withID(tf.stop_gradient(latents), i, training=True),
56-
positions=tf.stop_gradient(positions),
57-
params={k: tf.stop_gradient(v) if tf.is_tensor(v) else v for k, v in params.items()}
52+
x0=x0,
53+
xT=xT,
54+
latents=self._withID(latents, i, training=True),
55+
positions=positions,
56+
params=params
5857
)
5958
loss += trainStep['loss']
6059
xT = tf.stop_gradient(trainStep['value'])

NN/RestorationModel/CSequentialRestorator.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def call(self, latents, pos, T, V, residual, training=None):
1717

1818
def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
1919
for restorator in self._restorators:
20-
if tf.is_tensor(value): value = tf.stop_gradient(value)
2120
value = restorator.reverse(
2221
latents=latents, pos=pos, reverseArgs=reverseArgs,
2322
residual=residual,
@@ -37,11 +36,11 @@ def train_step(self, x0, latents, positions, params, xT=None):
3736
# the rest of the restorators are trained sequentially
3837
for restorator in self._restorators:
3938
trainStep = restorator.train_step(
40-
x0=tf.stop_gradient(x0),
41-
xT=tf.stop_gradient(xT),
42-
latents=tf.stop_gradient(latents),
43-
positions=tf.stop_gradient(positions),
44-
params={k: tf.stop_gradient(v) if isinstance(v, tf.Tensor) else v for k, v in params.items()}
39+
x0=x0,
40+
xT=xT,
41+
latents=latents,
42+
positions=positions,
43+
params=params
4544
)
4645
loss += trainStep['loss']
4746
xT = tf.stop_gradient(trainStep['value'])

0 commit comments

Comments
 (0)