Skip to content

Commit be83836

Browse files
misc
1 parent 88ea8b4 commit be83836

File tree

3 files changed

+35
-17
lines changed

3 files changed

+35
-17
lines changed

NN/Utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ def call(self, x):
2727
coefs = tf.pow(self._base, powers)
2828
return tf.reduce_sum(x * coefs, axis=-1)
2929
############################################
30+
SMLP_GLOBAL_DROPOUT = 0.01
3031
class sMLP(tf.keras.layers.Layer):
31-
def __init__(self, sizes, activation='linear', dropout=0.01, **kwargs):
32+
def __init__(self, sizes, activation='linear', dropout=None, **kwargs):
3233
super().__init__(**kwargs)
34+
dropout = SMLP_GLOBAL_DROPOUT if dropout is None else dropout
3335
layers = []
3436
for i, sz in enumerate(sizes):
3537
if 0.0 < dropout:

app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,9 @@ def _modelFromArgs(args):
342342
stats = json.load(f)
343343

344344
# My own ids hardcoded here for simplicity
345-
userId = '98fdb9d9-14ef-9276-31e6-836e830acc19'
346-
placeId = 'ce42c1a9-f4ef-42d6-a219-cf25fad912ed'
347-
screenId = 'ce42c1a9-f4ef-42d6-a219-cf25fad912ed/29f35417-7bb7-3c94-124c-2ae16bda235d'
345+
userId = 'ce42c1a9-f4ef-42d6-a219-cf25fad912ed'
346+
placeId = '29ecaa6a-d3b5-784b-887e-f50a0c6533fa'
347+
screenId = placeId + '/' + '29f35417-7bb7-3c94-124c-2ae16bda235d'
348348
return CModelWrapper(
349349
timesteps=args.steps,
350350
user=dict(
@@ -353,7 +353,7 @@ def _modelFromArgs(args):
353353
screenId=screenId,
354354
),
355355
stats=stats,
356-
weights=dict(folder=args.folder, postfix=args.model)
356+
weights=dict(folder=args.folder, postfix=args.model, embeddings=True)
357357
)
358358

359359
def _predictorFromArgs(args):

scripts/train.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,27 @@ def _eval(dataset, model, plotFilename, args):
7373

7474
def evaluator(datasets, model, folder, args):
7575
losses = [np.inf] * len(datasets) # initialize with infinity
76-
def evaluate():
76+
def evaluate(onlyImproved=False):
7777
totalLoss = totalDist = 0.0
7878
for i, dataset in enumerate(datasets):
7979
loss, dist, T = _eval(dataset, model, os.path.join(folder, 'pred-%d.png' % i), args)
80-
print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
81-
i + 1, len(datasets), T, loss, losses[i], dist
82-
))
80+
if not onlyImproved:
81+
print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
82+
i + 1, len(datasets), T, loss, losses[i], dist
83+
))
8384
if loss < losses[i]:
84-
print('Improved %.5f => %.5f' % (losses[i], loss))
85+
print('Test %d / %d | Improved %.5f => %.5f' % (i + 1, len(datasets), losses[i], loss))
8586
model.save(folder, postfix='best-%d' % i) # save the model separately
8687
losses[i] = loss
8788
pass
8889

8990
totalLoss += loss
9091
totalDist += dist
9192
continue
92-
print('Mean loss: %.5f | Mean distance: %.5f' % (
93-
totalLoss / len(datasets), totalDist / len(datasets)
94-
))
93+
if not onlyImproved:
94+
print('Mean loss: %.5f | Mean distance: %.5f' % (
95+
totalLoss / len(datasets), totalDist / len(datasets)
96+
))
9597
return totalLoss / len(datasets)
9698
return evaluate
9799

@@ -268,10 +270,19 @@ def main(args):
268270
break
269271
if 'reset' == args.on_patience:
270272
print('Resetting the model to the average of the best models')
271-
# and add some noise
272-
averageModels(folder, model, noiseStd=0.01)
273-
bestEpoch = epoch
274-
continue
273+
bestEpoch = epoch # reset the best epoch
274+
for _ in range(args.restarts):
275+
# and add some noise
276+
averageModels(folder, model, noiseStd=args.noise)
277+
# re-evaluate the model with the new weights
278+
testLoss = eval(onlyImproved=True)
279+
if testLoss < bestLoss:
280+
print('Improved %.5f => %.5f' % (bestLoss, testLoss))
281+
bestLoss = testLoss
282+
bestEpoch = epoch
283+
model.save(folder, postfix='best')
284+
continue
285+
continue
275286
continue
276287
return
277288

@@ -299,6 +310,11 @@ def main(args):
299310
help='JSON file with the scheduler parameters for sampling the training dataset'
300311
)
301312
parser.add_argument('--debug', action='store_true')
313+
parser.add_argument('--noise', type=float, default=1e-4)
314+
parser.add_argument(
315+
'--restarts', type=int, default=1,
316+
help='Number of times to restart the model reinitializing the weights'
317+
)
302318

303319
main(parser.parse_args())
304320
pass

0 commit comments

Comments
 (0)