7
7
class CSamplerWatcher (ISamplerWatcher ):
8
8
def __init__ (self , steps , tracked , indices = None ):
9
9
super ().__init__ ()
10
- self ._indices = tf .constant (indices , dtype = tf .int32 ) if not (indices is None ) else None
11
10
self ._tracked = {}
12
11
prefix = [steps ]
13
- if not (self ._indices is None ):
12
+ self ._indices = None
13
+ if not (indices is None ):
14
+ self ._indices = tf .reshape (tf .constant (indices , dtype = tf .int32 ), (1 , - 1 ))
14
15
prefix = [steps , tf .size (self ._indices )]
15
16
16
17
for name , shape in tracked .items ():
@@ -20,6 +21,7 @@ def __init__(self, steps, tracked, indices=None):
20
21
21
22
if 'value' in self ._tracked : # value has steps + 1 shape, so we need extra variable
22
23
shp = prefix + list (tracked ['value' ])
24
+ print ('Initial value shape:' , shp )
23
25
self ._initialValue = tf .Variable (tf .zeros (shp [1 :]), trainable = False )
24
26
pass
25
27
@@ -43,13 +45,13 @@ def tracked(self, name):
43
45
def _updateTracked (self , name , value , mask = None , index = None , iteration = None ):
44
46
tracked = self ._tracked .get (name , None )
45
47
if tracked is None : return
46
- src , dest = self ._withIndices (
48
+ src , dest , unchangedIdx = self ._withIndices (
47
49
value , index , mask = mask ,
48
50
masked = not ('value' == name )
49
51
)
50
- tf . print ( '-' * 80 )
51
- tf . print ( name , src , dest , summarize = - 1 )
52
- tf .print ('N' , tf . reduce_sum ( tf . cast ( mask , tf . int32 )) )
52
+ prev = self . tracked ( name )[ iteration - 1 ]
53
+ self . _move ( prev , unchangedIdx , tracked [ iteration ], unchangedIdx )
54
+ tf .print (unchangedIdx )
53
55
self ._move (value , src , tracked [iteration ], dest )
54
56
return
55
57
@@ -75,51 +77,49 @@ def _onStart(self, value, kwargs):
75
77
index = kwargs ['index' ]
76
78
self ._iteration .assign (0 )
77
79
if 'value' in self ._tracked : # save initial value
78
- src , dest = self ._withIndices (value , index )
80
+ src , dest , _ = self ._withIndices (value , index )
79
81
self ._move (value , src , self ._initialValue , dest )
80
82
return
81
83
82
84
def _withIndices (self , value , index , mask = None , masked = False ):
83
- N = tf .shape (value )[0 ]
84
- srcIndex = tf .range (N )
85
- destIndex = index + srcIndex
85
+ unchanged = tf .constant ([], dtype = tf .int32 )
86
+ srcIndex = tf .range (tf .shape (value )[0 ])
86
87
if mask is not None : # use mask
87
- N = tf .reduce_sum (tf .cast (mask , tf .int32 ))
88
- destIndex = index + tf .cast (tf .where (mask ), tf .int32 )
89
- destIndex = tf .reshape (destIndex , (N ,))
90
- if masked :
91
- rng = tf .range (tf .shape (mask )[0 ])
92
- srcIndex = NNU .masked (rng , mask )
93
- pass
94
- pass
88
+ unchanged = tf .logical_not (mask )
89
+ unchanged = tf .cast (tf .where (unchanged ), tf .int32 ) + index
95
90
91
+ if not masked :
92
+ whereIdx = tf .where (mask )
93
+ srcIndex = tf .cast (whereIdx , tf .int32 )
94
+ pass
95
+ destIndex = index + srcIndex
96
+
96
97
if self ._indices is not None :
97
- indices = tf .reshape (self ._indices , (1 , - 1 ))
98
- destIndex = tf .reshape (destIndex , (- 1 , 1 ))
99
- correspondence = indices == destIndex
100
- tf .assert_rank (correspondence , 2 )
101
- mask_ = tf .reduce_any (correspondence , axis = 0 )
102
- tf .assert_equal (tf .shape (mask_ ), tf .shape (self ._indices ))
103
- # collect destination indices
104
- destIndex = tf .where (mask_ )
105
- destIndex = tf .cast (destIndex , tf .int32 )
106
- # find corresponding source indices
107
- mask_ = tf .reduce_any (correspondence , axis = - 1 )
108
- tf .assert_equal (tf .shape (mask ), tf .shape (srcIndex ))
109
- srcIndex = tf .where (mask_ )
110
- srcIndex = tf .cast (srcIndex , tf .int32 )
111
- N = tf .reduce_sum (tf .cast (mask_ , tf .int32 ))
98
+ unchanged = self ._index2index (unchanged , axis = 0 )
99
+ srcIndex = self ._index2index (destIndex , axis = 1 )
100
+ destIndex = self ._index2index (destIndex , axis = 0 )
112
101
pass
113
-
114
- srcIndex = tf .reshape (srcIndex , (N , 1 ))
115
- destIndex = tf .reshape (destIndex , (N , 1 ))
116
- return srcIndex , destIndex
102
+ return srcIndex , destIndex , unchanged
117
103
118
104
def _move (self , src , srcIndex , dest , destIndex ):
119
- tf .print (tf .shape (src ), tf .shape (srcIndex ), tf .shape (dest ), tf .shape (destIndex ))
120
- tf .print (srcIndex , destIndex , summarize = - 1 )
105
+ # tensor_scatter_nd_update can't handle empty indices, so we need to check it
106
+ if tf .size (srcIndex ) == 0 : return dest
107
+
108
+ srcIndex = tf .reshape (srcIndex , (- 1 , 1 ))
109
+ destIndex = tf .reshape (destIndex , (- 1 , 1 ))
121
110
src = tf .gather_nd (src , srcIndex ) # collect only valid indices
122
111
res = tf .tensor_scatter_nd_update (dest , destIndex , src )
123
112
dest .assign (res )
124
113
return res
114
+
115
+ def _index2index (self , indices , axis = 0 ):
116
+ N = tf .size (indices )
117
+ NN = tf .size (self ._indices )
118
+ indices = tf .reshape (indices , (- 1 , 1 ))
119
+ indices = tf .cast (indices , tf .int32 )
120
+ correspondence = self ._indices == indices
121
+ tf .assert_equal (tf .shape (correspondence ), (N , NN ))
122
+ mask = tf .reduce_any (correspondence , axis = axis )
123
+ res = tf .where (mask )
124
+ return tf .cast (res , tf .int32 )
125
125
# End of CSamplerWatcher
0 commit comments