@@ -39,23 +39,15 @@ def tracked(self, name):
39
39
res = tf .concat ([self ._initialValue [None ], res ], axis = 0 )
40
40
return res
41
41
42
- def _updateTracked (self , name , value , mask = None , index = None ):
42
+ def _updateTracked (self , name , value , mask = None , index = None , iteration = None ):
43
43
tracked = self ._tracked .get (name , None )
44
44
if tracked is None : return
45
- value , idx = self ._withIndices (value , index , mask = mask )
45
+ src , dest = self ._withIndices (
46
+ value , index , mask = mask ,
47
+ masked = (mask is not None ) and not ('value' == name )
48
+ )
46
49
47
- iteration = self ._iteration
48
- if (mask is not None ) and not ('value' == name ): # 'value' is always unmasked
49
- mask , _ = self ._withIndices (mask , index , mask = mask )
50
- prev = self ._tracked [name ][iteration ]
51
- # expand mask to match the value shape by copying values from the previous iteration
52
- indices = tf .where (mask )
53
- sz = tf .shape (value )[0 ]
54
- value = tf .tensor_scatter_nd_update (prev [idx :idx + sz ], indices , value )
55
- pass
56
-
57
- sz = tf .shape (value )[0 ]
58
- tracked [iteration , idx :idx + sz ].assign (value )
50
+ self ._move (value , src , tracked [iteration ], dest )
59
51
return
60
52
61
53
def _onNextStep (self , iteration , kwargs ):
@@ -69,36 +61,48 @@ def _onNextStep(self, iteration, kwargs):
69
61
mask = step .mask if hasattr (step , 'mask' ) else None
70
62
# iterate over all fields
71
63
for name in solution ._fields :
72
- self ._updateTracked (name , getattr (solution , name ), mask = mask , index = index )
64
+ self ._updateTracked (
65
+ name , getattr (solution , name ),
66
+ mask = mask , index = index , iteration = iteration
67
+ )
73
68
continue
74
69
return
75
70
76
71
def _onStart (self , value , kwargs ):
77
72
index = kwargs ['index' ]
78
73
self ._iteration .assign (0 )
79
74
if 'value' in self ._tracked : # save initial value
80
- value , idx = self ._withIndices (value , index )
81
- # update slice [index:index+sz] with the value
82
- sz = tf .shape (value )[0 ]
83
- self ._initialValue [idx :idx + sz ].assign (value )
75
+ src , dest = self ._withIndices (value , index )
76
+ self ._move (value , src , self ._initialValue , dest )
84
77
return
85
78
86
- def _withIndices (self , value , index , mask = None ):
87
- if self . _indices is None : return value , index
88
- # find subset of indices
89
- sz = tf .shape (value )[0 ]
79
+ def _withIndices (self , value , index , mask = None , masked = False ):
80
+ srcIndex = tf . range ( tf . shape ( value )[ 0 ])
81
+ if masked :
82
+ srcIndex = tf .boolean_mask ( tf . range ( tf . shape (mask )[0 ]), mask )
90
83
91
- validMask = tf . logical_and ( index <= self . _indices , self . _indices < index + sz )
84
+ destIndex = index + srcIndex
92
85
if mask is not None :
93
- maskedIndices = tf .range (sz )
94
- maskedIndices = tf .boolean_mask (maskedIndices , mask ) + index
95
- # exclude masked indices
96
- maskedIndices = tf .reduce_any (maskedIndices [:, None ] == self ._indices [None ], axis = 0 )
97
- validMask = tf .logical_and (validMask , maskedIndices )
86
+ srcIndex = tf .boolean_mask (srcIndex , mask )
87
+ destIndex = tf .boolean_mask (destIndex , mask )
98
88
pass
99
-
100
- startIndex = tf .reduce_min (tf .where (validMask ))
101
- startIndex = tf .cast (startIndex , tf .int32 )
102
- indices = tf .boolean_mask (self ._indices , validMask ) - index
103
- return tf .gather (value , indices , axis = 0 ), startIndex
89
+
90
+ if self ._indices is not None :
91
+ mask = tf .reduce_any (self ._indices [None ] == destIndex [:, None ], axis = 0 )
92
+ tf .assert_equal (tf .shape (mask ), tf .shape (self ._indices ))
93
+ # collect only valid indices
94
+ srcIndex = tf .boolean_mask (self ._indices , mask ) - index
95
+ # collect destination indices
96
+ destIndex = tf .where (mask )
97
+ pass
98
+
99
+ srcIndex = tf .reshape (srcIndex , (- 1 , 1 ))
100
+ destIndex = tf .reshape (destIndex , (- 1 , 1 ))
101
+ return srcIndex , destIndex
102
+
103
+ def _move (self , src , srcIndex , dest , destIndex ):
104
+ src = tf .gather_nd (src , srcIndex ) # collect only valid indices
105
+ res = tf .tensor_scatter_nd_update (dest , destIndex , src )
106
+ dest .assign (res )
107
+ return res
104
108
# End of CSamplerWatcher
0 commit comments