Skip to content

Commit bcb125f

Browse files
more data to track training progress
1 parent 8e28a38 commit bcb125f

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

scripts/train.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,29 +73,35 @@ def _eval(dataset, model, plotFilename, args):
7373

7474
def evaluator(datasets, model, folder, args):
7575
losses = [np.inf] * len(datasets) # initialize with infinity
76+
dists = [np.inf] * len(datasets) # initialize with infinity
7677
def evaluate(onlyImproved=False):
7778
totalLoss = totalDist = 0.0
79+
losses_dist = []
7880
for i, dataset in enumerate(datasets):
7981
loss, dist, T = _eval(dataset, model, os.path.join(folder, 'pred-%d.png' % i), args)
82+
losses_dist.append((loss, losses[i], dist, dists[i]))
8083
isImproved = loss < losses[i]
8184
if (not onlyImproved) or isImproved:
82-
print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
83-
i + 1, len(datasets), T, loss, losses[i], dist
85+
print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (
86+
i + 1, len(datasets), T, loss, losses[i], dist, dists[i]
8487
))
8588
if isImproved:
86-
print('Test %d / %d | Improved %.5f => %.5f' % (i + 1, len(datasets), losses[i], loss))
89+
print('Test %d / %d | Improved %.5f => %.5f, Distance: %.5f => %.5f' % (
90+
i + 1, len(datasets), losses[i], loss, dists[i], dist
91+
))
8792
model.save(folder, postfix='best-%d' % i) # save the model separately
8893
losses[i] = loss
8994
pass
9095

96+
dists[i] = min(dist, dists[i]) # track the best distance
9197
totalLoss += loss
9298
totalDist += dist
9399
continue
94100
if not onlyImproved:
95101
print('Mean loss: %.5f | Mean distance: %.5f' % (
96102
totalLoss / len(datasets), totalDist / len(datasets)
97103
))
98-
return totalLoss / len(datasets)
104+
return totalLoss / len(datasets), losses_dist
99105
return evaluate
100106

101107
def _modelTrainingLoop(model, dataset):
@@ -247,15 +253,20 @@ def main(args):
247253
for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/'))
248254
]
249255
eval = evaluator(evalDatasets, model, folder, args)
250-
bestLoss = eval() # evaluate loaded model
256+
bestLoss, _ = eval() # evaluate loaded model
251257
bestEpoch = 0
252258
# wrapper for the evaluation function. It saves the model if it is better
253259
def evalWrapper(eval):
254260
def f(epoch, onlyImproved=False):
255261
nonlocal bestLoss, bestEpoch
256-
newLoss = eval(onlyImproved=onlyImproved)
262+
newLoss, losses = eval(onlyImproved=onlyImproved)
257263
if newLoss < bestLoss:
258264
print('Improved %.5f => %.5f' % (bestLoss, newLoss))
265+
if onlyImproved: #details
266+
for i, (loss, bestLoss_, dist, bestDist) in enumerate(losses):
267+
print('Test %d | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (i + 1, loss, bestLoss_, dist, bestDist))
268+
continue
269+
print('-' * 80)
259270
bestLoss = newLoss
260271
bestEpoch = epoch
261272
model.save(folder, postfix='best')

0 commit comments

Comments
 (0)