From 4137912bc32de178e02cf15f1149c59b6760ed4c Mon Sep 17 00:00:00 2001 From: leej3 Date: Mon, 22 Apr 2024 16:06:53 +0100 Subject: [PATCH] improve type support for fbrlogger --- ignite/handlers/fbresearch_logger.py | 16 +++++++++- .../ignite/handlers/test_fbresearch_logger.py | 32 +++++++++++++++++-- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/ignite/handlers/fbresearch_logger.py b/ignite/handlers/fbresearch_logger.py index a291138e48d5..879c066f4a8d 100644 --- a/ignite/handlers/fbresearch_logger.py +++ b/ignite/handlers/fbresearch_logger.py @@ -14,6 +14,14 @@ MB = 1024.0 * 1024.0 +def is_iterable(obj): + try: + iter(obj) + return True + except TypeError: + return False + + class FBResearchLogger: """Logs training and validation metrics for research purposes. @@ -99,8 +107,14 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] = output = engine.state.output if isinstance(output, dict): outputs += [f"{k}: {v:.4f}" for k, v in output.items()] + elif isinstance(output, str): + outputs.append(output) + elif isinstance(output, float): + outputs.append(f"{output:.4f}") + elif is_iterable(output): + outputs += [f"{v}" for v in output] else: - outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore + raise NotImplementedError(f"Output type {type(output)} not supported") lrs = "" if optimizer is not None: diff --git a/tests/ignite/handlers/test_fbresearch_logger.py b/tests/ignite/handlers/test_fbresearch_logger.py index b85bdcf2794e..ce67f62aaa70 100644 --- a/tests/ignite/handlers/test_fbresearch_logger.py +++ b/tests/ignite/handlers/test_fbresearch_logger.py @@ -3,9 +3,13 @@ from unittest.mock import MagicMock import pytest +import torch +import torch.nn as nn +import torch.optim as optim -from ignite.engine import Engine, Events -from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary +from ignite.engine import create_supervised_trainer, Engine, Events +from ignite.handlers.fbresearch_logger import FBResearchLogger +from ignite.utils import setup_logger @pytest.fixture @@ -56,3 +60,27 @@ def test_output_formatting(mock_engine, fb_research_logger, output, expected_pat actual_output = fb_research_logger.logger.info.call_args_list[0].args[0] assert re.search(expected_pattern, actual_output) + + +def test_logger_type_support(): + model = nn.Linear(10, 5) + opt = optim.SGD(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + data = [(torch.rand(4, 10), torch.randint(0, 5, size=(4,))) for _ in range(100)] + + trainer = create_supervised_trainer(model, opt, criterion) + + logger = setup_logger("trainer", level=logging.INFO) + logger = FBResearchLogger(logger=logger, show_output=True) + logger.attach(trainer, name="Train", every=20, optimizer=opt) + + trainer.run(data, max_epochs=4) + trainer.state.output = {"loss": 4.2} + trainer.fire_event(Events.ITERATION_COMPLETED) + trainer.state.output = "4.2" + trainer.fire_event(Events.ITERATION_COMPLETED) + trainer.state.output = [4.2, 4.2] + trainer.fire_event(Events.ITERATION_COMPLETED) + trainer.state.output = (4.2, 4.2) + trainer.fire_event(Events.ITERATION_COMPLETED)