1
+ import os
2
+ import numpy as np
3
+ import NN .networks as networks
4
+ import tensorflow as tf
5
+ import tensorflow_probability as tfp
6
+ import NN .Utils as NNU
7
+ import time
8
+ from tensorflow .keras import layers as L
9
+
10
+ # TODO: Implement the standard diffusion process (with the prediction of the noise, proper sampling, etc)
11
+ class CModelDiffusion :
12
+ '''
13
+ Wrapper for the diffusion model to predict the gaze point
14
+ Diffusion T is equal to the stddev of the gaussian noise
15
+ '''
16
+ def __init__ (self , timesteps , model = 'simple' , user = None , stats = None , use_encoders = False , ** kwargs ):
17
+ if user is None :
18
+ user = {
19
+ 'userId' : 0 ,
20
+ 'placeId' : 0 ,
21
+ 'screenId' : 0 ,
22
+ }
23
+ else :
24
+ user = {
25
+ 'userId' : stats ['userId' ].index (user ['userId' ]),
26
+ 'placeId' : stats ['placeId' ].index (user ['placeId' ]),
27
+ 'screenId' : stats ['screenId' ].index (user ['screenId' ]),
28
+ }
29
+ self ._user = user
30
+
31
+ self ._modelID = model
32
+ self ._timesteps = timesteps
33
+ embeddings = {
34
+ 'userId' : len (stats ['userId' ]),
35
+ 'placeId' : len (stats ['placeId' ]),
36
+ 'screenId' : len (stats ['screenId' ]),
37
+ 'size' : 64 ,
38
+ }
39
+ self ._modelRaw = networks .Face2LatentModel (
40
+ steps = timesteps , latentSize = 64 , embeddings = embeddings ,
41
+ diffusion = True
42
+ )
43
+ self ._model = self ._modelRaw ['main' ]
44
+ self ._embeddings = {
45
+ 'userId' : L .Embedding (len (stats ['userId' ]), embeddings ['size' ]),
46
+ 'placeId' : L .Embedding (len (stats ['placeId' ]), embeddings ['size' ]),
47
+ 'screenId' : L .Embedding (len (stats ['screenId' ]), embeddings ['size' ]),
48
+ }
49
+ self ._intermediateEncoders = {}
50
+ if use_encoders :
51
+ shapes = self ._modelRaw ['intermediate shapes' ]
52
+ for name , shape in shapes .items ():
53
+ enc = networks .IntermediatePredictor (name = '%s-encoder' % name )
54
+ enc .build (shape )
55
+ self ._intermediateEncoders [name ] = enc
56
+ continue
57
+
58
+ self ._maxDiffusionT = 100.0
59
+ if 'weights' in kwargs :
60
+ self .load (** kwargs ['weights' ])
61
+ self .compile ()
62
+ # add signatures to help tensorflow optimize the graph
63
+ specification = self ._modelRaw ['inputs specification' ]
64
+ self ._trainStep = tf .function (
65
+ self ._trainStep ,
66
+ input_signature = [
67
+ (
68
+ { 'clean' : specification , 'augmented' : specification , },
69
+ ( tf .TensorSpec (shape = (None , None , None , 2 ), dtype = tf .float32 ), )
70
+ )
71
+ ]
72
+ )
73
+ self ._eval = tf .function (
74
+ self ._eval ,
75
+ input_signature = [(
76
+ specification ,
77
+ ( tf .TensorSpec (shape = (None , None , None , 2 ), dtype = tf .float32 ), )
78
+ )]
79
+ )
80
+
81
+ return
82
+
83
+ def _step2mean (self , step ):
84
+ step = tf .cast (step , tf .float32 ) / self ._maxDiffusionT
85
+ step = tf .cast (step , tf .float32 ) + 1e-6
86
+ # step = tf.pow(step, 2.0) # make it decrease faster
87
+ return tf .clip_by_value (step , 1e-3 , 1.0 )
88
+
89
+ def _replaceByEmbeddings (self , data ):
90
+ data = dict (** data ) # copy
91
+ for name , emb in self ._embeddings .items ():
92
+ data [name ] = emb (data [name ][..., 0 ])
93
+ continue
94
+ return data
95
+
96
+ def _makeGaussian (self , mean , stddev ):
97
+ stddev = tf .concat ([stddev , stddev ], axis = - 1 )
98
+ return tfp .distributions .MultivariateNormalDiag (mean , stddev )
99
+
100
+ @tf .function
101
+ def _infer (self , data , training = False ):
102
+ print ('Instantiate _infer' )
103
+ data = self ._replaceByEmbeddings (data )
104
+ shp = tf .shape (data ['userId' ])
105
+ B , N = shp [0 ], self .timesteps
106
+ result = tf .zeros ((B , N , 2 ), dtype = tf .float32 )
107
+ for step in tf .range (self ._maxDiffusionT , - 1 , - 5 ):
108
+ mean = self ._step2mean (
109
+ tf .fill ((B , N , 1 ), step )
110
+ )
111
+ stepData = dict (** data )
112
+ stepData ['diffusionT' ] = mean
113
+ stepData ['diffusionPoints' ] = tf .random .normal ((B , N , 2 ), mean = result , stddev = mean )
114
+ result = self ._model (stepData , training = training )['result' ]
115
+ return result
116
+
117
+ def predict (self , data , ** kwargs ):
118
+ B = self ._timesteps
119
+ userId = kwargs .get ('userId' , self ._user ['userId' ])
120
+ placeId = kwargs .get ('placeId' , self ._user ['placeId' ])
121
+ screenId = kwargs .get ('screenId' , self ._user ['screenId' ])
122
+ # put them as (1, B, ?)
123
+ data ['userId' ] = np .full ((1 , B , 1 ), userId , dtype = np .int32 )
124
+ data ['placeId' ] = np .full ((1 , B , 1 ), placeId , dtype = np .int32 )
125
+ data ['screenId' ] = np .full ((1 , B , 1 ), screenId , dtype = np .int32 )
126
+
127
+ data = self ._replaceByEmbeddings (data ) # replace embeddings
128
+
129
+ result = self ._infer (data )
130
+ return result .numpy ()
131
+
132
+ def __call__ (self , data , startPos = None ):
133
+ predictions = self .predict (data )
134
+ return {
135
+ 'coords' : predictions [0 , - 1 , :],
136
+ }
137
+
138
+ def compile (self ):
139
+ self ._optimizer = NNU .createOptimizer ()
140
+ return
141
+
142
+ def _modelFilename (self , folder , postfix = '' ):
143
+ postfix = '-' + postfix if postfix else ''
144
+ return os .path .join (folder , '%s-%s%s.h5' % (self ._modelID , 'model' , postfix ))
145
+
146
+ def save (self , folder = None , postfix = '' ):
147
+ path = self ._modelFilename (folder , postfix )
148
+ self ._model .save_weights (path )
149
+ embeddings = {}
150
+ for nm in self ._embeddings .keys ():
151
+ weights = self ._embeddings [nm ].get_weights ()[0 ]
152
+ embeddings [nm ] = weights
153
+ continue
154
+ np .savez_compressed (path .replace ('.h5' , '-embeddings.npz' ), ** embeddings )
155
+ # save intermediate encoders
156
+ if self ._intermediateEncoders :
157
+ encoders = {}
158
+ for nm , encoder in self ._intermediateEncoders .items ():
159
+ # save each variable separately
160
+ for ww in encoder .trainable_variables :
161
+ encoders ['%s-%s' % (nm , ww .name )] = ww .numpy ()
162
+ continue
163
+ np .savez_compressed (path .replace ('.h5' , '-intermediate-encoders.npz' ), ** encoders )
164
+ return
165
+
166
+ def load (self , folder = None , postfix = '' , embeddings = False ):
167
+ path = self ._modelFilename (folder , postfix ) if not os .path .isfile (folder ) else folder
168
+ self ._model .load_weights (path )
169
+ if embeddings :
170
+ embeddings = np .load (path .replace ('.h5' , '-embeddings.npz' ))
171
+ for nm , emb in self ._embeddings .items ():
172
+ w = embeddings [nm ]
173
+ if not emb .built : emb .build ((None , w .shape [0 ]))
174
+ emb .set_weights ([w ]) # replace embeddings
175
+ continue
176
+
177
+ if self ._intermediateEncoders :
178
+ encodersName = path .replace ('.h5' , '-intermediate-encoders.npz' )
179
+ if os .path .isfile (encodersName ):
180
+ encoders = np .load (encodersName )
181
+ for nm , encoder in self ._intermediateEncoders .items ():
182
+ for ww in encoder .trainable_variables :
183
+ w = encoders ['%s-%s' % (nm , ww .name )]
184
+ ww .assign (w )
185
+ continue
186
+ return
187
+
188
+ def lock (self , isLocked ):
189
+ self ._model .trainable = not isLocked
190
+ return
191
+
192
+ @property
193
+ def timesteps (self ):
194
+ return self ._timesteps
195
+
196
+ def trainable_variables (self ):
197
+ parts = list (self ._embeddings .values ()) + [self ._model ] + list (self ._intermediateEncoders .values ())
198
+ return sum ([p .trainable_variables for p in parts ], [])
199
+
200
+ def _pointLoss (self , ytrue , ypred ):
201
+ # pseudo huber loss
202
+ delta = 0.01
203
+ tf .assert_equal (tf .shape (ytrue ), tf .shape (ypred ))
204
+ diff = tf .square (ytrue - ypred )
205
+ loss = tf .sqrt (diff + delta ** 2 ) - delta
206
+ tf .assert_equal (tf .shape (loss ), tf .shape (ytrue ))
207
+ return tf .reduce_mean (loss , axis = - 1 )
208
+
209
+ def _trainStep (self , Data ):
210
+ print ('Instantiate _trainStep' )
211
+ ###############
212
+ x , (y , ) = Data
213
+ y = y [..., 0 , :]
214
+ losses = {}
215
+ with tf .GradientTape () as tape :
216
+ data = x ['augmented' ]
217
+ data = self ._replaceByEmbeddings (data )
218
+ # add sampled T
219
+ B = tf .shape (y )[0 ]
220
+ N = self .timesteps
221
+ maxT = 100
222
+ diffusionT = tf .random .uniform ((B , 1 ), minval = 0 , maxval = maxT , dtype = tf .int32 )
223
+ # (B, 1) -> (B, N, 1)
224
+ diffusionT = tf .tile (diffusionT , (1 , N ))[..., None ]
225
+ diffusionT = self ._step2mean (diffusionT )
226
+ tf .assert_equal (tf .shape (diffusionT ), (B , N , 1 ))
227
+
228
+ # store the diffusion parameters
229
+ data ['diffusionT' ] = diffusionT
230
+ # sample the points
231
+ data ['diffusionPoints' ] = tf .random .normal ((B , N , 2 ), mean = y , stddev = diffusionT )
232
+ predictions = self ._model (data , training = True )
233
+ # intermediate = predictions['intermediate']
234
+ # assert len(intermediate) == 0, 'Intermediate predictions are not supported'
235
+
236
+ predictedMean = predictions ['result' ]
237
+ gaussian = self ._makeGaussian (predictedMean , diffusionT )
238
+ losses ['log_prob' ] = tf .reduce_mean (
239
+ - gaussian .log_prob (y )
240
+ )
241
+ losses ['points' ] = self ._pointLoss (y , predictedMean )
242
+ loss = sum (losses .values ())
243
+ losses ['loss' ] = loss
244
+
245
+ self ._optimizer .minimize (loss , tape .watched_variables (), tape = tape )
246
+ ###############
247
+ return losses
248
+
249
+ def fit (self , data ):
250
+ t = time .time ()
251
+ losses = self ._trainStep (data )
252
+ losses = {k : v .numpy () for k , v in losses .items ()}
253
+ return {'time' : int ((time .time () - t ) * 1000 ), 'losses' : losses }
254
+
255
+ def _eval (self , xy ):
256
+ print ('Instantiate _eval' )
257
+ x , (y ,) = xy
258
+ y = y [:, :, 0 ]
259
+ B , N = tf .shape (y )[0 ], tf .shape (y )[1 ]
260
+
261
+ predictions = self ._infer (x )
262
+
263
+ mean = self ._step2mean (tf .fill ((B , N , 1 ), 0 ))
264
+ gaussian = self ._makeGaussian (predictions , mean )
265
+ loss = tf .nn .sigmoid ( - gaussian .log_prob (y ) )
266
+ points = predictions
267
+ _ , dist = NNU .normVec (y - predictions )
268
+ return loss , points , dist
269
+
270
+ def eval (self , data ):
271
+ loss , sampled , dist = self ._eval (data )
272
+ return loss .numpy (), sampled .numpy (), dist .numpy ()
0 commit comments