@@ -239,31 +239,52 @@ def main(args):
239
239
model = trainer (** model )
240
240
model ._model .summary ()
241
241
242
- if args .average :
243
- averageModels (folder , model )
244
242
# find folders with the name "/test-*/"
245
243
evalDatasets = [
246
244
CTestLoader (nm )
247
245
for nm in glob .glob (os .path .join (folder , 'test-main' , 'test-*/' ))
248
246
]
249
247
eval = evaluator (evalDatasets , model , folder , args )
250
- bestLoss = eval ()
248
+ bestLoss = eval () # evaluate loaded model
251
249
bestEpoch = 0
250
+ # wrapper for the evaluation function. It saves the model if it is better
251
+ def evalWrapper (eval ):
252
+ def f (epoch , onlyImproved = False ):
253
+ nonlocal bestLoss , bestEpoch
254
+ newLoss = eval (onlyImproved = onlyImproved )
255
+ if newLoss < bestLoss :
256
+ print ('Improved %.5f => %.5f' % (bestLoss , newLoss ))
257
+ bestLoss = newLoss
258
+ bestEpoch = epoch
259
+ model .save (folder , postfix = 'best' )
260
+ return
261
+ return f
262
+
263
+ eval = evalWrapper (eval )
264
+
265
+ def performRandomSearch (epoch = 0 ):
266
+ nonlocal bestLoss , bestEpoch
267
+ averageModels (folder , model , noiseStd = 0.0 )
268
+ eval (epoch = epoch , onlyImproved = True ) # evaluate the averaged model
269
+ for _ in range (args .restarts ):
270
+ # and add some noise
271
+ averageModels (folder , model , noiseStd = args .noise )
272
+ # re-evaluate the model with the new weights
273
+ eval (epoch = epoch , onlyImproved = True )
274
+ continue
275
+ return
276
+
277
+ if args .average :
278
+ performRandomSearch ()
279
+
252
280
trainStep = _modelTrainingLoop (model , trainDataset )
253
281
for epoch in range (args .epochs ):
254
282
trainStep (
255
283
desc = 'Epoch %.*d / %d' % (len (str (args .epochs )), epoch , args .epochs ),
256
284
sampleParams = getSampleParams (epoch )
257
285
)
258
286
model .save (folder , postfix = 'latest' )
259
-
260
- testLoss = eval ()
261
- if testLoss < bestLoss :
262
- print ('Improved %.5f => %.5f' % (bestLoss , testLoss ))
263
- bestLoss = testLoss
264
- bestEpoch = epoch
265
- model .save (folder , postfix = 'best' )
266
- continue
287
+ eval (epoch )
267
288
268
289
print ('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch , bestLoss ))
269
290
if args .patience <= (epoch - bestEpoch ):
@@ -272,19 +293,7 @@ def main(args):
272
293
break
273
294
if 'reset' == args .on_patience :
274
295
print ('Resetting the model to the average of the best models' )
275
- bestEpoch = epoch # reset the best epoch
276
- for _ in range (args .restarts ):
277
- # and add some noise
278
- averageModels (folder , model , noiseStd = args .noise )
279
- # re-evaluate the model with the new weights
280
- testLoss = eval (onlyImproved = True )
281
- if testLoss < bestLoss :
282
- print ('Improved %.5f => %.5f' % (bestLoss , testLoss ))
283
- bestLoss = testLoss
284
- bestEpoch = epoch
285
- model .save (folder , postfix = 'best' )
286
- continue
287
- continue
296
+ performRandomSearch (epoch = epoch )
288
297
continue
289
298
return
290
299
0 commit comments