|
1 | 1 | import tensorflow as tf
|
2 | 2 | from NN.utils import is_namedtuple
|
| 3 | +import NN.utils as NNU |
3 | 4 | from .CSamplingInterceptor import CSamplingInterceptor
|
4 | 5 | from .ISamplerWatcher import ISamplerWatcher
|
5 | 6 |
|
@@ -44,9 +45,11 @@ def _updateTracked(self, name, value, mask=None, index=None, iteration=None):
|
44 | 45 | if tracked is None: return
|
45 | 46 | src, dest = self._withIndices(
|
46 | 47 | value, index, mask=mask,
|
47 |
| - masked=(mask is not None) and not ('value' == name) |
| 48 | + masked=not ('value' == name) |
48 | 49 | )
|
49 |
| - |
| 50 | + tf.print('-'*80) |
| 51 | + tf.print(name, src, dest, summarize=-1) |
| 52 | + tf.print('N', tf.reduce_sum(tf.cast(mask, tf.int32))) |
50 | 53 | self._move(value, src, tracked[iteration], dest)
|
51 | 54 | return
|
52 | 55 |
|
@@ -77,30 +80,44 @@ def _onStart(self, value, kwargs):
|
77 | 80 | return
|
78 | 81 |
|
79 | 82 | def _withIndices(self, value, index, mask=None, masked=False):
|
80 |
| - srcIndex = tf.range(tf.shape(value)[0]) |
81 |
| - if masked: |
82 |
| - srcIndex = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask) |
83 |
| - |
| 83 | + N = tf.shape(value)[0] |
| 84 | + srcIndex = tf.range(N) |
84 | 85 | destIndex = index + srcIndex
|
85 |
| - if mask is not None: |
86 |
| - srcIndex = tf.boolean_mask(srcIndex, mask) |
87 |
| - destIndex = tf.boolean_mask(destIndex, mask) |
| 86 | + if mask is not None: # use mask |
| 87 | + N = tf.reduce_sum(tf.cast(mask, tf.int32)) |
| 88 | + destIndex = index + tf.cast(tf.where(mask), tf.int32) |
| 89 | + destIndex = tf.reshape(destIndex, (N,)) |
| 90 | + if masked: |
| 91 | + rng = tf.range(tf.shape(mask)[0]) |
| 92 | + srcIndex = NNU.masked(rng, mask) |
| 93 | + pass |
88 | 94 | pass
|
89 | 95 |
|
90 | 96 | if self._indices is not None:
|
91 |
| - mask = tf.reduce_any(self._indices[None] == destIndex[:, None], axis=0) |
92 |
| - tf.assert_equal(tf.shape(mask), tf.shape(self._indices)) |
93 |
| - # collect only valid indices |
94 |
| - srcIndex = tf.boolean_mask(self._indices, mask) - index |
| 97 | + indices = tf.reshape(self._indices, (1, -1)) |
| 98 | + destIndex = tf.reshape(destIndex, (-1, 1)) |
| 99 | + correspondence = indices == destIndex |
| 100 | + tf.assert_rank(correspondence, 2) |
| 101 | + mask_ = tf.reduce_any(correspondence, axis=0) |
| 102 | + tf.assert_equal(tf.shape(mask_), tf.shape(self._indices)) |
95 | 103 | # collect destination indices
|
96 |
| - destIndex = tf.where(mask) |
| 104 | + destIndex = tf.where(mask_) |
| 105 | + destIndex = tf.cast(destIndex, tf.int32) |
| 106 | + # find corresponding source indices |
| 107 | + mask_ = tf.reduce_any(correspondence, axis=-1) |
| 108 | + tf.assert_equal(tf.shape(mask), tf.shape(srcIndex)) |
| 109 | + srcIndex = tf.where(mask_) |
| 110 | + srcIndex = tf.cast(srcIndex, tf.int32) |
| 111 | + N = tf.reduce_sum(tf.cast(mask_, tf.int32)) |
97 | 112 | pass
|
98 | 113 |
|
99 |
| - srcIndex = tf.reshape(srcIndex, (-1, 1)) |
100 |
| - destIndex = tf.reshape(destIndex, (-1, 1)) |
| 114 | + srcIndex = tf.reshape(srcIndex, (N, 1)) |
| 115 | + destIndex = tf.reshape(destIndex, (N, 1)) |
101 | 116 | return srcIndex, destIndex
|
102 | 117 |
|
103 | 118 | def _move(self, src, srcIndex, dest, destIndex):
|
| 119 | + tf.print(tf.shape(src), tf.shape(srcIndex), tf.shape(dest), tf.shape(destIndex)) |
| 120 | + tf.print(srcIndex, destIndex, summarize=-1) |
104 | 121 | src = tf.gather_nd(src, srcIndex) # collect only valid indices
|
105 | 122 | res = tf.tensor_scatter_nd_update(dest, destIndex, src)
|
106 | 123 | dest.assign(res)
|
|
0 commit comments