Skip to content

Commit

Permalink
Feature/metric dim (#490)
Browse files Browse the repository at this point in the history
* Add dim argument to mean metric

* Add dim argument to running mean decorator

* Update mean metric

* Add dim argument to mean metric

* Add var metric and update std metric

* Update CHANGELOG.md=

* Update ToDict metric tests

* Update basic opt example to use new dim argument
  • Loading branch information
ethanwharris committed Jan 24, 2019
1 parent e04f35d commit eca7028
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 65 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added torchbearer.variational, a sub-package for implementations of state of the art variational auto-encoders
- Added SimpleUniform and SimpleExponential distributions
- Added a decorator which can be used to cite a research article as part of a doc string
- Added an optional dimension argument to the mean, std and running_mean metric aggregators
- Added a var metric and decorator which can be used to calculate the variance of a metric
- Added an unbiased flag to the std and var metrics to optionally not apply Bessel's correction (consistent with torch.std / torch.var)
### Changed
- Changed the default behaviour of the std metric to compute the sample std, in line with torch.std
### Deprecated
### Removed
### Fixed
Expand Down
7 changes: 3 additions & 4 deletions docs/_static/examples/basic_opt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
from torch.nn import Module

import numpy as np

import torchbearer as tb

ESTIMATE = tb.state_key('est')
Expand All @@ -27,7 +25,7 @@ def f(self):
return torch.sum(out**2)

def forward(self, _, state):
state[ESTIMATE] = np.round(self.pars.detach().cpu().numpy(), 4)
state[ESTIMATE] = self.pars.detach().unsqueeze(1)
return self.f()


Expand All @@ -41,6 +39,7 @@ def loss(y_pred, y_true):
model = Net(p)
optim = torch.optim.SGD(model.parameters(), lr=0.0001)

tbtrial = tb.Trial(model, optim, loss, [ESTIMATE, 'loss'], pass_state=True).for_train_steps(training_steps).to('cuda')
tbtrial = tb.Trial(model, optim, loss, [tb.metrics.running_mean(ESTIMATE, dim=1), 'loss'], pass_state=True)
tbtrial.for_train_steps(training_steps).to('cuda')
tbtrial.run()
print(list(model.parameters())[0].data)
10 changes: 5 additions & 5 deletions docs/examples/basic_opt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ We store the current estimates for the minimum as parameters in the model (so Py

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 11-31
:lines: 9-29

The Loss
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -24,7 +24,7 @@ Note that as we are using a base loss, torchbearer passes this the network outpu

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 34-35
:lines: 32-33


Optimising
Expand All @@ -37,20 +37,20 @@ We have set the number of optimisation steps for this example as 50000.

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 38-39
:lines: 36-37

The learning rate chosen for this example is very low and we could get convergence much faster with a larger rate, however this allows us to view convergence in real time.
We define the model and optimiser in the standard way.

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 41-42
:lines: 39-40

Finally we start the optimising on the GPU and print the final minimum estimate.

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 44-46
:lines: 42-45

Usually torchbearer will infer the number of training steps from the data generator.
Since for this example we have no data to give the model (which will be passed `None`), we need to tell torchbearer how many steps to run using the ``for_train_steps`` method.
Expand Down
58 changes: 52 additions & 6 deletions tests/metrics/test_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,24 @@

from unittest.mock import Mock, call

from torchbearer.metrics import RunningMean, Metric, RunningMetric, Mean, Std
from torchbearer.metrics import RunningMean, Metric, RunningMetric, Mean, Std, Var

import torch


class TestVar(unittest.TestCase):
def test_variance_dim(self):
var = Var('test', dim=0)
var.process(torch.Tensor([[1., 2.], [3., 4.]]))
var.process(torch.Tensor([[4., 3.], [2., 1.]]))
var.process(torch.Tensor([[1., 1.], [1., 1.]]))

res = var.process_final()
self.assertTrue(len(res) == 2)
for m in res:
self.assertTrue(abs(m - 1.6000) < 0.0001)


class TestStd(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
Expand All @@ -17,7 +30,7 @@ def setUp(self):
torch.FloatTensor([0.7, 0.8, 0.9]),
torch.ones(torch.Size([]))]

self._std = Std('test')
self._std = Std('test', unbiased=False)
self._std.reset({})
self._target = 0.31622776601684

Expand Down Expand Up @@ -55,7 +68,7 @@ def setUpMoreDims(self):
torch.FloatTensor([[0.4, 0.5, 0.6], [1.4, 1.5, 1.6]]),
torch.FloatTensor([[0.7, 0.8, 0.9], [1.7, 1.8, 1.9]]),
torch.ones(torch.Size([]))]
self._std = Std('test')
self._std = Std('test', unbiased=False)
self._std.reset({})
self._target = 0.57662804083742

Expand All @@ -66,6 +79,17 @@ def test_more_dims(self):
result = self._std.process_final({})
self.assertAlmostEqual(self._target, result, places=5)

def test_std_dim(self):
std = Std('test', dim=0)
std.process(torch.Tensor([[1., 2.], [3., 4.]]))
std.process(torch.Tensor([[4., 3.], [2., 1.]]))
std.process(torch.Tensor([[1., 1.], [1., 1.]]))

res = std.process_final()
self.assertTrue(len(res) == 2)
for m in res:
self.assertTrue(abs(m - 1.2649) < 0.0001)


class TestMean(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -116,6 +140,17 @@ def test_more_dims(self):
result = self._mean.process_final({})
self.assertAlmostEqual(self._target, result, places=5)

def test_mean_dim(self):
mean = Mean('test', dim=0)
mean.process(torch.Tensor([[1., 2.], [3., 4.]]))
mean.process(torch.Tensor([[4., 3.], [2., 1.]]))
mean.process(torch.Tensor([[1., 1.], [1., 1.]]))

res = mean.process_final()
self.assertTrue(len(res) == 2)
for m in res:
self.assertTrue(abs(m - 2.0) < 0.0001)


class TestRunningMetric(unittest.TestCase):
def setUp(self):
Expand All @@ -137,15 +172,15 @@ def test_cache_one_step(self):

def test_empty_methods(self):
metric = RunningMetric('test')
self.assertTrue(metric._step(['test']) is None)
self.assertTrue(metric._process_train(['test']) is None)
self.assertRaises(NotImplementedError, lambda: metric._step(['test']) is None)
self.assertRaises(NotImplementedError, lambda: metric._process_train(['test']) is None)


class TestRunningMean(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
self._mean = RunningMean('test')
self._cache = [1.0, 1.5, 2.0]
self._cache = [torch.Tensor([1.0]), torch.Tensor([1.5]), torch.Tensor([2.0])]
self._target = 1.5

def test_train(self):
Expand All @@ -155,3 +190,14 @@ def test_train(self):
def test_step(self):
result = self._mean._step(self._cache)
self.assertEqual(self._target, result)

def test_dims(self):
mean = RunningMean('test', dim=0)
cache = [mean._process_train(torch.Tensor([[1., 2.], [3., 4.]])),
mean._process_train(torch.Tensor([[4., 3.], [2., 1.]])),
mean._process_train(torch.Tensor([[1., 1.], [1., 1.]]))]

res = mean._step(cache)
self.assertTrue(len(res) == 2)
for m in res:
self.assertTrue(abs(m - 2.0) < 0.0001)
42 changes: 36 additions & 6 deletions tests/metrics/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,58 +56,88 @@ def test_to_dict_metric_instance(self):

def test_mean_metric_class(self):
metric = metrics.Metric
out = metrics.mean(metric)('test')
out = metrics.mean(dim=10)(metric)('test')
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Mean))
self.assertTrue(out.children[0].metric._kwargs['dim'] == 10)
self.assertTrue(out.children[0].metric.name == 'test')
self.assertTrue(out.root.name == 'test')

def test_mean_metric_instance(self):
metric = metrics.Metric('test')
out = metrics.mean(metric)
out = metrics.mean(metric, dim=10)
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Mean))
self.assertTrue(out.children[0].metric._kwargs['dim'] == 10)
self.assertTrue(out.children[0].metric.name == 'test')
self.assertTrue(out.root.name == 'test')

def test_std_metric_class(self):
metric = metrics.Metric
out = metrics.std(metric)('test')
out = metrics.std(dim=10, unbiased=False)(metric)('test')
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Std))
self.assertTrue(out.children[0].metric._kwargs['dim'] == 10)
self.assertTrue(not out.children[0].metric._unbiased)
self.assertTrue(out.children[0].metric.name == 'test_std')
self.assertTrue(out.root.name == 'test')

def test_std_metric_instance(self):
metric = metrics.Metric('test')
out = metrics.std(metric)
out = metrics.std(metric, dim=10, unbiased=False)
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Std))
self.assertTrue(out.children[0].metric._kwargs['dim'] == 10)
self.assertTrue(not out.children[0].metric._unbiased)
self.assertTrue(out.children[0].metric.name == 'test_std')
self.assertTrue(out.root.name == 'test')

def test_var_metric_class(self):
metric = metrics.Metric
out = metrics.var(dim=10, unbiased=False)(metric)('test')
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Var))
self.assertTrue(out.children[0].metric._kwargs['dim'] == 10)
self.assertTrue(not out.children[0].metric._unbiased)
self.assertTrue(out.children[0].metric.name == 'test_var')
self.assertTrue(out.root.name == 'test')

def test_var_metric_instance(self):
metric = metrics.Metric('test')
out = metrics.var(metric, dim=10, unbiased=False)
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Var))
self.assertTrue(out.children[0].metric._kwargs['dim'] == 10)
self.assertTrue(not out.children[0].metric._unbiased)
self.assertTrue(out.children[0].metric.name == 'test_var')
self.assertTrue(out.root.name == 'test')

def test_running_mean_metric_class(self):
metric = metrics.Metric
out = metrics.running_mean(batch_size=40, step_size=20)(metric)('test')
out = metrics.running_mean(batch_size=40, step_size=20, dim=10)(metric)('test')
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.RunningMean))
self.assertTrue(out.children[0].metric._batch_size == 40)
self.assertTrue(out.children[0].metric._step_size == 20)
self.assertTrue(out.children[0].metric._kwargs['dim'] == 10)
self.assertTrue(out.children[0].metric.name == 'running_test')
self.assertTrue(out.root.name == 'test')

def test_running_mean_metric_instance(self):
metric = metrics.Metric('test')
out = metrics.running_mean(batch_size=40, step_size=20)(metric)
out = metrics.running_mean(batch_size=40, step_size=20, dim=10)(metric)
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.RunningMean))
self.assertTrue(out.children[0].metric._batch_size == 40)
self.assertTrue(out.children[0].metric._step_size == 20)
self.assertTrue(out.children[0].metric._kwargs['dim'] == 10)
self.assertTrue(out.children[0].metric.name == 'running_test')
self.assertTrue(out.root.name == 'test')
14 changes: 14 additions & 0 deletions tests/metrics/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ def test_eval_process_final(self):
self.assertTrue(self._to_dict.process_final('input') == {'val_test': 'process_final'})
self._metric.process_final.assert_called_once_with('input')

def test_eval_train(self):
self._to_dict.eval(data_key=torchbearer.TRAIN_DATA)
self.assertEqual(self._metric.eval.call_count, 1)

self.assertTrue(self._to_dict.process_final('input') == {'train_test': 'process_final'})
self._metric.process_final.assert_called_once_with('input')

def test_eval_test(self):
self._to_dict.eval(data_key=torchbearer.TEST_DATA)
self.assertEqual(self._metric.eval.call_count, 1)

self.assertTrue(self._to_dict.process_final('input') == {'test_test': 'process_final'})
self._metric.process_final.assert_called_once_with('input')

def test_reset(self):
self._to_dict.reset('test')
self._metric.reset.assert_called_once_with('test')
Expand Down

0 comments on commit eca7028

Please sign in to comment.