Skip to content

Commit ed74a15

Browse files
added normalisation
removed useless branch Improved training script
1 parent c4501c0 commit ed74a15

File tree

3 files changed

+35
-53
lines changed

3 files changed

+35
-53
lines changed

NN/Utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def __init__(self, mlp=None, **kwargs):
232232
super().__init__(**kwargs)
233233
if mlp is None: mlp = lambda x: x
234234
self._mlp = mlp
235+
self._norm = L.LayerNormalization()
235236
return
236237

237238
def build(self, input_shapes):
@@ -243,6 +244,7 @@ def build(self, input_shapes):
243244
def call(self, x):
244245
assert isinstance(x, list), "expected list of inputs"
245246
xhat = tf.concat(x, axis=-1)
247+
xhat = self._norm(xhat)
246248
xhat = self._mlp(xhat)
247249
xhat = self._lastDense(xhat)
248250
x0 = x[0]

NN/networks.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -185,36 +185,7 @@ def Face2LatentModel(
185185
IP = lambda x: IntermediatePredictor()(x) # own IntermediatePredictor for each output
186186
res['intermediate'] = {k: IP(x) for k, x in intermediate.items()}
187187
res['result'] = IP(res['latent'])
188-
###################################
189-
# TODO: figure out is this helpful or not
190-
# branch for global coordinates transformation
191-
# predict shift, rotation, scale
192-
emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
193-
emb = sMLP(sizes=[64, 64, 64, 64, 32], activation='relu')(emb[:, 0])
194-
shift = L.Dense(2, name='GlobalShift')(emb)[:, None]
195-
rotation = L.Dense(1, name='GlobalRotation', activation='sigmoid')(emb)[:, None] * np.pi
196-
scale = L.Dense(2, name='GlobalScale')(emb)[:, None]
197-
198-
shifted = res['result'] + shift - 0.5 # [0.5, 0.5] -> [0, 0]
199-
# Rotation matrix components
200-
cos_rotation = L.Lambda(lambda x: tf.cos(x))(rotation)
201-
sin_rotation = L.Lambda(lambda x: tf.sin(x))(rotation)
202-
rotation_matrix = L.Lambda(lambda x: tf.stack([x[0], x[1]], axis=-1))([cos_rotation, sin_rotation])
203-
204-
# Apply rotation
205-
rotated = L.Lambda(
206-
lambda x: tf.einsum('isj,iomj->isj', x[0], x[1])
207-
)([shifted, rotation_matrix]) + 0.5 # [0, 0] -> [0.5, 0.5] back
208-
209-
# Apply scale
210-
scaled = rotated * scale
211-
def clipWithGradient(x):
212-
res = tf.clip_by_value(x, 0.0, 1.0)
213-
return x + tf.stop_gradient(res - x)
214188

215-
res['result'] = L.Lambda(clipWithGradient)(scaled)
216-
###################################
217-
218189
main = tf.keras.Model(inputs=inputs, outputs=res)
219190
return {
220191
'main': main,

scripts/train.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -239,31 +239,52 @@ def main(args):
239239
model = trainer(**model)
240240
model._model.summary()
241241

242-
if args.average:
243-
averageModels(folder, model)
244242
# find folders with the name "/test-*/"
245243
evalDatasets = [
246244
CTestLoader(nm)
247245
for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/'))
248246
]
249247
eval = evaluator(evalDatasets, model, folder, args)
250-
bestLoss = eval()
248+
bestLoss = eval() # evaluate loaded model
251249
bestEpoch = 0
250+
# wrapper for the evaluation function. It saves the model if it is better
251+
def evalWrapper(eval):
252+
def f(epoch, onlyImproved=False):
253+
nonlocal bestLoss, bestEpoch
254+
newLoss = eval(onlyImproved=onlyImproved)
255+
if newLoss < bestLoss:
256+
print('Improved %.5f => %.5f' % (bestLoss, newLoss))
257+
bestLoss = newLoss
258+
bestEpoch = epoch
259+
model.save(folder, postfix='best')
260+
return
261+
return f
262+
263+
eval = evalWrapper(eval)
264+
265+
def performRandomSearch(epoch=0):
266+
nonlocal bestLoss, bestEpoch
267+
averageModels(folder, model, noiseStd=0.0)
268+
eval(epoch=epoch, onlyImproved=True) # evaluate the averaged model
269+
for _ in range(args.restarts):
270+
# and add some noise
271+
averageModels(folder, model, noiseStd=args.noise)
272+
# re-evaluate the model with the new weights
273+
eval(epoch=epoch, onlyImproved=True)
274+
continue
275+
return
276+
277+
if args.average:
278+
performRandomSearch()
279+
252280
trainStep = _modelTrainingLoop(model, trainDataset)
253281
for epoch in range(args.epochs):
254282
trainStep(
255283
desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
256284
sampleParams=getSampleParams(epoch)
257285
)
258286
model.save(folder, postfix='latest')
259-
260-
testLoss = eval()
261-
if testLoss < bestLoss:
262-
print('Improved %.5f => %.5f' % (bestLoss, testLoss))
263-
bestLoss = testLoss
264-
bestEpoch = epoch
265-
model.save(folder, postfix='best')
266-
continue
287+
eval(epoch)
267288

268289
print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
269290
if args.patience <= (epoch - bestEpoch):
@@ -272,19 +293,7 @@ def main(args):
272293
break
273294
if 'reset' == args.on_patience:
274295
print('Resetting the model to the average of the best models')
275-
bestEpoch = epoch # reset the best epoch
276-
for _ in range(args.restarts):
277-
# and add some noise
278-
averageModels(folder, model, noiseStd=args.noise)
279-
# re-evaluate the model with the new weights
280-
testLoss = eval(onlyImproved=True)
281-
if testLoss < bestLoss:
282-
print('Improved %.5f => %.5f' % (bestLoss, testLoss))
283-
bestLoss = testLoss
284-
bestEpoch = epoch
285-
model.save(folder, postfix='best')
286-
continue
287-
continue
296+
performRandomSearch(epoch=epoch)
288297
continue
289298
return
290299

0 commit comments

Comments
 (0)