@@ -73,25 +73,27 @@ def _eval(dataset, model, plotFilename, args):
73
73
74
74
def evaluator (datasets , model , folder , args ):
75
75
losses = [np .inf ] * len (datasets ) # initialize with infinity
76
- def evaluate ():
76
+ def evaluate (onlyImproved = False ):
77
77
totalLoss = totalDist = 0.0
78
78
for i , dataset in enumerate (datasets ):
79
79
loss , dist , T = _eval (dataset , model , os .path .join (folder , 'pred-%d.png' % i ), args )
80
- print ('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
81
- i + 1 , len (datasets ), T , loss , losses [i ], dist
82
- ))
80
+ if not onlyImproved :
81
+ print ('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
82
+ i + 1 , len (datasets ), T , loss , losses [i ], dist
83
+ ))
83
84
if loss < losses [i ]:
84
- print ('Improved %.5f => %.5f' % (losses [i ], loss ))
85
+ print ('Test %d / %d | Improved %.5f => %.5f' % (i + 1 , len ( datasets ), losses [i ], loss ))
85
86
model .save (folder , postfix = 'best-%d' % i ) # save the model separately
86
87
losses [i ] = loss
87
88
pass
88
89
89
90
totalLoss += loss
90
91
totalDist += dist
91
92
continue
92
- print ('Mean loss: %.5f | Mean distance: %.5f' % (
93
- totalLoss / len (datasets ), totalDist / len (datasets )
94
- ))
93
+ if not onlyImproved :
94
+ print ('Mean loss: %.5f | Mean distance: %.5f' % (
95
+ totalLoss / len (datasets ), totalDist / len (datasets )
96
+ ))
95
97
return totalLoss / len (datasets )
96
98
return evaluate
97
99
@@ -268,10 +270,19 @@ def main(args):
268
270
break
269
271
if 'reset' == args .on_patience :
270
272
print ('Resetting the model to the average of the best models' )
271
- # and add some noise
272
- averageModels (folder , model , noiseStd = 0.01 )
273
- bestEpoch = epoch
274
- continue
273
+ bestEpoch = epoch # reset the best epoch
274
+ for _ in range (args .restarts ):
275
+ # and add some noise
276
+ averageModels (folder , model , noiseStd = args .noise )
277
+ # re-evaluate the model with the new weights
278
+ testLoss = eval (onlyImproved = True )
279
+ if testLoss < bestLoss :
280
+ print ('Improved %.5f => %.5f' % (bestLoss , testLoss ))
281
+ bestLoss = testLoss
282
+ bestEpoch = epoch
283
+ model .save (folder , postfix = 'best' )
284
+ continue
285
+ continue
275
286
continue
276
287
return
277
288
@@ -299,6 +310,11 @@ def main(args):
299
310
help = 'JSON file with the scheduler parameters for sampling the training dataset'
300
311
)
301
312
parser .add_argument ('--debug' , action = 'store_true' )
313
+ parser .add_argument ('--noise' , type = float , default = 1e-4 )
314
+ parser .add_argument (
315
+ '--restarts' , type = int , default = 1 ,
316
+ help = 'Number of times to restart the model reinitializing the weights'
317
+ )
302
318
303
319
main (parser .parse_args ())
304
320
pass
0 commit comments