Skip to content

Commit

Permalink
Improve RunningAverage reset when epoch_bound=False (#2950)
Browse files Browse the repository at this point in the history
* Do the improvement

* A few bug fix in test

* two improvements in test
  • Loading branch information
sadra-barikbin committed May 22, 2023
1 parent 1df9932 commit a99ea7f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
2 changes: 2 additions & 0 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = Ep
if self.epoch_bound:
# restart average every epoch
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
else:
engine.add_event_handler(Events.STARTED, self.started)
# compute metric
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
# apply running average
Expand Down
14 changes: 8 additions & 6 deletions tests/ignite/metrics/test_running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def test_epoch_unbound():
batch_size = 10
n_classes = 10
data = list(range(n_iters))
loss_values = iter(range(n_epochs * n_iters))
y_true_batch_values = iter(np.random.randint(0, n_classes, size=(n_epochs * n_iters, batch_size)))
y_pred_batch_values = iter(np.random.rand(n_epochs * n_iters, batch_size, n_classes))
loss_values = iter(range(2 * n_epochs * n_iters))
y_true_batch_values = iter(np.random.randint(0, n_classes, size=(2 * n_epochs * n_iters, batch_size)))
y_pred_batch_values = iter(np.random.rand(2 * n_epochs * n_iters, batch_size, n_classes))

def update_fn(engine, batch):
loss_value = next(loss_values)
Expand All @@ -146,9 +146,7 @@ def update_fn(engine, batch):

running_avg_acc = [None]

@trainer.on(Events.STARTED)
def running_avg_output_init(engine):
engine.state.running_avg_output = None
trainer.state.running_avg_output = None

@trainer.on(Events.ITERATION_COMPLETED, running_avg_acc)
def manual_running_avg_acc(engine, running_avg_acc):
Expand Down Expand Up @@ -187,6 +185,10 @@ def assert_equal_running_avg_output_values(engine):

trainer.run(data, max_epochs=3)

running_avg_acc[0] = None
trainer.state.running_avg_output = None
trainer.run(data, max_epochs=3)


def test_multiple_attach():
n_iters = 100
Expand Down

0 comments on commit a99ea7f

Please sign in to comment.