Skip to content

Commit 6d66496

Browse files
simplified diffusion model
1 parent 84b5493 commit 6d66496

File tree

3 files changed

+304
-6
lines changed

3 files changed

+304
-6
lines changed

Core/CModelDiffusion.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
import os
2+
import numpy as np
3+
import NN.networks as networks
4+
import tensorflow as tf
5+
import tensorflow_probability as tfp
6+
import NN.Utils as NNU
7+
import time
8+
from tensorflow.keras import layers as L
9+
10+
# TODO: Implement the standard diffusion process (with the prediction of the noise, proper sampling, etc)
11+
class CModelDiffusion:
12+
'''
13+
Wrapper for the diffusion model to predict the gaze point
14+
Diffusion T is equal to the stddev of the gaussian noise
15+
'''
16+
def __init__(self, timesteps, model='simple', user=None, stats=None, use_encoders=False, **kwargs):
17+
if user is None:
18+
user = {
19+
'userId': 0,
20+
'placeId': 0,
21+
'screenId': 0,
22+
}
23+
else:
24+
user = {
25+
'userId': stats['userId'].index(user['userId']),
26+
'placeId': stats['placeId'].index(user['placeId']),
27+
'screenId': stats['screenId'].index(user['screenId']),
28+
}
29+
self._user = user
30+
31+
self._modelID = model
32+
self._timesteps = timesteps
33+
embeddings = {
34+
'userId': len(stats['userId']),
35+
'placeId': len(stats['placeId']),
36+
'screenId': len(stats['screenId']),
37+
'size': 64,
38+
}
39+
self._modelRaw = networks.Face2LatentModel(
40+
steps=timesteps, latentSize=64, embeddings=embeddings,
41+
diffusion=True
42+
)
43+
self._model = self._modelRaw['main']
44+
self._embeddings = {
45+
'userId': L.Embedding(len(stats['userId']), embeddings['size']),
46+
'placeId': L.Embedding(len(stats['placeId']), embeddings['size']),
47+
'screenId': L.Embedding(len(stats['screenId']), embeddings['size']),
48+
}
49+
self._intermediateEncoders = {}
50+
if use_encoders:
51+
shapes = self._modelRaw['intermediate shapes']
52+
for name, shape in shapes.items():
53+
enc = networks.IntermediatePredictor(name='%s-encoder' % name)
54+
enc.build(shape)
55+
self._intermediateEncoders[name] = enc
56+
continue
57+
58+
self._maxDiffusionT = 100.0
59+
if 'weights' in kwargs:
60+
self.load(**kwargs['weights'])
61+
self.compile()
62+
# add signatures to help tensorflow optimize the graph
63+
specification = self._modelRaw['inputs specification']
64+
self._trainStep = tf.function(
65+
self._trainStep,
66+
input_signature=[
67+
(
68+
{ 'clean': specification, 'augmented': specification, },
69+
( tf.TensorSpec(shape=(None, None, None, 2), dtype=tf.float32), )
70+
)
71+
]
72+
)
73+
self._eval = tf.function(
74+
self._eval,
75+
input_signature=[(
76+
specification,
77+
( tf.TensorSpec(shape=(None, None, None, 2), dtype=tf.float32), )
78+
)]
79+
)
80+
81+
return
82+
83+
def _step2mean(self, step):
84+
step = tf.cast(step, tf.float32) / self._maxDiffusionT
85+
step = tf.cast(step, tf.float32) + 1e-6
86+
# step = tf.pow(step, 2.0) # make it decrease faster
87+
return tf.clip_by_value(step, 1e-3, 1.0)
88+
89+
def _replaceByEmbeddings(self, data):
90+
data = dict(**data) # copy
91+
for name, emb in self._embeddings.items():
92+
data[name] = emb(data[name][..., 0])
93+
continue
94+
return data
95+
96+
def _makeGaussian(self, mean, stddev):
97+
stddev = tf.concat([stddev, stddev], axis=-1)
98+
return tfp.distributions.MultivariateNormalDiag(mean, stddev)
99+
100+
@tf.function
101+
def _infer(self, data, training=False):
102+
print('Instantiate _infer')
103+
data = self._replaceByEmbeddings(data)
104+
shp = tf.shape(data['userId'])
105+
B, N = shp[0], self.timesteps
106+
result = tf.zeros((B, N, 2), dtype=tf.float32)
107+
for step in tf.range(self._maxDiffusionT, -1, -5):
108+
mean = self._step2mean(
109+
tf.fill((B, N, 1), step)
110+
)
111+
stepData = dict(**data)
112+
stepData['diffusionT'] = mean
113+
stepData['diffusionPoints'] = tf.random.normal((B, N, 2), mean=result, stddev=mean)
114+
result = self._model(stepData, training=training)['result']
115+
return result
116+
117+
def predict(self, data, **kwargs):
118+
B = self._timesteps
119+
userId = kwargs.get('userId', self._user['userId'])
120+
placeId = kwargs.get('placeId', self._user['placeId'])
121+
screenId = kwargs.get('screenId', self._user['screenId'])
122+
# put them as (1, B, ?)
123+
data['userId'] = np.full((1, B, 1), userId, dtype=np.int32)
124+
data['placeId'] = np.full((1, B, 1), placeId, dtype=np.int32)
125+
data['screenId'] = np.full((1, B, 1), screenId, dtype=np.int32)
126+
127+
data = self._replaceByEmbeddings(data) # replace embeddings
128+
129+
result = self._infer(data)
130+
return result.numpy()
131+
132+
def __call__(self, data, startPos=None):
133+
predictions = self.predict(data)
134+
return {
135+
'coords': predictions[0, -1, :],
136+
}
137+
138+
def compile(self):
139+
self._optimizer = NNU.createOptimizer()
140+
return
141+
142+
def _modelFilename(self, folder, postfix=''):
143+
postfix = '-' + postfix if postfix else ''
144+
return os.path.join(folder, '%s-%s%s.h5' % (self._modelID, 'model', postfix))
145+
146+
def save(self, folder=None, postfix=''):
147+
path = self._modelFilename(folder, postfix)
148+
self._model.save_weights(path)
149+
embeddings = {}
150+
for nm in self._embeddings.keys():
151+
weights = self._embeddings[nm].get_weights()[0]
152+
embeddings[nm] = weights
153+
continue
154+
np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
155+
# save intermediate encoders
156+
if self._intermediateEncoders:
157+
encoders = {}
158+
for nm, encoder in self._intermediateEncoders.items():
159+
# save each variable separately
160+
for ww in encoder.trainable_variables:
161+
encoders['%s-%s' % (nm, ww.name)] = ww.numpy()
162+
continue
163+
np.savez_compressed(path.replace('.h5', '-intermediate-encoders.npz'), **encoders)
164+
return
165+
166+
def load(self, folder=None, postfix='', embeddings=False):
167+
path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
168+
self._model.load_weights(path)
169+
if embeddings:
170+
embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
171+
for nm, emb in self._embeddings.items():
172+
w = embeddings[nm]
173+
if not emb.built: emb.build((None, w.shape[0]))
174+
emb.set_weights([w]) # replace embeddings
175+
continue
176+
177+
if self._intermediateEncoders:
178+
encodersName = path.replace('.h5', '-intermediate-encoders.npz')
179+
if os.path.isfile(encodersName):
180+
encoders = np.load(encodersName)
181+
for nm, encoder in self._intermediateEncoders.items():
182+
for ww in encoder.trainable_variables:
183+
w = encoders['%s-%s' % (nm, ww.name)]
184+
ww.assign(w)
185+
continue
186+
return
187+
188+
def lock(self, isLocked):
189+
self._model.trainable = not isLocked
190+
return
191+
192+
@property
193+
def timesteps(self):
194+
return self._timesteps
195+
196+
def trainable_variables(self):
197+
parts = list(self._embeddings.values()) + [self._model] + list(self._intermediateEncoders.values())
198+
return sum([p.trainable_variables for p in parts], [])
199+
200+
def _pointLoss(self, ytrue, ypred):
201+
# pseudo huber loss
202+
delta = 0.01
203+
tf.assert_equal(tf.shape(ytrue), tf.shape(ypred))
204+
diff = tf.square(ytrue - ypred)
205+
loss = tf.sqrt(diff + delta ** 2) - delta
206+
tf.assert_equal(tf.shape(loss), tf.shape(ytrue))
207+
return tf.reduce_mean(loss, axis=-1)
208+
209+
def _trainStep(self, Data):
210+
print('Instantiate _trainStep')
211+
###############
212+
x, (y, ) = Data
213+
y = y[..., 0, :]
214+
losses = {}
215+
with tf.GradientTape() as tape:
216+
data = x['augmented']
217+
data = self._replaceByEmbeddings(data)
218+
# add sampled T
219+
B = tf.shape(y)[0]
220+
N = self.timesteps
221+
maxT = 100
222+
diffusionT = tf.random.uniform((B, 1), minval=0, maxval=maxT, dtype=tf.int32)
223+
# (B, 1) -> (B, N, 1)
224+
diffusionT = tf.tile(diffusionT, (1, N))[..., None]
225+
diffusionT = self._step2mean(diffusionT)
226+
tf.assert_equal(tf.shape(diffusionT), (B, N, 1))
227+
228+
# store the diffusion parameters
229+
data['diffusionT'] = diffusionT
230+
# sample the points
231+
data['diffusionPoints'] = tf.random.normal((B, N, 2), mean=y, stddev=diffusionT)
232+
predictions = self._model(data, training=True)
233+
# intermediate = predictions['intermediate']
234+
# assert len(intermediate) == 0, 'Intermediate predictions are not supported'
235+
236+
predictedMean = predictions['result']
237+
gaussian = self._makeGaussian(predictedMean, diffusionT)
238+
losses['log_prob'] = tf.reduce_mean(
239+
-gaussian.log_prob(y)
240+
)
241+
losses['points'] = self._pointLoss(y, predictedMean)
242+
loss = sum(losses.values())
243+
losses['loss'] = loss
244+
245+
self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
246+
###############
247+
return losses
248+
249+
def fit(self, data):
250+
t = time.time()
251+
losses = self._trainStep(data)
252+
losses = {k: v.numpy() for k, v in losses.items()}
253+
return {'time': int((time.time() - t) * 1000), 'losses': losses}
254+
255+
def _eval(self, xy):
256+
print('Instantiate _eval')
257+
x, (y,) = xy
258+
y = y[:, :, 0]
259+
B, N = tf.shape(y)[0], tf.shape(y)[1]
260+
261+
predictions = self._infer(x)
262+
263+
mean = self._step2mean(tf.fill((B, N, 1), 0))
264+
gaussian = self._makeGaussian(predictions, mean)
265+
loss = tf.nn.sigmoid( -gaussian.log_prob(y) )
266+
points = predictions
267+
_, dist = NNU.normVec(y - predictions)
268+
return loss, points, dist
269+
270+
def eval(self, data):
271+
loss, sampled, dist = self._eval(data)
272+
return loss.numpy(), sampled.numpy(), dist.numpy()

NN/networks.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ def call(self, T):
2020
return T[..., 0, :]
2121

2222
class IntermediatePredictor(tf.keras.layers.Layer):
23+
def __init__(self, shift=0.5, **kwargs):
24+
super().__init__(**kwargs)
25+
self._shift = shift
26+
return
27+
2328
def build(self, input_shape):
2429
self._mlp = sMLP(
2530
sizes=[128, 64, 32], activation='relu',
@@ -33,7 +38,7 @@ def call(self, x):
3338
B = tf.shape(x)[0]
3439
N = tf.shape(x)[1]
3540
x = self._mlp(x)
36-
x = 0.5 + self._decodePoints(x) # [0, 0] -> [0.5, 0.5]
41+
x = self._shift + self._decodePoints(x) # [0, 0] -> [0.5, 0.5]
3742
tf.assert_equal(tf.shape(x), (B, N, 2))
3843
return x
3944
# End of IntermediatePredictor
@@ -138,7 +143,8 @@ def _InputSpec():
138143

139144
def Face2LatentModel(
140145
pointsN=478, eyeSize=32, steps=None, latentSize=64,
141-
embeddings=None
146+
embeddings=None,
147+
diffusion=False # whether to use diffusion model
142148
):
143149
points = L.Input((steps, pointsN, 2))
144150
eyeL = L.Input((steps, eyeSize, eyeSize, 1))
@@ -149,6 +155,14 @@ def Face2LatentModel(
149155
screenIdEmb = L.Input((steps, embeddings['size']))
150156

151157
emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
158+
if diffusion:
159+
diffusionT = L.Input((steps, 1))
160+
diffusionPoints = L.Input((steps, 2))
161+
encodedDT = CTimeEncoderLayer()(diffusionT)
162+
# shared transformation for all points
163+
encodedDP = CCoordsEncodingLayer(32, sharedTransformation=True)(diffusionPoints)
164+
# add diffusion features to the embeddings
165+
emb = L.Concatenate(-1)([emb, encodedDT, encodedDP])
152166

153167
Face2Step = Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize=emb.shape[-1])
154168
Step2Latent = Step2LatentModel(latentSize, embeddingsSize=emb.shape[-1])
@@ -179,15 +193,25 @@ def Face2LatentModel(
179193
'placeId': placeIdEmb,
180194
'screenId': screenIdEmb,
181195
}
182-
183-
res['result'] = IntermediatePredictor()(res['latent'])
196+
res['result'] = IntermediatePredictor(
197+
shift=0.0 if diffusion else 0.5 # shift points to the center, if not using diffusion
198+
)(
199+
L.Concatenate(-1)([res['latent'], emb])
200+
)
201+
202+
if diffusion:
203+
inputs['diffusionT'] = diffusionT
204+
inputs['diffusionPoints'] = diffusionPoints
205+
# make residuals
206+
res['result'] = diffusionPoints + res['result']
207+
184208
main = tf.keras.Model(inputs=inputs, outputs=res)
185209
return {
186210
'intermediate shapes': {k: v.shape for k, v in res['intermediate'].items()},
187211
'main': main,
188212
'Face2Step': Face2Step,
189213
'Step2Latent': Step2Latent,
190-
'inputs specification': _InputSpec(),
214+
'inputs specification': _InputSpec()
191215
}
192216

193217
if __name__ == '__main__':

scripts/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections import defaultdict
1313
import time
1414
from Core.CModelTrainer import CModelTrainer
15+
from Core.CModelDiffusion import CModelDiffusion
1516
import tqdm
1617
import json
1718
import glob
@@ -186,6 +187,7 @@ def F(epoch):
186187

187188
def _trainer_from(args):
188189
if args.trainer == 'default': return CModelTrainer
190+
if args.trainer == 'diffusion': return CModelDiffusion
189191
raise Exception('Unknown trainer: %s' % (args.trainer, ))
190192

191193
def averageModels(folder, model, noiseStd=0.0):
@@ -330,7 +332,7 @@ def performRandomSearch(epoch=0):
330332
parser.add_argument('--modelId', type=str)
331333
parser.add_argument(
332334
'--trainer', type=str, default='default',
333-
choices=['default']
335+
choices=['default', 'diffusion']
334336
)
335337
parser.add_argument(
336338
'--schedule', type=str, default=None,

0 commit comments

Comments
 (0)