Skip to content

Commit c33735f

Browse files
xxxxxxx
1 parent f9d52c5 commit c33735f

File tree

4 files changed

+45
-19
lines changed

4 files changed

+45
-19
lines changed

NN/RestorationModel/CRestorationModel.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tensorflow as tf
2+
from NN.utils import masked
23

34
'''
45
wrapper which combines the encoders, decoder and restorator together
@@ -56,8 +57,8 @@ def denoiser(x, t, mask=None):
5657
args = dict(condition=latents, coords=EPos, timestep=t, V=x)
5758
residuals = residual
5859
if mask is not None:
59-
args = {k: tf.boolean_mask(v, mask) for k, v in args.items()}
60-
residuals = tf.boolean_mask(residual, mask)
60+
args = {k: masked(v, mask) for k, v in args.items()}
61+
residuals = masked(residual, mask)
6162

6263
res = self._decoder(**args, training=training)
6364
return self._withResidual(res, residuals)

NN/restorators/samplers/CARSampler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _currentValueHandler(threshold):
4545
return lambda step: step.xt
4646

4747
def currentValueF(step, **kwargs):
48-
return tf.boolean_mask(step.xt, step.mask, axis=0)
48+
return NNU.masked(step.xt, step.mask)
4949
return currentValueF
5050

5151
# crate a closure that will be used to postprocess the value

NN/restorators/samplers/CSamplerWatcher.py

+33-16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import tensorflow as tf
22
from NN.utils import is_namedtuple
3+
import NN.utils as NNU
34
from .CSamplingInterceptor import CSamplingInterceptor
45
from .ISamplerWatcher import ISamplerWatcher
56

@@ -44,9 +45,11 @@ def _updateTracked(self, name, value, mask=None, index=None, iteration=None):
4445
if tracked is None: return
4546
src, dest = self._withIndices(
4647
value, index, mask=mask,
47-
masked=(mask is not None) and not ('value' == name)
48+
masked=not ('value' == name)
4849
)
49-
50+
tf.print('-'*80)
51+
tf.print(name, src, dest, summarize=-1)
52+
tf.print('N', tf.reduce_sum(tf.cast(mask, tf.int32)))
5053
self._move(value, src, tracked[iteration], dest)
5154
return
5255

@@ -77,30 +80,44 @@ def _onStart(self, value, kwargs):
7780
return
7881

7982
def _withIndices(self, value, index, mask=None, masked=False):
80-
srcIndex = tf.range(tf.shape(value)[0])
81-
if masked:
82-
srcIndex = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask)
83-
83+
N = tf.shape(value)[0]
84+
srcIndex = tf.range(N)
8485
destIndex = index + srcIndex
85-
if mask is not None:
86-
srcIndex = tf.boolean_mask(srcIndex, mask)
87-
destIndex = tf.boolean_mask(destIndex, mask)
86+
if mask is not None: # use mask
87+
N = tf.reduce_sum(tf.cast(mask, tf.int32))
88+
destIndex = index + tf.cast(tf.where(mask), tf.int32)
89+
destIndex = tf.reshape(destIndex, (N,))
90+
if masked:
91+
rng = tf.range(tf.shape(mask)[0])
92+
srcIndex = NNU.masked(rng, mask)
93+
pass
8894
pass
8995

9096
if self._indices is not None:
91-
mask = tf.reduce_any(self._indices[None] == destIndex[:, None], axis=0)
92-
tf.assert_equal(tf.shape(mask), tf.shape(self._indices))
93-
# collect only valid indices
94-
srcIndex = tf.boolean_mask(self._indices, mask) - index
97+
indices = tf.reshape(self._indices, (1, -1))
98+
destIndex = tf.reshape(destIndex, (-1, 1))
99+
correspondence = indices == destIndex
100+
tf.assert_rank(correspondence, 2)
101+
mask_ = tf.reduce_any(correspondence, axis=0)
102+
tf.assert_equal(tf.shape(mask_), tf.shape(self._indices))
95103
# collect destination indices
96-
destIndex = tf.where(mask)
104+
destIndex = tf.where(mask_)
105+
destIndex = tf.cast(destIndex, tf.int32)
106+
# find corresponding source indices
107+
mask_ = tf.reduce_any(correspondence, axis=-1)
108+
tf.assert_equal(tf.shape(mask), tf.shape(srcIndex))
109+
srcIndex = tf.where(mask_)
110+
srcIndex = tf.cast(srcIndex, tf.int32)
111+
N = tf.reduce_sum(tf.cast(mask_, tf.int32))
97112
pass
98113

99-
srcIndex = tf.reshape(srcIndex, (-1, 1))
100-
destIndex = tf.reshape(destIndex, (-1, 1))
114+
srcIndex = tf.reshape(srcIndex, (N, 1))
115+
destIndex = tf.reshape(destIndex, (N, 1))
101116
return srcIndex, destIndex
102117

103118
def _move(self, src, srcIndex, dest, destIndex):
119+
tf.print(tf.shape(src), tf.shape(srcIndex), tf.shape(dest), tf.shape(destIndex))
120+
tf.print(srcIndex, destIndex, summarize=-1)
104121
src = tf.gather_nd(src, srcIndex) # collect only valid indices
105122
res = tf.tensor_scatter_nd_update(dest, destIndex, src)
106123
dest.assign(res)

NN/utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
import tensorflow_probability as tfp
44
import tensorflow.keras.layers as L
55

6+
def masked(x, mask):
7+
'''
8+
very weird hack to apply a mask to the tensor and ensure that the shape is preserved
9+
'''
10+
N = tf.reduce_sum(tf.cast(mask, tf.int32))
11+
x = tf.boolean_mask(x, mask, axis=0)
12+
return tf.reshape(x, tf.concat([[N], tf.shape(x)[1:]], axis=0))
13+
614
def shuffleBatch(batch):
715
indices = tf.range(tf.shape(batch)[0])
816
indices = tf.random.shuffle(indices)

0 commit comments

Comments
 (0)