Skip to content

Commit

Permalink
Update docs for metrics (#227)
Browse files Browse the repository at this point in the history
* Update docs for metrics

* Fix epoch metric
  • Loading branch information
ethanwharris committed Jul 20, 2018
1 parent da8a481 commit 3027854
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 69 deletions.
7 changes: 4 additions & 3 deletions torchbearer/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
:members:
:undoc-members:
Metric Wrappers / Aggregators
Metric Wrappers
------------------------------------
.. automodule:: torchbearer.metrics.wrappers
:members:
:undoc-members:
Metric Aggregators
------------------------------------
.. automodule:: torchbearer.metrics.aggregators
:members:
:undoc-members:
Expand All @@ -29,11 +32,9 @@
.. automodule:: torchbearer.metrics.primitives
:members:
:undoc-members:
.. automodule:: torchbearer.metrics.roc_auc_score
:members:
:undoc-members:
"""

from .metrics import *
Expand Down
78 changes: 41 additions & 37 deletions torchbearer/metrics/aggregators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
Aggregators are a special kind of :class:`.Metric` which takes as input, the output from a previous metric or metrics.
As a result, via a :class:`.MetricTree`, a series of aggregators can collect statistics such as Mean or Standard
Deviation without needing to compute the underlying metric multiple times. This can, however, make the aggregators
complex to use. It is therefore typically better to use the :mod:`decorator API<.decorators>`.
"""
from torchbearer import metrics
from abc import ABCMeta, abstractmethod
from collections import deque
Expand All @@ -10,22 +16,18 @@ class RunningMetric(metrics.AdvancedMetric):
.. note::
This class only provides output during training.
Running metrics only provide output during training.
:param name: The name of the metric.
:type name: str
:param batch_size: The size of the deque to store of previous results.
:type batch_size: int
:param step_size: The number of iterations between aggregations.
:type step_size: int
"""
__metaclass__ = ABCMeta

def __init__(self, name, batch_size=50, step_size=10):
"""Initialise the deque of results.
:param name: The name of the metric. Will be prepended with 'running_'.
:type name: str
:param batch_size: The size of the deque to store of previous results.
:type batch_size: int
:param step_size: The number of iterations between aggregations.
:type step_size: int
"""
super().__init__(name)
self._batch_size = batch_size
self._step_size = step_size
Expand All @@ -36,8 +38,7 @@ def __init__(self, name, batch_size=50, step_size=10):
def _process_train(self, *args):
"""Process the metric for a single train step.
:param state: The current model state.
:type state: dict
:param args: The output of some :class:`.Metric`
:return: The metric value.
"""
Expand All @@ -57,8 +58,7 @@ def _step(self, cache):
def process_train(self, *args):
"""Add the current metric value to the cache and call '_step' is needed.
:param state: The current model state.
:type state: dict
:param args: The output of some :class:`.Metric`
:return: The current metric value.
"""
Expand All @@ -81,19 +81,17 @@ def reset(self, state):


class RunningMean(RunningMetric):
"""A running metric wrapper which outputs the mean of a sequence of observations.
"""A :class:`RunningMetric` which outputs the mean of a sequence of its input over the course of an epoch.
:param name: The name of this running mean.
:type name: str
:param batch_size: The size of the deque to store of previous results.
:type batch_size: int
:param step_size: The number of iterations between aggregations.
:type step_size: int
"""

def __init__(self, name, batch_size=50, step_size=10):
"""Wrap the given metric in initialise the parent :class:`RunningMetric`.
:param metric: The metric to wrap.
:type metric: Metric
:param batch_size: The size of the deque to store of previous results.
:type batch_size: int
:param step_size: The number of iterations between aggregations.
:type step_size: int
"""
super().__init__(name, batch_size=batch_size, step_size=step_size)

def _process_train(self, data):
Expand All @@ -104,17 +102,20 @@ def _step(self, cache):


class Std(metrics.Metric):
"""Metric wrapper which calculates the standard deviation of process outputs between calls to reset.
"""Metric aggregator which calculates the standard deviation of process outputs between calls to reset.
:param name: The name of this metric.
:type name: str
"""

def __init__(self, name):
super(Std, self).__init__(name)

def process(self, data):
"""Process the wrapped metric and compute values required for the std.
"""Compute values required for the std from the input.
:param state: The model state.
:type state: dict
:param data: The output of some previous call to :meth:`.Metric.process`.
:type data: torch.Tensor
"""
self._sum += data.sum().item()
Expand All @@ -128,8 +129,8 @@ def process(self, data):
def process_final(self, data):
"""Compute and return the final standard deviation.
:param state: The model state.
:type state: dict
:param data: The output of some previous call to :meth:`.Metric.process_final`.
:type data: torch.Tensor
:return: The standard deviation of each observation since the last reset call.
"""
Expand All @@ -151,17 +152,20 @@ def reset(self, state):


class Mean(metrics.Metric):
"""Metric wrapper which calculates the mean value of a series of observations between reset calls.
"""Metric aggregator which calculates the mean of process outputs between calls to reset.
:param name: The name of this metric.
:type name: str
"""

def __init__(self, name):
super(Mean, self).__init__(name)

def process(self, data):
"""Compute the metric value and add it to the rolling sum.
"""Add the input to the rolling sum.
:param state: The model state.
:type state: dict
:param data: The output of some previous call to :meth:`.Metric.process`.
:type data: torch.Tensor
"""
self._sum += data.sum().item()
Expand All @@ -174,8 +178,8 @@ def process(self, data):
def process_final(self, data):
"""Compute and return the mean of all metric values since the last call to reset.
:param state: The model state.
:type state: dict
:param data: The output of some previous call to :meth:`.Metric.process_final`.
:type data: torch.Tensor
:return: The mean of the metric values since the last call to reset.
"""
Expand Down
3 changes: 1 addition & 2 deletions torchbearer/metrics/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
def default_for_key(key):
"""The :func:`default_for_key` decorator will register the given metric in the global metric dict
(`metrics.DEFAULT_METRICS`) so that it can be referenced by name in instances of :class:`.MetricList` such as in the
list given to
:meth:`.torchbearer.Model.fit`.
list given to the :class:`.torchbearer.Model`.
Example: ::
Expand Down
14 changes: 11 additions & 3 deletions torchbearer/metrics/primitives.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
Base metrics are the base classes which represent the metrics supplied with torchbearer. The all use the
:func:`.default_for_key` decorator so that they can be accessed in the call to :class:`.torchbearer.Model` via the
following strings:
- '`acc`' or '`accuracy`': The :class:`.CategoricalAccuracy` metric
- '`loss`': The :class:`.Loss` metric
- '`epoch`': The :class:`.Epoch` metric
- '`roc_auc`' or '`roc_auc_score`': The :class:`.RocAucScore` metric
"""
import torchbearer
from torchbearer import metrics

Expand Down Expand Up @@ -63,9 +73,7 @@ def _process(self, state):


@metrics.default_for_key('epoch')
@metrics.running_mean
@metrics.std
@metrics.mean
@metrics.to_dict
class EpochFactory(metrics.MetricFactory):
def build(self):
return Epoch()
70 changes: 46 additions & 24 deletions torchbearer/metrics/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,38 @@
"""
Metric wrappers are classes which wrap instances of :class:`.Metric` or, in the case of :class:`EpochLambda` and
:class:`BatchLambda`, functions. Typically, these should **not** be used directly (although this is entirely possible), but
via the :mod:`decorator API<.decorators>`.
"""
import torchbearer
from torchbearer import metrics

import torch


class ToDict(metrics.AdvancedMetric):
"""The :class:`ToDict` class is an :class:`.AdvancedMetric` which will put output from the inner :class:`.Metric` in
a dict (mapping metric name to value) before returning. When in `eval` mode, 'val\_' will be prepended to the metric
name.
Example: ::
>>> from torchbearer import metrics
>>> @metrics.lambda_metric('my_metric')
... def my_metric(y_pred, y_true):
... return y_pred + y_true
...
>>> metric = metrics.ToDict(my_metric().build())
>>> metric.process({'y_pred': 4, 'y_true': 5})
{'my_metric': 9}
>>> metric.eval()
>>> metric.process({'y_pred': 4, 'y_true': 5})
{'val_my_metric': 9}
:param metric: The :class:`.Metric` instance to *wrap*.
:type metric: metrics.Metric
"""

def __init__(self, metric):
super(ToDict, self).__init__(metric.name)

Expand Down Expand Up @@ -45,23 +73,20 @@ def reset(self, state):

class BatchLambda(metrics.Metric):
"""A metric which returns the output of the given function on each batch.
:param name: The name of the metric.
:type name: str
:param metric_function: A metric function('y_pred', 'y_true') to wrap.
"""

def __init__(self, name, metric_function):
"""Construct a metric with the given name which wraps the given function.
:param name: The name of the metric.
:type name: str
:param metric_function: A metric function('y_pred', 'y_true') to wrap.
"""
super(BatchLambda, self).__init__(name)
self._metric_function = metric_function

def process(self, state):
"""Return the output of the wrapped function.
:param state: The model state.
:param state: The :class:`.torchbearer.Model` state.
:type state: dict
:return: The value of the metric function('y_pred', 'y_true').
Expand All @@ -73,20 +98,17 @@ class EpochLambda(metrics.AdvancedMetric):
"""A metric wrapper which computes the given function for concatenated values of 'y_true' and 'y_pred' each epoch.
Can be used as a running metric which computes the function for batches of outputs with a given step size during
training.
:param name: The name of the metric.
:type name: str
:param metric_function: The function('y_pred', 'y_true') to use as the metric.
:param running: True if this should act as a running metric.
:type running: bool
:param step_size: Step size to use between calls if running=True.
:type step_size: int
"""

def __init__(self, name, metric_function, running=True, step_size=50):
"""Wrap the given function as a metric with the given name.
:param name: The name of the metric.
:type name: str
:param metric_function: The function('y_pred', 'y_true') to use as the metric.
:param running: True if this should act as a running metric.
:type running: bool
:param step_size: Step size to use between calls if running=True.
:type step_size: int
"""
super(EpochLambda, self).__init__(name)
self._step = metric_function
self._final = metric_function
Expand All @@ -100,7 +122,7 @@ def process_train(self, state):
"""Concatenate the 'y_true' and 'y_pred' from the state along the 0 dimension. If this is a running metric,
evaluates the function every number of steps.
:param state: The model state.
:param state: The :class:`.torchbearer.Model` state.
:type state: dict
:return: The current running result.
Expand All @@ -114,7 +136,7 @@ def process_train(self, state):
def process_final_train(self, state):
"""Evaluate the function with the aggregated outputs.
:param state: The model state.
:param state: The :class:`.torchbearer.Model` state.
:type state: dict
:return: The result of the function.
Expand All @@ -124,7 +146,7 @@ def process_final_train(self, state):
def process_validate(self, state):
"""During validation, just concatenate 'y_true' and y_pred'.
:param state: The model state.
:param state: The :class:`.torchbearer.Model` state.
:type state: dict
"""
Expand All @@ -134,7 +156,7 @@ def process_validate(self, state):
def process_final_validate(self, state):
"""Evaluate the function with the aggregated outputs.
:param state: The model state.
:param state: The :class:`.torchbearer.Model` state.
:type state: dict
:return: The result of the function.
Expand All @@ -144,7 +166,7 @@ def process_final_validate(self, state):
def reset(self, state):
"""Reset the 'y_true' and 'y_pred' caches.
:param state: The model state.
:param state: The :class:`.torchbearer.Model` state.
:type state: dict
"""
Expand Down

0 comments on commit 3027854

Please sign in to comment.