Skip to content

Commit

Permalink
Feature/metrics (#214)
Browse files Browse the repository at this point in the history
* Update metric API to use decorators

* Update CHANGELOG.md

* Rename wrappers to aggregators
  • Loading branch information
ethanwharris committed Jul 19, 2018
1 parent 39c782d commit 354c776
Show file tree
Hide file tree
Showing 14 changed files with 377 additions and 223 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added a decorator API for metrics which allows decorators to be used for metric construction
- Added a default_for_key decorator which can be used to associate a string with a given metric in metric lists
### Changed
- Changed the API for running metrics and aggregators to no longer wrap a metric but instead receive input
### Deprecated
### Removed
### Fixed

## [0.1.3] - 2018-07-17
### Added
- Added a flag (step_on_batch) to the LR Scheduler callbacks which allows for step() to be called on each iteration instead of each epoch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ def setUp(self):
torch.FloatTensor([0.4, 0.5, 0.6]),
torch.FloatTensor([0.7, 0.8, 0.9])]

self._std = Std(self._metric)
self._std = Std('test')
self._std.reset({})
self._target = 0.25819888974716

def test_train(self):
self._std.train()
for i in range(3):
self._std.process({})
self._std.process(self._metric.process())
result = self._std.process_final({})
self.assertAlmostEqual(self._target, result)

def test_validate(self):
self._std.eval()
for i in range(3):
self._std.process({})
self._std.process(self._metric.process())
result = self._std.process_final({})
self.assertAlmostEqual(self._target, result)

Expand All @@ -45,21 +45,21 @@ def setUp(self):
torch.FloatTensor([0.4, 0.5, 0.6]),
torch.FloatTensor([0.7, 0.8, 0.9])]

self._mean = Mean(self._metric)
self._mean = Mean('test')
self._mean.reset({})
self._target = 0.5

def test_train_dict(self):
self._mean.train()
for i in range(3):
self._mean.process({})
self._mean.process(self._metric.process())
result = self._mean.process_final({})
self.assertAlmostEqual(self._target, result)

def test_validate_dict(self):
self._mean.eval()
for i in range(3):
self._mean.process({})
self._mean.process(self._metric.process())
result = self._mean.process_final({})
self.assertAlmostEqual(self._target, result)

Expand Down
8 changes: 4 additions & 4 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
class TestMetricList(unittest.TestCase):
def test_default_acc(self):
metric = MetricList(['acc'])
self.assertTrue(metric.metric_list[0].name == 'running_acc', msg='running_acc not in: ' + str(metric.metric_list))
self.assertTrue(metric.metric_list[0].name == 'acc', msg='acc not in: ' + str(metric.metric_list))

def test_default_loss(self):
metric = MetricList(['loss'])
self.assertTrue(metric.metric_list[0].name == 'running_loss', msg='running_loss not in: ' + str(metric.metric_list))
self.assertTrue(metric.metric_list[0].name == 'loss', msg='loss not in: ' + str(metric.metric_list))

def test_process(self):
my_mock = Metric('test')
my_mock.process = Mock(return_value=-1)
my_mock.process = Mock(return_value={'test': -1})
metric = MetricList([my_mock])
result = metric.process({'state': -1})
self.assertEqual({'test': -1}, result)
my_mock.process.assert_called_once_with({'state': -1})

def test_process_final(self):
my_mock = Metric('test')
my_mock.process_final = Mock(return_value=-1)
my_mock.process_final = Mock(return_value={'test': -1})
metric = MetricList([my_mock])
result = metric.process_final({'state': -1})
self.assertEqual({'test': -1}, result)
Expand Down
9 changes: 8 additions & 1 deletion tests/metrics/test_roc_auc_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from unittest.mock import Mock, patch

import torchbearer
from torchbearer.metrics import RocAucScore
from torchbearer.metrics import RocAucScore, MetricList

import torch

Expand Down Expand Up @@ -44,3 +44,10 @@ def test_non_one_hot(self, mock_sklearn_metrics):
mock_sklearn_metrics.roc_auc_score.call_args_list[0][0][1])
except AssertionError:
self.fail('y_pred not correctly passed to sklearn')

def test_default_roc(self):
mlist = MetricList(['roc_auc'])
self.assertTrue(mlist.metric_list[0].name == 'roc_auc_score')

mlist = MetricList(['roc_auc_score'])
self.assertTrue(mlist.metric_list[0].name == 'roc_auc_score')
6 changes: 2 additions & 4 deletions tests/metrics/test_running.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@ def test_cache_one_step(self):
class TestRunningMean(unittest.TestCase):
def setUp(self):
self._metric = Metric('test')
self._metric.process = Mock(return_value=torch.FloatTensor([1.0, 1.5, 2.0]))
self._mean = RunningMean(self._metric)
self._mean = RunningMean('test')
self._cache = [1.0, 1.5, 2.0]
self._target = 1.5

def test_train(self):
result = self._mean._process_train({'test': -1})
self._metric.process.assert_called_with({'test': -1})
result = self._mean._process_train(torch.FloatTensor([1.0, 1.5, 2.0]))
self.assertAlmostEqual(self._target, result, 3, 0.002)

def test_step(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class TestTorchbearer(TestCase):

def test_main_loop_metrics(self):
metric = Metric('test')
metric.process = Mock(return_value=0)
metric.process_final = Mock(return_value=0)
metric.process = Mock(return_value={'test': 0})
metric.process_final = Mock(return_value={'test': 0})
metric.reset = Mock(return_value=None)
metric_list = MetricList([metric])

Expand Down Expand Up @@ -404,8 +404,8 @@ def test_main_loop_callback_calls(self):

def test_test_loop_metrics(self):
metric = Metric('test')
metric.process = Mock(return_value=0)
metric.process_final = Mock(return_value=0)
metric.process = Mock(return_value={'test': 0})
metric.process_final = Mock(return_value={'test': 0})
metric.reset = Mock(return_value=None)
metric_list = MetricList([metric])

Expand Down
23 changes: 13 additions & 10 deletions torchbearer/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
"""
Base Classes
------------------------------------
.. automodule:: torchbearer.metrics.metrics
:members:
:undoc-members:
Metric Wrappers
Decorators - The Decorator API
------------------------------------
.. automodule:: torchbearer.metrics.wrappers
.. automodule:: torchbearer.metrics.decorators
:members:
:undoc-members:
.. automodule:: torchbearer.metrics.running
Metric Aggregators
------------------------------------
.. automodule:: torchbearer.metrics.wrappers
:members:
:undoc-members:
Default Metrics
------------------------------------
.. automodule:: torchbearer.metrics.defaults
.. automodule:: torchbearer.metrics.running
:members:
:undoc-members:
Valued Metrics
Base Metrics
------------------------------------
.. automodule:: torchbearer.metrics.primitives
Expand All @@ -34,8 +37,8 @@
"""

from .metrics import *
from .defaults import *
from .wrappers import *
from .aggregators import *
from .running import *
from .decorators import *
from .roc_auc_score import *
from .primitives import *
124 changes: 42 additions & 82 deletions torchbearer/metrics/wrappers.py → torchbearer/metrics/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,101 +4,68 @@
import torch


def std(metric):
"""Utility function to wrap the given metric in an :class:`Std`.
:param metric: The metric to wrap.
:return: Std -- A standard deviation metric which wraps the input.
"""
return Std(metric)


def mean(metric):
"""Utility function to wrap the given metric in an :class:`Mean`.
:param metric: The metric to wrap.
:return: Mean -- A mean metric which wraps the input.
"""
return Mean(metric)


def statistics(metric):
"""Utility function to wrap the given metric in a set of default statistics.
:param metric: The metric to wrap.
:return: MetricList -- A metric list containing a mean and std.
"""
return metrics.MetricList([mean(metric), std(metric)])


stats = statistics
class ToDict(metrics.AdvancedMetric):
def __init__(self, metric):
super(ToDict, self).__init__(metric.name)

self.metric = metric

class Wrapper(metrics.Metric):
"""Basic metric wrapper class which masks the processing methods.
"""
def process_train(self, *args):
val = self.metric.process(*args)
if val is not None:
return {self.metric.name: val}

def __init__(self, metric, postfix):
"""Wrap the given metric and append the given string to the metric name.
def process_validate(self, *args):
val = self.metric.process(*args)
if val is not None:
return {'val_' + self.metric.name: val}

:param metric: The metric to wrap.
:type metric: Metric
:param postfix: String to add to the metric name.
:type postfix: str
def process_final_train(self, *args):
val = self.metric.process_final(*args)
if val is not None:
return {self.metric.name: val}

"""
super().__init__(metric.name + postfix)
self._metric = metric
def process_final_validate(self, *args):
val = self.metric.process_final(*args)
if val is not None:
return {'val_' + self.metric.name: val}

def eval(self):
"""Call eval on the underlying metric.
"""
super().eval()
self._metric.eval()
self.metric.eval()

def train(self):
"""Call train on the underlying metric.
"""
super().train()
self._metric.train()
self.metric.train()

def reset(self, state):
"""Call reset on the underlying metric.
"""
super().reset(state)
self._metric.reset(state)
self.metric.reset(state)


class Std(Wrapper):
class Std(metrics.Metric):
"""Metric wrapper which calculates the standard deviation of process outputs between calls to reset.
"""

def __init__(self, metric):
"""Wrap the given metric.
:param metric: The metric to wrap.
:type metric: Metric
"""
super().__init__(metric, '_std')
def __init__(self, name):
super(Std, self).__init__(name)

def process(self, state):
def process(self, data):
"""Process the wrapped metric and compute values required for the std.
:param state: The model state.
:type state: dict
"""
result = self._metric.process(state)
self._sum += result.sum().item()
self._sum_sq += result.pow(2).sum().item()
self._sum += data.sum().item()
self._sum_sq += data.pow(2).sum().item()

if result.size() == torch.Size([]):
if data.size() == torch.Size([]):
self._count += 1
else:
self._count += result.size(0)
self._count += data.size(0)

def process_final(self, state):
def process_final(self, data):
"""Compute and return the final standard deviation.
:param state: The model state.
Expand All @@ -123,35 +90,28 @@ def reset(self, state):
self._count = 0


class Mean(Wrapper):
class Mean(metrics.Metric):
"""Metric wrapper which calculates the mean value of a series of observations between reset calls.
"""

def __init__(self, metric):
"""Wrap the given metric.
:param metric: The metric to wrap.
:type metric: Metric
def __init__(self, name):
super(Mean, self).__init__(name)

"""
super().__init__(metric, '')

def process(self, state):
def process(self, data):
"""Compute the metric value and add it to the rolling sum.
:param state: The model state.
:type state: dict
"""
result = self._metric.process(state)
self._sum += result.sum().item()
self._sum += data.sum().item()

if result.size() == torch.Size([]):
if data.size() == torch.Size([]):
self._count += 1
else:
self._count += result.size(0)
self._count += data.size(0)

def process_final(self, state):
def process_final(self, data):
"""Compute and return the mean of all metric values since the last call to reset.
:param state: The model state.
Expand Down Expand Up @@ -185,7 +145,7 @@ def __init__(self, name, metric_function):
:param metric_function: A metric function('y_pred', 'y_true') to wrap.
"""
super().__init__(name)
super(BatchLambda, self).__init__(name)
self._metric_function = metric_function

def process(self, state):
Expand Down Expand Up @@ -217,7 +177,7 @@ def __init__(self, name, metric_function, running=True, step_size=50):
:type step_size: int
"""
super().__init__(name)
super(EpochLambda, self).__init__(name)
self._step = metric_function
self._final = metric_function
self._step_size = step_size
Expand Down

0 comments on commit 354c776

Please sign in to comment.