Skip to content

Commit

Permalink
improve type support for fbrlogger
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Apr 22, 2024
1 parent f431e60 commit 4137912
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
16 changes: 15 additions & 1 deletion ignite/handlers/fbresearch_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
MB = 1024.0 * 1024.0


def is_iterable(obj):
try:
iter(obj)
return True
except TypeError:
return False


class FBResearchLogger:
"""Logs training and validation metrics for research purposes.
Expand Down Expand Up @@ -99,8 +107,14 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
output = engine.state.output
if isinstance(output, dict):
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
elif isinstance(output, str):
outputs.append(output)
elif isinstance(output, float):
outputs.append(f"{output:.4f}")
elif is_iterable(output):
outputs += [f"{v}" for v in output]
else:
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore
raise NotImplementedError(f"Output type {type(output)} not supported")

lrs = ""
if optimizer is not None:
Expand Down
32 changes: 30 additions & 2 deletions tests/ignite/handlers/test_fbresearch_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
from unittest.mock import MagicMock

import pytest
import torch
import torch.nn as nn
import torch.optim as optim

from ignite.engine import Engine, Events
from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary
from ignite.engine import create_supervised_trainer, Engine, Events
from ignite.handlers.fbresearch_logger import FBResearchLogger
from ignite.utils import setup_logger


@pytest.fixture
Expand Down Expand Up @@ -56,3 +60,27 @@ def test_output_formatting(mock_engine, fb_research_logger, output, expected_pat

actual_output = fb_research_logger.logger.info.call_args_list[0].args[0]
assert re.search(expected_pattern, actual_output)


def test_logger_type_support():
model = nn.Linear(10, 5)
opt = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

data = [(torch.rand(4, 10), torch.randint(0, 5, size=(4,))) for _ in range(100)]

trainer = create_supervised_trainer(model, opt, criterion)

logger = setup_logger("trainer", level=logging.INFO)
logger = FBResearchLogger(logger=logger, show_output=True)
logger.attach(trainer, name="Train", every=20, optimizer=opt)

trainer.run(data, max_epochs=4)
trainer.state.output = {"loss": 4.2}
trainer.fire_event(Events.ITERATION_COMPLETED)
trainer.state.output = "4.2"
trainer.fire_event(Events.ITERATION_COMPLETED)
trainer.state.output = [4.2, 4.2]
trainer.fire_event(Events.ITERATION_COMPLETED)
trainer.state.output = (4.2, 4.2)
trainer.fire_event(Events.ITERATION_COMPLETED)

0 comments on commit 4137912

Please sign in to comment.