Skip to content

Commit f9d52c5

Browse files
fix
1 parent 6ea65fb commit f9d52c5

File tree

2 files changed

+58
-48
lines changed

2 files changed

+58
-48
lines changed

NN/restorators/samplers/CSamplerWatcher.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,15 @@ def tracked(self, name):
3939
res = tf.concat([self._initialValue[None], res], axis=0)
4040
return res
4141

42-
def _updateTracked(self, name, value, mask=None, index=None):
42+
def _updateTracked(self, name, value, mask=None, index=None, iteration=None):
4343
tracked = self._tracked.get(name, None)
4444
if tracked is None: return
45-
value, idx = self._withIndices(value, index, mask=mask)
45+
src, dest = self._withIndices(
46+
value, index, mask=mask,
47+
masked=(mask is not None) and not ('value' == name)
48+
)
4649

47-
iteration = self._iteration
48-
if (mask is not None) and not ('value' == name): # 'value' is always unmasked
49-
mask, _ = self._withIndices(mask, index, mask=mask)
50-
prev = self._tracked[name][iteration]
51-
# expand mask to match the value shape by copying values from the previous iteration
52-
indices = tf.where(mask)
53-
sz = tf.shape(value)[0]
54-
value = tf.tensor_scatter_nd_update(prev[idx:idx+sz], indices, value)
55-
pass
56-
57-
sz = tf.shape(value)[0]
58-
tracked[iteration, idx:idx+sz].assign(value)
50+
self._move(value, src, tracked[iteration], dest)
5951
return
6052

6153
def _onNextStep(self, iteration, kwargs):
@@ -69,36 +61,48 @@ def _onNextStep(self, iteration, kwargs):
6961
mask = step.mask if hasattr(step, 'mask') else None
7062
# iterate over all fields
7163
for name in solution._fields:
72-
self._updateTracked(name, getattr(solution, name), mask=mask, index=index)
64+
self._updateTracked(
65+
name, getattr(solution, name),
66+
mask=mask, index=index, iteration=iteration
67+
)
7368
continue
7469
return
7570

7671
def _onStart(self, value, kwargs):
7772
index = kwargs['index']
7873
self._iteration.assign(0)
7974
if 'value' in self._tracked: # save initial value
80-
value, idx = self._withIndices(value, index)
81-
# update slice [index:index+sz] with the value
82-
sz = tf.shape(value)[0]
83-
self._initialValue[idx:idx+sz].assign(value)
75+
src, dest = self._withIndices(value, index)
76+
self._move(value, src, self._initialValue, dest)
8477
return
8578

86-
def _withIndices(self, value, index, mask=None):
87-
if self._indices is None: return value, index
88-
# find subset of indices
89-
sz = tf.shape(value)[0]
79+
def _withIndices(self, value, index, mask=None, masked=False):
80+
srcIndex = tf.range(tf.shape(value)[0])
81+
if masked:
82+
srcIndex = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask)
9083

91-
validMask = tf.logical_and(index <= self._indices, self._indices < index + sz)
84+
destIndex = index + srcIndex
9285
if mask is not None:
93-
maskedIndices = tf.range(sz)
94-
maskedIndices = tf.boolean_mask(maskedIndices, mask) + index
95-
# exclude masked indices
96-
maskedIndices = tf.reduce_any(maskedIndices[:, None] == self._indices[None], axis=0)
97-
validMask = tf.logical_and(validMask, maskedIndices)
86+
srcIndex = tf.boolean_mask(srcIndex, mask)
87+
destIndex = tf.boolean_mask(destIndex, mask)
9888
pass
99-
100-
startIndex = tf.reduce_min(tf.where(validMask))
101-
startIndex = tf.cast(startIndex, tf.int32)
102-
indices = tf.boolean_mask(self._indices, validMask) - index
103-
return tf.gather(value, indices, axis=0), startIndex
89+
90+
if self._indices is not None:
91+
mask = tf.reduce_any(self._indices[None] == destIndex[:, None], axis=0)
92+
tf.assert_equal(tf.shape(mask), tf.shape(self._indices))
93+
# collect only valid indices
94+
srcIndex = tf.boolean_mask(self._indices, mask) - index
95+
# collect destination indices
96+
destIndex = tf.where(mask)
97+
pass
98+
99+
srcIndex = tf.reshape(srcIndex, (-1, 1))
100+
destIndex = tf.reshape(destIndex, (-1, 1))
101+
return srcIndex, destIndex
102+
103+
def _move(self, src, srcIndex, dest, destIndex):
104+
src = tf.gather_nd(src, srcIndex) # collect only valid indices
105+
res = tf.tensor_scatter_nd_update(dest, destIndex, src)
106+
dest.assign(res)
107+
return res
104108
# End of CSamplerWatcher

tests/test_CSamplerWatcher.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def fakeModel(x, T, **kwargs):
2727
def _fake_AR(threshold, timesteps=10):
2828
interpolant = sampler_from_config({
2929
"name": "autoregressive",
30-
"noise provider": "normal",
30+
"noise provider": "zero",
3131
"threshold": threshold,
3232
"steps": {
3333
"start": 1.0,
@@ -146,26 +146,31 @@ def test_trackSolutionWithMask():
146146
return
147147

148148
def test_trackSolutionWithMask_value():
149-
fake = _fake_AR(threshold=0.1)
149+
fake = _fake_AR(threshold=0.5)
150150
watcher = CSamplerWatcher(
151151
steps=10,
152-
tracked=dict(value=(32, 3))
152+
tracked=dict(value=(32, 3), x0=(32, 3), x1=(32, 3))
153153
)
154154
_ = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
155155
_checkTracked(watcher.tracked('value'), N=11)
156+
_checkTracked(watcher.tracked('x0'), N=10)
157+
_checkTracked(watcher.tracked('x1'), N=10)
156158
return
157159

158160
# test multiple calls with index
159161
def test_multipleCallsWithIndex():
160-
fake = _fake_sampler()
162+
fake = _fake_AR(threshold=0.1)
161163
watcher = CSamplerWatcher(
162164
steps=10,
163165
tracked=dict(value=(32*3, 3))
164166
)
165167
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
166-
_ = fake.interpolant.sample(**arg, index=0)
167-
_ = fake.interpolant.sample(**arg, index=32)
168-
_ = fake.interpolant.sample(**arg, index=64)
168+
A = fake.interpolant.sample(**arg, index=0)
169+
B = fake.interpolant.sample(**arg, index=32)
170+
C = fake.interpolant.sample(**arg, index=64)
171+
172+
tf.debugging.assert_equal(A, B)
173+
tf.debugging.assert_equal(A, C)
169174

170175
collectedSteps = watcher.tracked('value')
171176
tf.debugging.assert_equal(tf.shape(collectedSteps)[1], 96, 'Must collect 96 values')
@@ -182,16 +187,17 @@ def test_multipleCallsWithIndexAndMask():
182187
watcher = CSamplerWatcher(
183188
steps=10,
184189
tracked=dict(value=(3,)),
185-
indices=[0, 32, 64, 65]
190+
indices=[0, 32]
186191
)
187192
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
188-
_ = fake.interpolant.sample(**arg, index=0)
189-
_ = fake.interpolant.sample(**arg, index=32)
190-
_ = fake.interpolant.sample(**arg, index=64)
193+
A = fake.interpolant.sample(**arg, index=0)
194+
B = fake.interpolant.sample(**arg, index=32)
195+
C = fake.interpolant.sample(**arg, index=64)
191196

192197
collectedSteps = watcher.tracked('value')
193-
tf.assert_equal(tf.shape(collectedSteps), (11, 4, 3), 'Must be (11, 4, 3)')
194-
tf.debugging.assert_equal(watcher.iteration, 10, 'Must collect 10 steps')
198+
tf.assert_equal(tf.shape(collectedSteps), (11, 2, 3), 'Must be (11, 4, 3)')
199+
tf.assert_equal(A, B)
200+
tf.assert_equal(A, C)
195201
tf.debugging.assert_equal(collectedSteps[:, 0:1], collectedSteps[:, 1:2])
196-
tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
202+
# tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
197203
return

0 commit comments

Comments
 (0)