Skip to content

Commit aefa449

Browse files
residual conditional
1 parent 4d942fd commit aefa449

File tree

4 files changed

+12
-0
lines changed

4 files changed

+12
-0
lines changed

NN/RestorationModel/CRepeatedRestorator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def train_step(self, x0, latents, positions, params, xT=None):
4848
xT = tf.stop_gradient(trainStep['value'])
4949

5050
for i in range(1, self._N):
51+
params['residual'] = xT
5152
trainStep = self._restorator.train_step(
5253
x0=x0,
5354
xT=xT,

NN/RestorationModel/CRestorationModel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class CRestorationModel(tf.keras.Model):
99
def __init__(self,
1010
decoder, restorator,
1111
posEncoder, timeEncoder,
12+
residualCondition=False,
1213
**kwargs
1314
):
1415
assert posEncoder is not None, "posEncoder is not provided"
@@ -18,6 +19,7 @@ def __init__(self,
1819
self._restorator = restorator
1920
self._posEncoder = posEncoder
2021
self._timeEncoder = timeEncoder
22+
self._residualCondition = residualCondition
2123
return
2224

2325
def _encodeTime(self, t, training):
@@ -60,6 +62,13 @@ def denoiser(x, t, mask=None):
6062
args = {k: masked(v, mask) for k, v in args.items()}
6163
residuals = masked(residual, mask)
6264

65+
if self._residualCondition: # add residuals to the condition
66+
B = tf.shape(residuals)[0]
67+
args['condition'] = tf.reshape( # ensure that we know the shape of the condition
68+
tf.concat([args['condition'], residuals], axis=-1),
69+
[B, -1]
70+
)
71+
6372
res = self._decoder(**args, training=training)
6473
return self._withResidual(res, residuals)
6574

NN/RestorationModel/CSequentialRestorator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def train_step(self, x0, latents, positions, params, xT=None):
3535
xT = tf.stop_gradient(trainStep['value'])
3636
# the rest of the restorators are trained sequentially
3737
for restorator in self._restorators:
38+
params['residual'] = xT
3839
trainStep = restorator.train_step(
3940
x0=x0,
4041
xT=xT,

NN/RestorationModel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def restorationModel_from_config(config):
3131
restorator=restorator,
3232
posEncoder=encoding_from_config(config['position encoding']),
3333
timeEncoder=encoding_from_config(config['time encoding']),
34+
residualCondition=config.get('residual condition', False),
3435
)
3536

3637
if 'repeated' == name:

0 commit comments

Comments
 (0)