Skip to content

Commit 0bf7637

Browse files
some fixes
1 parent 0b9f38b commit 0bf7637

9 files changed

+95
-32
lines changed

NN/Renderer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ def batched(self, ittr, B, N, batchSize=None, training=False):
2828
for i in tf.range(NBatches):
2929
index = i * stepBy
3030
data = ittr(index, stepBy)
31-
V = self._restorator.reverse(**data, training=training)
31+
V = self._restorator.reverse(**data, training=training, index=index)
3232
C = tf.shape(V)[-1]
3333
res = res.write(i, tf.reshape(V, (B, stepBy, C)))
3434
continue
3535
#################
3636
index = NBatches * stepBy
3737

3838
data = ittr(index, N - index)
39-
V = self._restorator.reverse(**data, training=training)
39+
V = self._restorator.reverse(**data, training=training, index=index)
4040
C = tf.shape(V)[-1]
4141

4242
w = N - index

NN/RestorationModel/CRepeatedRestorator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ def call(self, latents, pos, T, V, residual, training=None):
2626
continue
2727
return V
2828

29-
def reverse(self, latents, pos, reverseArgs, training, value, residual):
29+
def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
3030
for i in range(self._N):
3131
if tf.is_tensor(value): value = tf.stop_gradient(value)
3232
value = self._restorator.reverse(
3333
latents=self._withID(latents, i, training),
3434
pos=pos, reverseArgs=reverseArgs,
3535
residual=residual,
36-
training=training, value=value
36+
training=training, value=value, index=index
3737
)
3838
continue
3939
return value

NN/RestorationModel/CRestorationModel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def call(self, latents, pos, T, V, residual, training=None):
4949
res = self._decoder(condition=latents, coords=EPos, timestep=t, V=V, training=training)
5050
return self._withResidual(res, residual)
5151

52-
def reverse(self, latents, pos, reverseArgs, training, value, residual):
52+
def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
5353
EPos = self._encodePos(pos, training=training, args=reverseArgs.get('decoder', {}))
5454

5555
def denoiser(x, t, mask=None):
@@ -65,7 +65,8 @@ def denoiser(x, t, mask=None):
6565
return self._restorator.reverse(
6666
value=value, denoiser=denoiser,
6767
modelT=lambda t: self._encodeTime(t, training=training),
68-
**reverseArgs
68+
**reverseArgs,
69+
index=index
6970
)
7071

7172
def train_step(self, x0, latents, positions, params, xT=None):

NN/RestorationModel/CSequentialRestorator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ def call(self, latents, pos, T, V, residual, training=None):
1515
continue
1616
return V
1717

18-
def reverse(self, latents, pos, reverseArgs, training, value, residual):
18+
def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
1919
for restorator in self._restorators:
2020
if tf.is_tensor(value): value = tf.stop_gradient(value)
2121
value = restorator.reverse(
2222
latents=latents, pos=pos, reverseArgs=reverseArgs,
2323
residual=residual,
24-
training=training, value=value
24+
training=training, value=value, index=index
2525
)
2626
continue
2727
return value

NN/restorators/CARProcess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ def denoiser(x, t=None, **kwargs):
3939
return model(x=x, t=T, mask=kwargs.get('mask', None))[:, :self.predictions]
4040
return denoiser
4141

42-
def reverse(self, value, denoiser, modelT=None, **kwargs):
42+
def reverse(self, value, denoiser, modelT=None, index=0, **kwargs):
4343
if isinstance(value, tuple):
4444
value = self._sourceDistribution.initialValueFor(value + (self.predictions, ))
4545

4646
denoiser = self._makeDenoiser(denoiser, modelT)
47-
res = self._sampler.sample(value=value, model=denoiser, **kwargs)
47+
res = self._sampler.sample(value=value, model=denoiser, index=index, **kwargs)
4848
tf.assert_equal(tf.shape(res), tf.shape(value))
4949
return res
5050

NN/restorators/samplers/CBasicInterpolantSampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ def __init__(self, interpolant, algorithm):
1212
def interpolant(self): return self._interpolant
1313

1414
@tf.function
15-
def sample(self, value, model, **kwargs):
16-
kwargs = dict(**kwargs, interpolant=self._interpolant) # add interpolant to kwargs
15+
def sample(self, value, model, index=0, **kwargs):
16+
# add interpolant to kwargs and index
17+
kwargs = dict(**kwargs, interpolant=self._interpolant, index=index)
1718
# wrap algorithm with hook, if provided
1819
algorithm = kwargs.get('algorithmInterceptor', lambda x: x)( self._algorithm )
1920
assert isinstance(algorithm, ISamplingAlgorithm), f'Algorithm must be an instance of ISamplingAlgorithm, but got {type(algorithm)}'

NN/restorators/samplers/CSamplerWatcher.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,27 @@ def tracked(self, name):
3939
res = tf.concat([self._initialValue[None], res], axis=0)
4040
return res
4141

42-
def _updateTracked(self, name, value, mask=None):
42+
def _updateTracked(self, name, value, mask=None, index=None):
4343
tracked = self._tracked.get(name, None)
4444
if tracked is None: return
45-
value = self._withIndices(value)
45+
value, idx = self._withIndices(value, index, mask=mask)
4646

4747
iteration = self._iteration
48-
if (mask is None) or ('value' == name): # 'value' is always unmasked
49-
tracked[iteration].assign(value)
50-
return
51-
52-
mask = self._withIndices(mask)
53-
prev = tracked[iteration - 1]
54-
# expand mask to match the value shape by copying values from the previous iteration
55-
indices = tf.where(mask)
56-
value = tf.tensor_scatter_nd_update(prev, indices, value)
57-
tf.assert_equal(tf.shape(prev), tf.shape(value), 'Must be the same shape')
58-
tracked[iteration].assign(value)
48+
if (mask is not None) and not ('value' == name): # 'value' is always unmasked
49+
mask, _ = self._withIndices(mask, index, mask=mask)
50+
prev = self._tracked[name][iteration]
51+
# expand mask to match the value shape by copying values from the previous iteration
52+
indices = tf.where(mask)
53+
sz = tf.shape(value)[0]
54+
value = tf.tensor_scatter_nd_update(prev[idx:idx+sz], indices, value)
55+
pass
56+
57+
sz = tf.shape(value)[0]
58+
tracked[iteration, idx:idx+sz].assign(value)
5959
return
6060

6161
def _onNextStep(self, iteration, kwargs):
62+
index = kwargs['index']
6263
self._iteration.assign(iteration)
6364
# track also solution
6465
solution = kwargs['solution']
@@ -68,17 +69,36 @@ def _onNextStep(self, iteration, kwargs):
6869
mask = step.mask if hasattr(step, 'mask') else None
6970
# iterate over all fields
7071
for name in solution._fields:
71-
self._updateTracked(name, getattr(solution, name), mask=mask)
72+
self._updateTracked(name, getattr(solution, name), mask=mask, index=index)
7273
continue
7374
return
7475

7576
def _onStart(self, value, kwargs):
77+
index = kwargs['index']
7678
self._iteration.assign(0)
7779
if 'value' in self._tracked: # save initial value
78-
self._initialValue.assign( self._withIndices(value) )
80+
value, idx = self._withIndices(value, index)
81+
# update slice [index:index+sz] with the value
82+
sz = tf.shape(value)[0]
83+
self._initialValue[idx:idx+sz].assign(value)
7984
return
8085

81-
def _withIndices(self, value):
82-
if self._indices is None: return value
83-
return tf.gather(value, self._indices, axis=0)
86+
def _withIndices(self, value, index, mask=None):
87+
if self._indices is None: return value, index
88+
# find subset of indices
89+
sz = tf.shape(value)[0]
90+
91+
validMask = tf.logical_and(index <= self._indices, self._indices < index + sz)
92+
if mask is not None:
93+
maskedIndices = tf.range(sz)
94+
maskedIndices = tf.boolean_mask(maskedIndices, mask) + index
95+
# exclude masked indices
96+
maskedIndices = tf.reduce_any(maskedIndices[:, None] == self._indices[None], axis=0)
97+
validMask = tf.logical_and(validMask, maskedIndices)
98+
pass
99+
100+
startIndex = tf.reduce_min(tf.where(validMask))
101+
startIndex = tf.cast(startIndex, tf.int32)
102+
indices = tf.boolean_mask(self._indices, validMask) - index
103+
return tf.gather(value, indices, axis=0), startIndex
84104
# End of CSamplerWatcher

huggingface/HF/NN/CInterpolantVisualization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _collectSteps(self, points, initialValues, kwargs):
158158
initialValues=initialValues,
159159
reverseArgs=dict(
160160
**reverseArgs,
161-
algorithmInterceptor=watcher.interceptor(),
161+
algorithmInterceptor=watcher,
162162
),
163163
)
164164
N = watcher.iteration + 1

tests/test_CSamplerWatcher.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _fake_AR(threshold, timesteps=10):
4141
shape = (32, 3)
4242
fakeNoise = tf.random.normal(shape)
4343
def fakeModel(x, t, mask, **kwargs):
44-
s = tf.boolean_mask(fakeNoise, mask)
44+
s = tf.boolean_mask(fakeNoise, mask) if mask is not None else fakeNoise
4545
return s + tf.cast(t, tf.float32) * x
4646

4747
x = tf.random.normal(shape)
@@ -153,4 +153,45 @@ def test_trackSolutionWithMask_value():
153153
)
154154
_ = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
155155
_checkTracked(watcher.tracked('value'), N=11)
156+
return
157+
158+
# test multiple calls with index
159+
def test_multipleCallsWithIndex():
160+
fake = _fake_sampler()
161+
watcher = CSamplerWatcher(
162+
steps=10,
163+
tracked=dict(value=(32*3, 3))
164+
)
165+
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
166+
_ = fake.interpolant.sample(**arg, index=0)
167+
_ = fake.interpolant.sample(**arg, index=32)
168+
_ = fake.interpolant.sample(**arg, index=64)
169+
170+
collectedSteps = watcher.tracked('value')
171+
tf.debugging.assert_equal(tf.shape(collectedSteps)[1], 96, 'Must collect 96 values')
172+
tf.debugging.assert_equal(watcher.iteration, 10, 'Must collect 10 steps')
173+
# values must be same across (0..32), (32..64), (64..96)
174+
tf.debugging.assert_equal(collectedSteps[:, :32], collectedSteps[:, 32:64])
175+
tf.debugging.assert_equal(collectedSteps[:, 32:64], collectedSteps[:, 64:])
176+
return
177+
178+
# TODO: find out why this test fails, but previous one passes
179+
# test multiple calls with index and mask
180+
def test_multipleCallsWithIndexAndMask():
181+
fake = _fake_AR(threshold=0.1)
182+
watcher = CSamplerWatcher(
183+
steps=10,
184+
tracked=dict(value=(3,)),
185+
indices=[0, 32, 64, 65]
186+
)
187+
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
188+
_ = fake.interpolant.sample(**arg, index=0)
189+
_ = fake.interpolant.sample(**arg, index=32)
190+
_ = fake.interpolant.sample(**arg, index=64)
191+
192+
collectedSteps = watcher.tracked('value')
193+
tf.assert_equal(tf.shape(collectedSteps), (11, 4, 3), 'Must be (11, 4, 3)')
194+
tf.debugging.assert_equal(watcher.iteration, 10, 'Must collect 10 steps')
195+
tf.debugging.assert_equal(collectedSteps[:, 0:1], collectedSteps[:, 1:2])
196+
tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
156197
return

0 commit comments

Comments
 (0)