Skip to content

Commit

Permalink
add frequency metric to determine some average per-second metrics (#760)
Browse files Browse the repository at this point in the history
* add frequency metric to determine some average per-second metrics (like wps)

* add usage in docstring.

* fix flake.

* fill the metric on each iteration instead of relying on running averages.

* incorporate feedback

* add missing import.

* add initial tests. TODO: fix them!

* fix distributed initialization check in the metric and fix tests.

* fix flake

* add distributed test.

* add device to ensure it runs on CPU.

* add non-distributed test

* rename FrequencyMetric to Frequency for consistency.

* move Frequency to main metrics module from contrib because of lack of 3pp

* fix distributed test to incorporate workers in assertions.

* fix logic of test.

* fix logic to factor the scaling across workers.

* fix test.

* Fixed accumulation problem in distrib config

* Attached reset to epoch started

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
erip and vfdev-5 committed Feb 3, 2020
1 parent 918746b commit 0375a6e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 1 deletion.
4 changes: 3 additions & 1 deletion ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ignite.metrics.confusion_matrix import ConfusionMatrix, IoU, mIoU, DiceCoefficient
from ignite.metrics.accumulation import VariableAccumulation, Average, GeometricAverage
from ignite.metrics.fbeta import Fbeta
from ignite.metrics.frequency import Frequency

__all__ = [
'Metric',
Expand All @@ -36,5 +37,6 @@
'Recall',
'RootMeanSquaredError',
'RunningAverage',
'VariableAccumulation'
'VariableAccumulation',
'Frequency'
]
63 changes: 63 additions & 0 deletions ignite/metrics/frequency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.distributed as dist

from ignite.engine import Events
from ignite.metrics import Metric
from ignite.handlers.timing import Timer
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced


class Frequency(Metric):
"""Provides metrics for the number of examples processed per second.
Examples:
.. code-block:: python
# Compute number of tokens processed
wps_metric = Frequency(output_transformer=lambda x: x['ntokens'])
wps_metric.attach(trainer, name='wps')
# Logging with TQDM
ProgressBar(persist=True).attach(trainer, metric_names=['wps'])
# Progress bar will looks like
# Epoch [2/10]: [12/24] 50%|█████ , wps=400 [00:17<1:23]
"""

def __init__(self, output_transform=lambda x: x, device=None):
self._timer = None
self._acc = None
self._n = None
self._elapsed = None
super(Frequency, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self):
self._timer = Timer()
self._acc = 0
self._n = 0
self._elapsed = 0.0
super(Frequency, self).reset()

@reinit__is_reduced
def update(self, output):
self._acc += output
self._n = self._acc
self._elapsed = torch.tensor(self._timer.value(), device=self._device)

@sync_all_reduce("_n", "_elapsed")
def compute(self):
time_divisor = 1.0

if dist.is_available() and dist.is_initialized():
time_divisor *= dist.get_world_size()

# Returns the average processed objects per second across all workers
return self._n / self._elapsed.item() * time_divisor

def completed(self, engine, name):
engine.state.metrics[name] = int(self.compute())

def attach(self, engine, name, event_name=Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
engine.add_event_handler(event_name, self.iteration_completed)
engine.add_event_handler(event_name, self.completed, name)
58 changes: 58 additions & 0 deletions tests/ignite/metrics/test_frequency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import time

import pytest

import torch.distributed as dist

from ignite.engine import Engine, Events
from ignite.metrics import Frequency


def test_nondistributed_average():
artificial_time = 1 # seconds
num_tokens = 100
average_upper_bound = num_tokens / artificial_time
average_lower_bound = average_upper_bound * 0.9
freq_metric = Frequency()
freq_metric.reset()
time.sleep(artificial_time)
freq_metric.update(num_tokens)
average = freq_metric.compute()
assert average_lower_bound < average < average_upper_bound


def _test_frequency_with_engine(device, workers):

artificial_time = 0.1 / workers # seconds
total_tokens = 1200 // workers
batch_size = 128 // workers

estimated_wps = batch_size * workers / artificial_time

def update_fn(engine, batch):
time.sleep(artificial_time)
return {"ntokens": len(batch)}

engine = Engine(update_fn)
wps_metric = Frequency(output_transform=lambda x: x["ntokens"], device=device)
wps_metric.attach(engine, 'wps')

@engine.on(Events.ITERATION_COMPLETED)
def assert_wps(e):
wps = e.state.metrics['wps']
assert estimated_wps * 0.85 < wps < estimated_wps, \
"{}: {} < {} < {}".format(e.state.iteration, estimated_wps * 0.85, wps, estimated_wps)

data = [[i] * batch_size for i in range(0, total_tokens, batch_size)]
engine.run(data, max_epochs=1)


def test_frequency_with_engine():
device = "cpu"
_test_frequency_with_engine(device, workers=1)


@pytest.mark.distributed
def test_frequency_with_engine_distributed(distributed_context_single_node_gloo):
device = "cpu"
_test_frequency_with_engine(device, workers=dist.get_world_size())

0 comments on commit 0375a6e

Please sign in to comment.