Skip to content

Commit 2f22c74

Browse files
ignore outliers
1 parent fa32d5a commit 2f22c74

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

scripts/train.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def evaluator(datasets, model, folder, args):
7575
losses = [np.inf] * len(datasets) # initialize with infinity
7676
dists = [np.inf] * len(datasets) # initialize with infinity
7777
def evaluate(onlyImproved=False):
78-
totalLoss = totalDist = 0.0
78+
totalLoss = []
79+
totalDist = []
7980
losses_dist = []
8081
for i, dataset in enumerate(datasets):
8182
loss, dist, T = _eval(dataset, model, os.path.join(folder, 'pred-%d.png' % i), args)
@@ -94,14 +95,16 @@ def evaluate(onlyImproved=False):
9495
pass
9596

9697
dists[i] = min(dist, dists[i]) # track the best distance
97-
totalLoss += loss
98-
totalDist += dist
98+
# filter the results by the distance, to ignore the outliers
99+
if dists[i] < 0.1:
100+
totalLoss.append(loss)
101+
totalDist.append(dist)
99102
continue
100103
if not onlyImproved:
101104
print('Mean loss: %.5f | Mean distance: %.5f' % (
102-
totalLoss / len(datasets), totalDist / len(datasets)
105+
np.mean(totalLoss), np.mean(totalDist)
103106
))
104-
return totalLoss / len(datasets), losses_dist
107+
return np.mean(totalLoss), losses_dist
105108
return evaluate
106109

107110
def _modelTrainingLoop(model, dataset):

0 commit comments

Comments
 (0)