Skip to content

Commit

Permalink
Fixed AverageValueMeter and tests (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
levzlotnik authored and szagoruyko committed Nov 5, 2019
1 parent d0d7fc9 commit 6b9aa85
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion test/test_meters.py
Expand Up @@ -47,7 +47,7 @@ def testAverageValueMeter_n(self):
"""
m = meter.AverageValueMeter()
for i in range(1, 11):
m.add(i * i, n=i)
m.add(i, n=i)
mean, std = m.value()
self.assertEqual(mean, 7.0)
m.reset()
Expand Down
22 changes: 11 additions & 11 deletions torchnet/meter/averagevaluemeter.py
Expand Up @@ -11,22 +11,22 @@ def __init__(self):

def add(self, value, n=1):
self.val = value
self.sum += value
self.var += value * value
self.n += n

if self.n == 0:
self.mean, self.std = np.nan, np.nan
elif self.n == 1:
self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy
self.sum += value * n
if n <= 0:
raise ValueError("Cannot use a non-positive weight for the running stat.")
elif self.n == 0:
self.mean = 0.0 + value # This is to force a copy in torch/numpy
self.std = np.inf
self.mean_old = self.mean
self.m_s = 0.0
else:
self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
self.m_s += (value - self.mean_old) * (value - self.mean)
self.mean = self.mean_old + n * (value - self.mean_old) / float(self.n + n)
self.m_s += n * (value - self.mean_old) * (value - self.mean)
self.mean_old = self.mean
self.std = np.sqrt(self.m_s / (self.n - 1.0))
self.std = np.sqrt(self.m_s / (self.n + n - 1.0))
self.var = self.std ** 2

self.n += n

def value(self):
return self.mean, self.std
Expand Down

0 comments on commit 6b9aa85

Please sign in to comment.