Skip to content

Commit

Permalink
Update docs for metrics (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored and MattPainter01 committed Jul 20, 2018
1 parent 8735eee commit 67af6c0
Show file tree
Hide file tree
Showing 12 changed files with 592 additions and 434 deletions.
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __getattr__(cls, name):

sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)

sys.path.insert(0, os.path.abspath('.'))
sys.path.insert(0, os.path.abspath('../'))


Expand All @@ -44,7 +45,7 @@ def __getattr__(cls, name):
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode']
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx']

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
Expand Down
10 changes: 7 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@
contain the root `toctree` directive.
Welcome to torchbearer's documentation!
====================================
=======================================

.. toctree::
:glob:
:maxdepth: 1
:caption: Examples

examples/*
examples/quickstart
examples/vae
examples/gan

.. toctree::
:glob:
:maxdepth: 1
:caption: Package Reference

code/*
code/main
code/callbacks
code/metrics


Indices and tables
Expand Down
140 changes: 29 additions & 111 deletions tests/metrics/test_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,11 @@

from unittest.mock import Mock, call

from torch.autograd import Variable

import torchbearer
from torchbearer.metrics import Std, Metric, Mean, BatchLambda, EpochLambda, ToDict
from torchbearer.metrics import RunningMean, Metric, RunningMetric, Mean, Std

import torch


class TestToDict(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
self._metric.train = Mock()
self._metric.eval = Mock()
self._metric.reset = Mock()
self._metric.process = Mock(return_value='process')
self._metric.process_final = Mock(return_value='process_final')

self._to_dict = ToDict(self._metric)

def test_train_process(self):
self._to_dict.train()
self._metric.train.assert_called_once()

self.assertTrue(self._to_dict.process('input') == {'test': 'process'})
self._metric.process.assert_called_once_with('input')

def test_train_process_final(self):
self._to_dict.train()
self._metric.train.assert_called_once()

self.assertTrue(self._to_dict.process_final('input') == {'test': 'process_final'})
self._metric.process_final.assert_called_once_with('input')

def test_eval_process(self):
self._to_dict.eval()
self._metric.eval.assert_called_once()

self.assertTrue(self._to_dict.process('input') == {'val_test': 'process'})
self._metric.process.assert_called_once_with('input')

def test_eval_process_final(self):
self._to_dict.eval()
self._metric.eval.assert_called_once()

self.assertTrue(self._to_dict.process_final('input') == {'val_test': 'process_final'})
self._metric.process_final.assert_called_once_with('input')

def test_reset(self):
self._to_dict.reset('test')
self._metric.reset.assert_called_once_with('test')


class TestStd(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
Expand Down Expand Up @@ -112,76 +65,41 @@ def test_validate_dict(self):
self.assertAlmostEqual(self._target, result)


class TestBatchLambda(unittest.TestCase):
class TestRunningMetric(unittest.TestCase):
def setUp(self):
self._metric_function = Mock(return_value='test')
self._metric = BatchLambda('test', self._metric_function)
self._states = [{torchbearer.Y_TRUE: Variable(torch.FloatTensor([1])), torchbearer.Y_PRED: Variable(torch.FloatTensor([2]))},
{torchbearer.Y_TRUE: Variable(torch.FloatTensor([3])), torchbearer.Y_PRED: Variable(torch.FloatTensor([4]))},
{torchbearer.Y_TRUE: Variable(torch.FloatTensor([5])), torchbearer.Y_PRED: Variable(torch.FloatTensor([6]))}]
self._metric = RunningMetric('test', batch_size=5, step_size=5)
self._metric.reset({})
self._metric._process_train = Mock(return_value=3)
self._metric._step = Mock(return_value='output')

def test_train(self):
def test_train_called_with_state(self):
self._metric.train()
calls = []
for i in range(len(self._states)):
self._metric.process(self._states[i])
calls.append(call(self._states[i][torchbearer.Y_PRED].data, self._states[i][torchbearer.Y_TRUE].data))
self._metric_function.assert_has_calls(calls)
self._metric.process({'test': -1})
self._metric._process_train.assert_called_with({'test': -1})

def test_validate(self):
self._metric.eval()
calls = []
for i in range(len(self._states)):
self._metric.process(self._states[i])
calls.append(call(self._states[i][torchbearer.Y_PRED].data, self._states[i][torchbearer.Y_TRUE].data))
self._metric_function.assert_has_calls(calls)


class TestEpochLambda(unittest.TestCase):
def setUp(self):
self._metric_function = Mock(return_value='test')
self._metric = EpochLambda('test', self._metric_function, step_size=3)
self._metric.reset({torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32})
self._states = [{torchbearer.BATCH: 0, torchbearer.Y_TRUE: torch.LongTensor([0]), torchbearer.Y_PRED: torch.FloatTensor([0.0]), torchbearer.DEVICE: 'cpu'},
{torchbearer.BATCH: 1, torchbearer.Y_TRUE: torch.LongTensor([1]), torchbearer.Y_PRED: torch.FloatTensor([0.1]), torchbearer.DEVICE: 'cpu'},
{torchbearer.BATCH: 2, torchbearer.Y_TRUE: torch.LongTensor([2]), torchbearer.Y_PRED: torch.FloatTensor([0.2]), torchbearer.DEVICE: 'cpu'},
{torchbearer.BATCH: 3, torchbearer.Y_TRUE: torch.LongTensor([3]), torchbearer.Y_PRED: torch.FloatTensor([0.3]), torchbearer.DEVICE: 'cpu'},
{torchbearer.BATCH: 4, torchbearer.Y_TRUE: torch.LongTensor([4]), torchbearer.Y_PRED: torch.FloatTensor([0.4]), torchbearer.DEVICE: 'cpu'}]

def test_train(self):
def test_cache_one_step(self):
self._metric.train()
calls = [[torch.FloatTensor([0.0]), torch.LongTensor([0])],
[torch.FloatTensor([0.0, 0.1, 0.2, 0.3]), torch.LongTensor([0, 1, 2, 3])]]
for i in range(len(self._states)):
self._metric.process(self._states[i])
self.assertEqual(2, len(self._metric_function.call_args_list))
for i in range(len(self._metric_function.call_args_list)):
self.assertTrue(torch.eq(self._metric_function.call_args_list[i][0][0], calls[i][0]).all)
self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[i][0][1], -calls[i][1])), 1e-12).all)
self._metric_function.reset_mock()
self._metric.process_final({})

self._metric_function.assert_called_once()
self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)
for i in range(6):
self._metric.process({})
self._metric._step.assert_has_calls([call([3]), call([3, 3, 3, 3, 3])])

def test_validate(self):
self._metric.eval()
for i in range(len(self._states)):
self._metric.process(self._states[i])
self._metric_function.assert_not_called()
self._metric.process_final_validate({})
def test_empty_methods(self):
metric = RunningMetric('test')
self.assertTrue(metric._step(['test']) is None)
self.assertTrue(metric._process_train(['test']) is None)

self._metric_function.assert_called_once()
self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)

def test_not_running(self):
metric = EpochLambda('test', self._metric_function, running=False, step_size=6)
metric.reset({torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32})
metric.train()
class TestRunningMean(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
self._mean = RunningMean('test')
self._cache = [1.0, 1.5, 2.0]
self._target = 1.5

for i in range(12):
metric.process(self._states[0])
def test_train(self):
result = self._mean._process_train(torch.FloatTensor([1.0, 1.5, 2.0]))
self.assertAlmostEqual(self._target, result, 3, 0.002)

self._metric_function.assert_not_called()
def test_step(self):
result = self._mean._step(self._cache)
self.assertEqual(self._target, result)
47 changes: 0 additions & 47 deletions tests/metrics/test_running.py

This file was deleted.

129 changes: 129 additions & 0 deletions tests/metrics/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import unittest

from unittest.mock import Mock, call

from torch.autograd import Variable

import torchbearer
from torchbearer.metrics import Std, Metric, Mean, BatchLambda, EpochLambda, ToDict

import torch


class TestToDict(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
self._metric.train = Mock()
self._metric.eval = Mock()
self._metric.reset = Mock()
self._metric.process = Mock(return_value='process')
self._metric.process_final = Mock(return_value='process_final')

self._to_dict = ToDict(self._metric)

def test_train_process(self):
self._to_dict.train()
self._metric.train.assert_called_once()

self.assertTrue(self._to_dict.process('input') == {'test': 'process'})
self._metric.process.assert_called_once_with('input')

def test_train_process_final(self):
self._to_dict.train()
self._metric.train.assert_called_once()

self.assertTrue(self._to_dict.process_final('input') == {'test': 'process_final'})
self._metric.process_final.assert_called_once_with('input')

def test_eval_process(self):
self._to_dict.eval()
self._metric.eval.assert_called_once()

self.assertTrue(self._to_dict.process('input') == {'val_test': 'process'})
self._metric.process.assert_called_once_with('input')

def test_eval_process_final(self):
self._to_dict.eval()
self._metric.eval.assert_called_once()

self.assertTrue(self._to_dict.process_final('input') == {'val_test': 'process_final'})
self._metric.process_final.assert_called_once_with('input')

def test_reset(self):
self._to_dict.reset('test')
self._metric.reset.assert_called_once_with('test')


class TestBatchLambda(unittest.TestCase):
def setUp(self):
self._metric_function = Mock(return_value='test')
self._metric = BatchLambda('test', self._metric_function)
self._states = [{torchbearer.Y_TRUE: Variable(torch.FloatTensor([1])), torchbearer.Y_PRED: Variable(torch.FloatTensor([2]))},
{torchbearer.Y_TRUE: Variable(torch.FloatTensor([3])), torchbearer.Y_PRED: Variable(torch.FloatTensor([4]))},
{torchbearer.Y_TRUE: Variable(torch.FloatTensor([5])), torchbearer.Y_PRED: Variable(torch.FloatTensor([6]))}]

def test_train(self):
self._metric.train()
calls = []
for i in range(len(self._states)):
self._metric.process(self._states[i])
calls.append(call(self._states[i][torchbearer.Y_PRED].data, self._states[i][torchbearer.Y_TRUE].data))
self._metric_function.assert_has_calls(calls)

def test_validate(self):
self._metric.eval()
calls = []
for i in range(len(self._states)):
self._metric.process(self._states[i])
calls.append(call(self._states[i][torchbearer.Y_PRED].data, self._states[i][torchbearer.Y_TRUE].data))
self._metric_function.assert_has_calls(calls)


class TestEpochLambda(unittest.TestCase):
def setUp(self):
self._metric_function = Mock(return_value='test')
self._metric = EpochLambda('test', self._metric_function, step_size=3)
self._metric.reset({torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32})
self._states = [{torchbearer.BATCH: 0, torchbearer.Y_TRUE: torch.LongTensor([0]), torchbearer.Y_PRED: torch.FloatTensor([0.0]), torchbearer.DEVICE: 'cpu'},
{torchbearer.BATCH: 1, torchbearer.Y_TRUE: torch.LongTensor([1]), torchbearer.Y_PRED: torch.FloatTensor([0.1]), torchbearer.DEVICE: 'cpu'},
{torchbearer.BATCH: 2, torchbearer.Y_TRUE: torch.LongTensor([2]), torchbearer.Y_PRED: torch.FloatTensor([0.2]), torchbearer.DEVICE: 'cpu'},
{torchbearer.BATCH: 3, torchbearer.Y_TRUE: torch.LongTensor([3]), torchbearer.Y_PRED: torch.FloatTensor([0.3]), torchbearer.DEVICE: 'cpu'},
{torchbearer.BATCH: 4, torchbearer.Y_TRUE: torch.LongTensor([4]), torchbearer.Y_PRED: torch.FloatTensor([0.4]), torchbearer.DEVICE: 'cpu'}]

def test_train(self):
self._metric.train()
calls = [[torch.FloatTensor([0.0]), torch.LongTensor([0])],
[torch.FloatTensor([0.0, 0.1, 0.2, 0.3]), torch.LongTensor([0, 1, 2, 3])]]
for i in range(len(self._states)):
self._metric.process(self._states[i])
self.assertEqual(2, len(self._metric_function.call_args_list))
for i in range(len(self._metric_function.call_args_list)):
self.assertTrue(torch.eq(self._metric_function.call_args_list[i][0][0], calls[i][0]).all)
self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[i][0][1], -calls[i][1])), 1e-12).all)
self._metric_function.reset_mock()
self._metric.process_final({})

self._metric_function.assert_called_once()
self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)

def test_validate(self):
self._metric.eval()
for i in range(len(self._states)):
self._metric.process(self._states[i])
self._metric_function.assert_not_called()
self._metric.process_final_validate({})

self._metric_function.assert_called_once()
self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)

def test_not_running(self):
metric = EpochLambda('test', self._metric_function, running=False, step_size=6)
metric.reset({torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32})
metric.train()

for i in range(12):
metric.process(self._states[0])

self._metric_function.assert_not_called()

0 comments on commit 67af6c0

Please sign in to comment.