Skip to content

Commit 3809acc

Browse files
Finally its works!
1 parent c33735f commit 3809acc

File tree

2 files changed

+73
-49
lines changed

2 files changed

+73
-49
lines changed

NN/restorators/samplers/CSamplerWatcher.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
class CSamplerWatcher(ISamplerWatcher):
88
def __init__(self, steps, tracked, indices=None):
99
super().__init__()
10-
self._indices = tf.constant(indices, dtype=tf.int32) if not(indices is None) else None
1110
self._tracked = {}
1211
prefix = [steps]
13-
if not(self._indices is None):
12+
self._indices = None
13+
if not(indices is None):
14+
self._indices = tf.reshape(tf.constant(indices, dtype=tf.int32), (1, -1))
1415
prefix = [steps, tf.size(self._indices)]
1516

1617
for name, shape in tracked.items():
@@ -20,6 +21,7 @@ def __init__(self, steps, tracked, indices=None):
2021

2122
if 'value' in self._tracked: # value has steps + 1 shape, so we need extra variable
2223
shp = prefix + list(tracked['value'])
24+
print('Initial value shape:', shp)
2325
self._initialValue = tf.Variable(tf.zeros(shp[1:]), trainable=False)
2426
pass
2527

@@ -43,13 +45,13 @@ def tracked(self, name):
4345
def _updateTracked(self, name, value, mask=None, index=None, iteration=None):
4446
tracked = self._tracked.get(name, None)
4547
if tracked is None: return
46-
src, dest = self._withIndices(
48+
src, dest, unchangedIdx = self._withIndices(
4749
value, index, mask=mask,
4850
masked=not ('value' == name)
4951
)
50-
tf.print('-'*80)
51-
tf.print(name, src, dest, summarize=-1)
52-
tf.print('N', tf.reduce_sum(tf.cast(mask, tf.int32)))
52+
prev = self.tracked(name)[iteration - 1]
53+
self._move(prev, unchangedIdx, tracked[iteration], unchangedIdx)
54+
tf.print(unchangedIdx)
5355
self._move(value, src, tracked[iteration], dest)
5456
return
5557

@@ -75,51 +77,49 @@ def _onStart(self, value, kwargs):
7577
index = kwargs['index']
7678
self._iteration.assign(0)
7779
if 'value' in self._tracked: # save initial value
78-
src, dest = self._withIndices(value, index)
80+
src, dest, _ = self._withIndices(value, index)
7981
self._move(value, src, self._initialValue, dest)
8082
return
8183

8284
def _withIndices(self, value, index, mask=None, masked=False):
83-
N = tf.shape(value)[0]
84-
srcIndex = tf.range(N)
85-
destIndex = index + srcIndex
85+
unchanged = tf.constant([], dtype=tf.int32)
86+
srcIndex = tf.range(tf.shape(value)[0])
8687
if mask is not None: # use mask
87-
N = tf.reduce_sum(tf.cast(mask, tf.int32))
88-
destIndex = index + tf.cast(tf.where(mask), tf.int32)
89-
destIndex = tf.reshape(destIndex, (N,))
90-
if masked:
91-
rng = tf.range(tf.shape(mask)[0])
92-
srcIndex = NNU.masked(rng, mask)
93-
pass
94-
pass
88+
unchanged = tf.logical_not(mask)
89+
unchanged = tf.cast(tf.where(unchanged), tf.int32) + index
9590

91+
if not masked:
92+
whereIdx = tf.where(mask)
93+
srcIndex = tf.cast(whereIdx, tf.int32)
94+
pass
95+
destIndex = index + srcIndex
96+
9697
if self._indices is not None:
97-
indices = tf.reshape(self._indices, (1, -1))
98-
destIndex = tf.reshape(destIndex, (-1, 1))
99-
correspondence = indices == destIndex
100-
tf.assert_rank(correspondence, 2)
101-
mask_ = tf.reduce_any(correspondence, axis=0)
102-
tf.assert_equal(tf.shape(mask_), tf.shape(self._indices))
103-
# collect destination indices
104-
destIndex = tf.where(mask_)
105-
destIndex = tf.cast(destIndex, tf.int32)
106-
# find corresponding source indices
107-
mask_ = tf.reduce_any(correspondence, axis=-1)
108-
tf.assert_equal(tf.shape(mask), tf.shape(srcIndex))
109-
srcIndex = tf.where(mask_)
110-
srcIndex = tf.cast(srcIndex, tf.int32)
111-
N = tf.reduce_sum(tf.cast(mask_, tf.int32))
98+
unchanged = self._index2index(unchanged, axis=0)
99+
srcIndex = self._index2index(destIndex, axis=1)
100+
destIndex = self._index2index(destIndex, axis=0)
112101
pass
113-
114-
srcIndex = tf.reshape(srcIndex, (N, 1))
115-
destIndex = tf.reshape(destIndex, (N, 1))
116-
return srcIndex, destIndex
102+
return srcIndex, destIndex, unchanged
117103

118104
def _move(self, src, srcIndex, dest, destIndex):
119-
tf.print(tf.shape(src), tf.shape(srcIndex), tf.shape(dest), tf.shape(destIndex))
120-
tf.print(srcIndex, destIndex, summarize=-1)
105+
# tensor_scatter_nd_update can't handle empty indices, so we need to check it
106+
if tf.size(srcIndex) == 0: return dest
107+
108+
srcIndex = tf.reshape(srcIndex, (-1, 1))
109+
destIndex = tf.reshape(destIndex, (-1, 1))
121110
src = tf.gather_nd(src, srcIndex) # collect only valid indices
122111
res = tf.tensor_scatter_nd_update(dest, destIndex, src)
123112
dest.assign(res)
124113
return res
114+
115+
def _index2index(self, indices, axis=0):
116+
N = tf.size(indices)
117+
NN = tf.size(self._indices)
118+
indices = tf.reshape(indices, (-1, 1))
119+
indices = tf.cast(indices, tf.int32)
120+
correspondence = self._indices == indices
121+
tf.assert_equal(tf.shape(correspondence), (N, NN))
122+
mask = tf.reduce_any(correspondence, axis=axis)
123+
res = tf.where(mask)
124+
return tf.cast(res, tf.int32)
125125
# End of CSamplerWatcher

tests/test_CSamplerWatcher.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from Utils.utils import CFakeObject
33
from NN.restorators.samplers import sampler_from_config
44
from NN.restorators.samplers.CSamplerWatcher import CSamplerWatcher
5+
from NN.utils import masked
56

67
def _fake_sampler(stochasticity=1.0, timesteps=10):
78
interpolant = sampler_from_config({
@@ -24,7 +25,7 @@ def fakeModel(x, T, **kwargs):
2425
x = tf.random.normal(shape)
2526
return CFakeObject(x=x, model=fakeModel, interpolant=interpolant)
2627

27-
def _fake_AR(threshold, timesteps=10):
28+
def _fake_AR(threshold, timesteps=10, scale=1.0):
2829
interpolant = sampler_from_config({
2930
"name": "autoregressive",
3031
"noise provider": "zero",
@@ -41,8 +42,13 @@ def _fake_AR(threshold, timesteps=10):
4142
shape = (32, 3)
4243
fakeNoise = tf.random.normal(shape)
4344
def fakeModel(x, t, mask, **kwargs):
44-
s = tf.boolean_mask(fakeNoise, mask) if mask is not None else fakeNoise
45-
return s + tf.cast(t, tf.float32) * x
45+
s = fakeNoise
46+
if mask is not None:
47+
s = masked(fakeNoise, mask)
48+
t = masked(t, mask)
49+
x = masked(x, mask)
50+
51+
return s + tf.cast(t, tf.float32) * x * scale
4652

4753
x = tf.random.normal(shape)
4854
return CFakeObject(x=x, model=fakeModel, interpolant=interpolant)
@@ -159,7 +165,7 @@ def test_trackSolutionWithMask_value():
159165

160166
# test multiple calls with index
161167
def test_multipleCallsWithIndex():
162-
fake = _fake_AR(threshold=0.1)
168+
fake = _fake_sampler()
163169
watcher = CSamplerWatcher(
164170
steps=10,
165171
tracked=dict(value=(32*3, 3))
@@ -180,24 +186,42 @@ def test_multipleCallsWithIndex():
180186
tf.debugging.assert_equal(collectedSteps[:, 32:64], collectedSteps[:, 64:])
181187
return
182188

183-
# TODO: find out why this test fails, but previous one passes
184189
# test multiple calls with index and mask
185190
def test_multipleCallsWithIndexAndMask():
186191
fake = _fake_AR(threshold=0.1)
187192
watcher = CSamplerWatcher(
188193
steps=10,
189194
tracked=dict(value=(3,)),
190-
indices=[0, 32]
195+
indices=[0, 32, 64, 65]
191196
)
192197
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
193198
A = fake.interpolant.sample(**arg, index=0)
194199
B = fake.interpolant.sample(**arg, index=32)
195200
C = fake.interpolant.sample(**arg, index=64)
196-
197-
collectedSteps = watcher.tracked('value')
198-
tf.assert_equal(tf.shape(collectedSteps), (11, 2, 3), 'Must be (11, 4, 3)')
199201
tf.assert_equal(A, B)
200202
tf.assert_equal(A, C)
203+
204+
collectedSteps = watcher.tracked('value')[:watcher.iteration]
205+
tf.assert_equal(tf.shape(collectedSteps)[1:], (4, 3), 'Must be (4, 3)')
201206
tf.debugging.assert_equal(collectedSteps[:, 0:1], collectedSteps[:, 1:2])
202-
# tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
207+
tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
208+
return
209+
210+
# test that masked values aren't zeroed
211+
def test_maskedValues():
212+
fake = _fake_AR(threshold=1e+5, scale=0.0)
213+
watcher = CSamplerWatcher(
214+
steps=10,
215+
tracked=dict(value=(32, 3)),
216+
)
217+
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
218+
_ = fake.interpolant.sample(**arg, index=0)
219+
220+
collectedSteps = watcher.tracked('value')
221+
afterMask = collectedSteps[3:]
222+
beforeMask = collectedSteps[2]
223+
224+
tf.debugging.assert_greater(3, watcher.iteration, 'Must collect 3 steps')
225+
for i in range(3, watcher.iteration + 1):
226+
tf.debugging.assert_equal(afterMask[i], beforeMask, 'Must be equal')
203227
return

0 commit comments

Comments
 (0)