-
-
Notifications
You must be signed in to change notification settings - Fork 604
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add frequency metric to determine some average per-second metrics (#760)
* add frequency metric to determine some average per-second metrics (like wps) * add usage in docstring. * fix flake. * fill the metric on each iteration instead of relying on running averages. * incorporate feedback * add missing import. * add initial tests. TODO: fix them! * fix distributed initialization check in the metric and fix tests. * fix flake * add distributed test. * add device to ensure it runs on CPU. * add non-distributed test * rename FrequencyMetric to Frequency for consistency. * move Frequency to main metrics module from contrib because of lack of 3pp * fix distributed test to incorporate workers in assertions. * fix logic of test. * fix logic to factor the scaling across workers. * fix test. * Fixed accumulation problem in distrib config * Attached reset to epoch started Co-authored-by: vfdev <vfdev.5@gmail.com>
- Loading branch information
Showing
3 changed files
with
124 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
import torch.distributed as dist | ||
|
||
from ignite.engine import Events | ||
from ignite.metrics import Metric | ||
from ignite.handlers.timing import Timer | ||
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced | ||
|
||
|
||
class Frequency(Metric): | ||
"""Provides metrics for the number of examples processed per second. | ||
Examples: | ||
.. code-block:: python | ||
# Compute number of tokens processed | ||
wps_metric = Frequency(output_transformer=lambda x: x['ntokens']) | ||
wps_metric.attach(trainer, name='wps') | ||
# Logging with TQDM | ||
ProgressBar(persist=True).attach(trainer, metric_names=['wps']) | ||
# Progress bar will looks like | ||
# Epoch [2/10]: [12/24] 50%|█████ , wps=400 [00:17<1:23] | ||
""" | ||
|
||
def __init__(self, output_transform=lambda x: x, device=None): | ||
self._timer = None | ||
self._acc = None | ||
self._n = None | ||
self._elapsed = None | ||
super(Frequency, self).__init__(output_transform=output_transform, device=device) | ||
|
||
@reinit__is_reduced | ||
def reset(self): | ||
self._timer = Timer() | ||
self._acc = 0 | ||
self._n = 0 | ||
self._elapsed = 0.0 | ||
super(Frequency, self).reset() | ||
|
||
@reinit__is_reduced | ||
def update(self, output): | ||
self._acc += output | ||
self._n = self._acc | ||
self._elapsed = torch.tensor(self._timer.value(), device=self._device) | ||
|
||
@sync_all_reduce("_n", "_elapsed") | ||
def compute(self): | ||
time_divisor = 1.0 | ||
|
||
if dist.is_available() and dist.is_initialized(): | ||
time_divisor *= dist.get_world_size() | ||
|
||
# Returns the average processed objects per second across all workers | ||
return self._n / self._elapsed.item() * time_divisor | ||
|
||
def completed(self, engine, name): | ||
engine.state.metrics[name] = int(self.compute()) | ||
|
||
def attach(self, engine, name, event_name=Events.ITERATION_COMPLETED): | ||
engine.add_event_handler(Events.EPOCH_STARTED, self.started) | ||
engine.add_event_handler(event_name, self.iteration_completed) | ||
engine.add_event_handler(event_name, self.completed, name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import time | ||
|
||
import pytest | ||
|
||
import torch.distributed as dist | ||
|
||
from ignite.engine import Engine, Events | ||
from ignite.metrics import Frequency | ||
|
||
|
||
def test_nondistributed_average(): | ||
artificial_time = 1 # seconds | ||
num_tokens = 100 | ||
average_upper_bound = num_tokens / artificial_time | ||
average_lower_bound = average_upper_bound * 0.9 | ||
freq_metric = Frequency() | ||
freq_metric.reset() | ||
time.sleep(artificial_time) | ||
freq_metric.update(num_tokens) | ||
average = freq_metric.compute() | ||
assert average_lower_bound < average < average_upper_bound | ||
|
||
|
||
def _test_frequency_with_engine(device, workers): | ||
|
||
artificial_time = 0.1 / workers # seconds | ||
total_tokens = 1200 // workers | ||
batch_size = 128 // workers | ||
|
||
estimated_wps = batch_size * workers / artificial_time | ||
|
||
def update_fn(engine, batch): | ||
time.sleep(artificial_time) | ||
return {"ntokens": len(batch)} | ||
|
||
engine = Engine(update_fn) | ||
wps_metric = Frequency(output_transform=lambda x: x["ntokens"], device=device) | ||
wps_metric.attach(engine, 'wps') | ||
|
||
@engine.on(Events.ITERATION_COMPLETED) | ||
def assert_wps(e): | ||
wps = e.state.metrics['wps'] | ||
assert estimated_wps * 0.85 < wps < estimated_wps, \ | ||
"{}: {} < {} < {}".format(e.state.iteration, estimated_wps * 0.85, wps, estimated_wps) | ||
|
||
data = [[i] * batch_size for i in range(0, total_tokens, batch_size)] | ||
engine.run(data, max_epochs=1) | ||
|
||
|
||
def test_frequency_with_engine(): | ||
device = "cpu" | ||
_test_frequency_with_engine(device, workers=1) | ||
|
||
|
||
@pytest.mark.distributed | ||
def test_frequency_with_engine_distributed(distributed_context_single_node_gloo): | ||
device = "cpu" | ||
_test_frequency_with_engine(device, workers=dist.get_world_size()) |