Skip to content

Commit

Permalink
Test metrics and add support for non-factories in decorators (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored and MattPainter01 committed Jul 19, 2018
1 parent 354c776 commit 5f04b8e
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 65 deletions.
78 changes: 68 additions & 10 deletions tests/metrics/test_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,79 @@
from torch.autograd import Variable

import torchbearer
from torchbearer.metrics import Std, Metric, Mean, BatchLambda, EpochLambda
from torchbearer.metrics import Std, Metric, Mean, BatchLambda, EpochLambda, ToDict

import torch


class TestToDict(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
self._metric.train = Mock()
self._metric.eval = Mock()
self._metric.reset = Mock()
self._metric.process = Mock(return_value='process')
self._metric.process_final = Mock(return_value='process_final')

self._to_dict = ToDict(self._metric)

def test_train_process(self):
self._to_dict.train()
self._metric.train.assert_called_once()

self.assertTrue(self._to_dict.process('input') == {'test': 'process'})
self._metric.process.assert_called_once_with('input')

def test_train_process_final(self):
self._to_dict.train()
self._metric.train.assert_called_once()

self.assertTrue(self._to_dict.process_final('input') == {'test': 'process_final'})
self._metric.process_final.assert_called_once_with('input')

def test_eval_process(self):
self._to_dict.eval()
self._metric.eval.assert_called_once()

self.assertTrue(self._to_dict.process('input') == {'val_test': 'process'})
self._metric.process.assert_called_once_with('input')

def test_eval_process_final(self):
self._to_dict.eval()
self._metric.eval.assert_called_once()

self.assertTrue(self._to_dict.process_final('input') == {'val_test': 'process_final'})
self._metric.process_final.assert_called_once_with('input')

def test_reset(self):
self._to_dict.reset('test')
self._metric.reset.assert_called_once_with('test')


class TestStd(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
self._metric.process = Mock()
self._metric.process.side_effect = [torch.FloatTensor([0.1, 0.2, 0.3]),
self._metric.process.side_effect = [torch.zeros(torch.Size([])),
torch.FloatTensor([0.1, 0.2, 0.3]),
torch.FloatTensor([0.4, 0.5, 0.6]),
torch.FloatTensor([0.7, 0.8, 0.9])]
torch.FloatTensor([0.7, 0.8, 0.9]),
torch.ones(torch.Size([]))]

self._std = Std('test')
self._std.reset({})
self._target = 0.25819888974716
self._target = 0.31622776601684

def test_train(self):
self._std.train()
for i in range(3):
for i in range(5):
self._std.process(self._metric.process())
result = self._std.process_final({})
self.assertAlmostEqual(self._target, result)

def test_validate(self):
self._std.eval()
for i in range(3):
for i in range(5):
self._std.process(self._metric.process())
result = self._std.process_final({})
self.assertAlmostEqual(self._target, result)
Expand All @@ -41,24 +87,26 @@ class TestMean(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
self._metric.process = Mock()
self._metric.process.side_effect = [torch.FloatTensor([0.1, 0.2, 0.3]),
self._metric.process.side_effect = [torch.zeros(torch.Size([])),
torch.FloatTensor([0.1, 0.2, 0.3]),
torch.FloatTensor([0.4, 0.5, 0.6]),
torch.FloatTensor([0.7, 0.8, 0.9])]
torch.FloatTensor([0.7, 0.8, 0.9]),
torch.ones(torch.Size([]))]

self._mean = Mean('test')
self._mean.reset({})
self._target = 0.5

def test_train_dict(self):
self._mean.train()
for i in range(3):
for i in range(5):
self._mean.process(self._metric.process())
result = self._mean.process_final({})
self.assertAlmostEqual(self._target, result)

def test_validate_dict(self):
self._mean.eval()
for i in range(3):
for i in range(5):
self._mean.process(self._metric.process())
result = self._mean.process_final({})
self.assertAlmostEqual(self._target, result)
Expand Down Expand Up @@ -127,3 +175,13 @@ def test_validate(self):
self._metric_function.assert_called_once()
self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)

def test_not_running(self):
metric = EpochLambda('test', self._metric_function, running=False, step_size=6)
metric.reset({torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32})
metric.train()

for i in range(12):
metric.process(self._states[0])

self._metric_function.assert_not_called()
110 changes: 110 additions & 0 deletions tests/metrics/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import unittest
from unittest.mock import patch, Mock

import torchbearer.metrics as metrics
from torchbearer.metrics import default_for_key, lambda_metric, EpochLambda, BatchLambda


class TestDecorators(unittest.TestCase):
def test_default_for_key_class(self):
metric = metrics.Loss
metric = default_for_key('test')(metric)
self.assertTrue('test' in metrics.DEFAULT_METRICS)
self.assertTrue(metrics.DEFAULT_METRICS['test'].name == 'loss')
self.assertTrue(metric == metrics.Loss)

def test_default_for_key_metric(self):
metric = metrics.Loss()
metric = default_for_key('test')(metric)
self.assertTrue('test' in metrics.DEFAULT_METRICS)
self.assertTrue(metrics.DEFAULT_METRICS['test'].name == 'loss')
self.assertTrue(metric.name == 'loss')

def test_lambda_metric_epoch(self):
metric = 'test'
metric = lambda_metric('test', on_epoch=True)(metric)().build()
self.assertTrue(isinstance(metric, EpochLambda))
self.assertTrue(metric._final == 'test')

def test_lambda_metric_batch(self):
metric = 'test'
metric = lambda_metric('test')(metric)().build()
self.assertTrue(isinstance(metric, BatchLambda))
self.assertTrue(metric._metric_function == 'test')

def test_to_dict_metric(self):
metric = metrics.Metric
out = metrics.to_dict(metric)('test').build()
self.assertTrue(out.metric.name == 'test')
self.assertTrue(isinstance(out, metrics.ToDict))

@patch('torchbearer.metrics.MetricFactory.build')
def test_to_dict_metric_factory(self, build_mock):
metric = metrics.MetricFactory
build_mock.return_value = metrics.Metric('test')
out = metrics.to_dict(metric)().build()
self.assertTrue(out.metric.name == 'test')
self.assertTrue(isinstance(out, metrics.ToDict))
build_mock.assert_called_once()

def test_mean_metric(self):
metric = metrics.Metric
out = metrics.mean(metric)('test').build()
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Mean))
self.assertTrue(out.children[0].metric.name == 'test')
self.assertTrue(out.root.name == 'test')

@patch('torchbearer.metrics.MetricFactory.build')
def test_mean_metric_factory(self, build_mock):
metric = metrics.MetricFactory
build_mock.return_value = metrics.Metric('test')
out = metrics.mean(metric)().build()
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Mean))
self.assertTrue(out.children[0].metric.name == 'test')
self.assertTrue(out.root.name == 'test')

def test_std_metric(self):
metric = metrics.Metric
out = metrics.std(metric)('test').build()
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Std))
self.assertTrue(out.children[0].metric.name == 'test_std')
self.assertTrue(out.root.name == 'test')

@patch('torchbearer.metrics.MetricFactory.build')
def test_std_metric_factory(self, build_mock):
metric = metrics.MetricFactory
build_mock.return_value=metrics.Metric('test')
out = metrics.std(metric)().build()
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.Std))
self.assertTrue(out.children[0].metric.name == 'test_std')
self.assertTrue(out.root.name == 'test')

def test_running_mean_metric(self):
metric = metrics.Metric
out = metrics.running_mean(batch_size=40, step_size=20)(metric)('test').build()
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.RunningMean))
self.assertTrue(out.children[0].metric._batch_size == 40)
self.assertTrue(out.children[0].metric._step_size == 20)
self.assertTrue(out.children[0].metric.name == 'running_test')
self.assertTrue(out.root.name == 'test')

@patch('torchbearer.metrics.MetricFactory.build')
def test_running_mean_metric_factory(self, build_mock):
metric = metrics.MetricFactory
build_mock.return_value=metrics.Metric('test')
out = metrics.running_mean(metric)().build()
self.assertTrue(isinstance(out, metrics.MetricTree))
self.assertTrue(isinstance(out.children[0], metrics.ToDict))
self.assertTrue(isinstance(out.children[0].metric, metrics.RunningMean))
self.assertTrue(out.children[0].metric.name == 'running_test')
self.assertTrue(out.root.name == 'test')
128 changes: 127 additions & 1 deletion tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,98 @@
import unittest
from unittest.mock import Mock

from torchbearer.metrics import MetricList, Metric
from torchbearer.metrics import MetricFactory, MetricList, Metric, MetricTree, AdvancedMetric


class TestMetricFactory(unittest.TestCase):
def test_empty_build(self):
factory = MetricFactory()
self.assertTrue(factory.build() is None)


class TestMetricTree(unittest.TestCase):
def test_process(self):
root = Metric('test')
root.process = Mock(return_value='test')
leaf1 = Metric('test')
leaf1.process = Mock(return_value={'test': 10})
leaf2 = Metric('test')
leaf2.process = Mock(return_value=None)

tree = MetricTree(root)
tree.add_child(leaf1)
tree.add_child(leaf2)

self.assertTrue(tree.process('args') == {'test': 10})

root.process.assert_called_once_with('args')
leaf1.process.assert_called_once_with('test')
leaf2.process.assert_called_once_with('test')

def test_process_final(self):
root = Metric('test')
root.process_final = Mock(return_value='test')
leaf1 = Metric('test')
leaf1.process_final = Mock(return_value={'test': 10})
leaf2 = Metric('test')
leaf2.process_final = Mock(return_value=None)

tree = MetricTree(root)
tree.add_child(leaf1)
tree.add_child(leaf2)

self.assertTrue(tree.process_final('args') == {'test': 10})

root.process_final.assert_called_once_with('args')
leaf1.process_final.assert_called_once_with('test')
leaf2.process_final.assert_called_once_with('test')

def test_train(self):
root = Metric('test')
root.train = Mock()
leaf = Metric('test')
leaf.train = Mock()

tree = MetricTree(root)
tree.add_child(leaf)

tree.train()
root.train.assert_called_once()
leaf.train.assert_called_once()

def test_eval(self):
root = Metric('test')
root.eval = Mock()
leaf = Metric('test')
leaf.eval = Mock()

tree = MetricTree(root)
tree.add_child(leaf)

tree.eval()
root.eval.assert_called_once()
leaf.eval.assert_called_once()

def test_reset(self):
root = Metric('test')
root.reset = Mock()
leaf = Metric('test')
leaf.reset = Mock()

tree = MetricTree(root)
tree.add_child(leaf)

tree.reset({})
root.reset.assert_called_once_with({})
leaf.reset.assert_called_once_with({})


class TestMetricList(unittest.TestCase):
def test_list_in_list(self):
metric = MetricList(['acc', MetricList(['loss'])])
self.assertTrue(metric.metric_list[0].name == 'acc')
self.assertTrue(metric.metric_list[1].name == 'loss')

def test_default_acc(self):
metric = MetricList(['acc'])
self.assertTrue(metric.metric_list[0].name == 'acc', msg='acc not in: ' + str(metric.metric_list))
Expand All @@ -13,6 +101,10 @@ def test_default_loss(self):
metric = MetricList(['loss'])
self.assertTrue(metric.metric_list[0].name == 'loss', msg='loss not in: ' + str(metric.metric_list))

def test_default_epoch(self):
metric = MetricList(['epoch'])
self.assertTrue(metric.metric_list[0].name == 'epoch', msg='loss not in: ' + str(metric.metric_list))

def test_process(self):
my_mock = Metric('test')
my_mock.process = Mock(return_value={'test': -1})
Expand Down Expand Up @@ -49,3 +141,37 @@ def test_reset(self):
metric = MetricList([my_mock])
metric.reset({'state': -1})
my_mock.reset.assert_called_once_with({'state': -1})


class TestAdvancedMetric(unittest.TestCase):
def test_empty_methods(self):
metric = AdvancedMetric('test')

self.assertTrue(metric.process_train() is None)
self.assertTrue(metric.process_final_train() is None)
self.assertTrue(metric.process_validate() is None)
self.assertTrue(metric.process_final_validate() is None)

def test_train(self):
metric = AdvancedMetric('test')
metric.process_train = Mock()
metric.process_final_train = Mock()

metric.train()
metric.process('testing')
metric.process_train.assert_called_once_with('testing')

metric.process_final('testing')
metric.process_final_train.assert_called_once_with('testing')

def test_eval(self):
metric = AdvancedMetric('test')
metric.process_validate = Mock()
metric.process_final_validate = Mock()

metric.eval()
metric.process('testing')
metric.process_validate.assert_called_once_with('testing')

metric.process_final('testing')
metric.process_final_validate.assert_called_once_with('testing')

0 comments on commit 5f04b8e

Please sign in to comment.