Skip to content

Commit

Permalink
add fbresearch_logger.py
Browse files Browse the repository at this point in the history
Add FBResearchLogger class from unmerged branch object-detection-example

Add minimal docs and tests
  • Loading branch information
leej3 committed Mar 25, 2024
1 parent a7246e1 commit f7fe4a2
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Loggers

visdom_logger
wandb_logger
fbresearch_logger

.. seealso::

Expand Down
141 changes: 141 additions & 0 deletions ignite/handlers/fbresearch_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""FBResearch logger and its helper handlers."""

import datetime

import torch

from ignite.engine import Engine, Events
from ignite.handlers import Timer


MB = 1024.0 * 1024.0


class FBResearchLogger:
"""Logs training and validation metrics for research purposes.
This logger is designed to attach to an Ignite Engine and log various metrics
and system stats at configurable intervals, including learning rates, iteration
times, and GPU memory usage.
Args:
logger (logging.Logger): The logger to use for output.
delimiter (str): The delimiter to use between metrics in the log output.
show_output (bool): Flag to enable logging of the output from the engine's process function.
Examples:
.. code-block:: python
import logging
from ignite.handlers.fbresearch_logger import *
logger = FBResearchLogger(logger=logging.Logger(__name__), show_output=True)
logger.attach(trainer, name="Train", every=10, optimizer=my_optimizer)
"""

def __init__(self, logger, delimiter=" ", show_output=False):
self.delimiter = delimiter
self.logger = logger
self.iter_timer = None
self.data_timer = None
self.show_output = show_output

def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
"""Attaches all the logging handlers to the given engine.
Args:
engine (Engine): The engine to attach the logging handlers to.
name (str): The name of the engine (e.g., "Train", "Validate") to include in log messages.
every (int): Frequency of iterations to log information. Logs are generated every 'every' iterations.
optimizer: The optimizer used during training to log current learning rates.
"""
engine.add_event_handler(Events.EPOCH_STARTED, self.log_epoch_started, engine, name)
engine.add_event_handler(Events.ITERATION_COMPLETED(every=every), self.log_every, engine, optimizer=optimizer)
engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name)
engine.add_event_handler(Events.COMPLETED, self.log_completed, engine, name)

self.iter_timer = Timer(average=True)
self.iter_timer.attach(
engine,
start=Events.EPOCH_STARTED,
resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED,
step=Events.ITERATION_COMPLETED,
)
self.data_timer = Timer(average=True)
self.data_timer.attach(
engine,
start=Events.EPOCH_STARTED,
resume=Events.GET_BATCH_STARTED,
pause=Events.GET_BATCH_COMPLETED,
step=Events.GET_BATCH_COMPLETED,
)

def log_every(self, engine, optimizer=None):
cuda_max_mem = ""
if torch.cuda.is_available():
cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"

current_iter = engine.state.iteration % (engine.state.epoch_length + 1)
iter_avg_time = self.iter_timer.value()

eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)

outputs = []
if self.show_output:
output = engine.state.output
if isinstance(output, dict):
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
else:
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output]

lrs = ""
if optimizer is not None:
if len(optimizer.param_groups) == 1:
lrs += f"lr: {optimizer.param_groups[0]['lr']:.5f}"
else:
for i, g in enumerate(optimizer.param_groups):
lrs += f"lr [g{i}]: {g['lr']:.5f}"

msg = self.delimiter.join(
[
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]",
f"[{current_iter}/{engine.state.epoch_length}]:",
f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}",
f"{lrs}",
]
+ outputs
+ [
f"Iter time: {iter_avg_time:.4f} s",
f"Data prep time: {self.data_timer.value():.4f} s",
cuda_max_mem,
]
)
self.logger.info(msg)

def log_epoch_started(self, engine, name):
msg = f"{name}: start epoch [{engine.state.epoch}/{engine.state.max_epochs}]"
self.logger.info(msg)

def log_epoch_completed(self, engine, name):
epoch_time = engine.state.times[Events.EPOCH_COMPLETED.name]
epoch_info = f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]" if engine.state.max_epochs > 1 else ""
msg = self.delimiter.join(
[
f"{name}: {epoch_info}",
f"Total time: {datetime.timedelta(seconds=int(epoch_time))}",
f"({epoch_time / engine.state.epoch_length:.4f} s / it)",
]
)
self.logger.info(msg)

def log_completed(self, engine, name):
if engine.state.max_epochs > 1:
total_time = engine.state.times[Events.COMPLETED.name]
msg = self.delimiter.join(
[
f"{name}: run completed",
f"Total time: {datetime.timedelta(seconds=int(total_time))}",
]
)
self.logger.info(msg)
46 changes: 46 additions & 0 deletions tests/ignite/handlers/test_fbresearch_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import logging
import re
from unittest.mock import MagicMock

import pytest

from ignite.engine import Engine, Events
from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary


@pytest.fixture
def mock_engine():
engine = Engine(lambda e, b: None)
engine.state.epoch = 1
engine.state.max_epochs = 10
engine.state.epoch_length = 100
engine.state.iteration = 50
return engine


@pytest.fixture
def mock_logger():
return MagicMock(spec=logging.Logger)


@pytest.fixture
def fb_research_logger(mock_logger):
yield FBResearchLogger(logger=mock_logger, show_output=True)


@pytest.mark.parametrize(
"output,expected_pattern",
[
({"loss": 0.456, "accuracy": 0.789}, r"loss. *0.456.*accuracy. *0.789"),
((0.456, 0.789), r"0.456.*0.789"),
([0.456, 0.789], r"0.456.*0.789"),
],
)
def test_output_formatting(mock_engine, fb_research_logger, output, expected_pattern):
# Ensure the logger correctly formats and logs the output for each type
mock_engine.state.output = output
fb_research_logger.attach(mock_engine, name="Test", every=1)
mock_engine.fire_event(Events.ITERATION_COMPLETED)

actual_output = fb_research_logger.logger.info.call_args_list[0].args[0]
assert re.search(expected_pattern, actual_output)

0 comments on commit f7fe4a2

Please sign in to comment.