@@ -73,29 +73,35 @@ def _eval(dataset, model, plotFilename, args):
73
73
74
74
def evaluator (datasets , model , folder , args ):
75
75
losses = [np .inf ] * len (datasets ) # initialize with infinity
76
+ dists = [np .inf ] * len (datasets ) # initialize with infinity
76
77
def evaluate (onlyImproved = False ):
77
78
totalLoss = totalDist = 0.0
79
+ losses_dist = []
78
80
for i , dataset in enumerate (datasets ):
79
81
loss , dist , T = _eval (dataset , model , os .path .join (folder , 'pred-%d.png' % i ), args )
82
+ losses_dist .append ((loss , losses [i ], dist , dists [i ]))
80
83
isImproved = loss < losses [i ]
81
84
if (not onlyImproved ) or isImproved :
82
- print ('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
83
- i + 1 , len (datasets ), T , loss , losses [i ], dist
85
+ print ('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f (%.5f) ' % (
86
+ i + 1 , len (datasets ), T , loss , losses [i ], dist , dists [ i ]
84
87
))
85
88
if isImproved :
86
- print ('Test %d / %d | Improved %.5f => %.5f' % (i + 1 , len (datasets ), losses [i ], loss ))
89
+ print ('Test %d / %d | Improved %.5f => %.5f, Distance: %.5f => %.5f' % (
90
+ i + 1 , len (datasets ), losses [i ], loss , dists [i ], dist
91
+ ))
87
92
model .save (folder , postfix = 'best-%d' % i ) # save the model separately
88
93
losses [i ] = loss
89
94
pass
90
95
96
+ dists [i ] = min (dist , dists [i ]) # track the best distance
91
97
totalLoss += loss
92
98
totalDist += dist
93
99
continue
94
100
if not onlyImproved :
95
101
print ('Mean loss: %.5f | Mean distance: %.5f' % (
96
102
totalLoss / len (datasets ), totalDist / len (datasets )
97
103
))
98
- return totalLoss / len (datasets )
104
+ return totalLoss / len (datasets ), losses_dist
99
105
return evaluate
100
106
101
107
def _modelTrainingLoop (model , dataset ):
@@ -247,15 +253,20 @@ def main(args):
247
253
for nm in glob .glob (os .path .join (folder , 'test-main' , 'test-*/' ))
248
254
]
249
255
eval = evaluator (evalDatasets , model , folder , args )
250
- bestLoss = eval () # evaluate loaded model
256
+ bestLoss , _ = eval () # evaluate loaded model
251
257
bestEpoch = 0
252
258
# wrapper for the evaluation function. It saves the model if it is better
253
259
def evalWrapper (eval ):
254
260
def f (epoch , onlyImproved = False ):
255
261
nonlocal bestLoss , bestEpoch
256
- newLoss = eval (onlyImproved = onlyImproved )
262
+ newLoss , losses = eval (onlyImproved = onlyImproved )
257
263
if newLoss < bestLoss :
258
264
print ('Improved %.5f => %.5f' % (bestLoss , newLoss ))
265
+ if onlyImproved : #details
266
+ for i , (loss , bestLoss_ , dist , bestDist ) in enumerate (losses ):
267
+ print ('Test %d | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (i + 1 , loss , bestLoss_ , dist , bestDist ))
268
+ continue
269
+ print ('-' * 80 )
259
270
bestLoss = newLoss
260
271
bestEpoch = epoch
261
272
model .save (folder , postfix = 'best' )
0 commit comments