Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve type support for fbrlogger #3238

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 26 additions & 8 deletions ignite/handlers/fbresearch_logger.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""FBResearch logger and its helper handlers."""

import datetime
from typing import Any, Optional

# from typing import Any, Dict, Optional, Union
from typing import Any, Callable, List, Optional

import torch

from ignite import utils
from ignite.engine import Engine, Events
from ignite.handlers import Timer
from ignite.handlers.utils import global_step_from_engine # noqa


MB = 1024.0 * 1024.0

__all__ = ["FBResearchLogger", "global_step_from_engine"]


class FBResearchLogger:
"""Logs training and validation metrics for research purposes.
Expand Down Expand Up @@ -98,16 +100,27 @@ def __init__(self, logger: Any, delimiter: str = " ", show_output: bool = False
self.show_output: bool = show_output

def attach(
self, engine: Engine, name: str, every: int = 1, optimizer: Optional[torch.optim.Optimizer] = None
self,
engine: Engine,
name: str,
every: int = 1,
output_transform: Optional[Callable] = None,
state_attributes: Optional[List[str]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
) -> None:
"""Attaches all the logging handlers to the given engine.

Args:
engine: The engine to attach the logging handlers to.
name: The name of the engine (e.g., "Train", "Validate") to include in log messages.
every: Frequency of iterations to log information. Logs are generated every 'every' iterations.
output_transform: A function to select the value to log.
state_attributes: A list of attributes to log.
optimizer: The optimizer used during training to log current learning rates.
"""
self.name = name
self.output_transform = output_transform
self.state_attributes = state_attributes
engine.add_event_handler(Events.EPOCH_STARTED, self.log_epoch_started, engine, name)
engine.add_event_handler(Events.ITERATION_COMPLETED(every=every), self.log_every, engine, optimizer=optimizer)
engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name)
Expand Down Expand Up @@ -151,10 +164,9 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
outputs = []
if self.show_output and engine.state.output is not None:
output = engine.state.output
if isinstance(output, dict):
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
else:
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore
if self.output_transform is not None:
output = self.output_transform(output)
outputs = utils._to_str_list(output)

lrs = ""
if optimizer is not None:
Expand All @@ -164,6 +176,11 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
for i, g in enumerate(optimizer.param_groups):
lrs += f"lr [g{i}]: {g['lr']:.5f}"

state_attrs = []
if self.state_attributes is not None:
state_attrs = utils._to_str_list(
{name: getattr(engine.state, name, None) for name in self.state_attributes}
)
msg = self.delimiter.join(
[
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]",
Expand All @@ -172,6 +189,7 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] =
f"{lrs}",
]
+ outputs
+ [" ".join(state_attrs)]
+ [
f"Iter time: {iter_avg_time:.4f} s",
f"Data prep time: {self.data_timer.value():.4f} s",
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MaximumMeanDiscrepancy(Metric):

More details can be found in `Gretton et al. 2012`__.

__ https://jmlr.csail.mit.edu/papers/v13/gretton12a.html
__ https://www.jmlr.org/papers/volume13/gretton12a/gretton12a.pdf

- ``update`` must receive output of the form ``(x, y)``.
- ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`.
Expand Down
78 changes: 78 additions & 0 deletions ignite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import hashlib
import logging
import numbers
import random
import shutil
import warnings
Expand All @@ -14,6 +15,7 @@
"convert_tensor",
"apply_to_tensor",
"apply_to_type",
"_to_str_list",
"to_onehot",
"setup_logger",
"manual_seed",
Expand Down Expand Up @@ -90,6 +92,82 @@ def _tree_map(
return func(x, key=key)


def _to_str_list(data: Any) -> List[str]:
"""
Recursively flattens and formats complex data structures, including keys for
dictionaries, into a list of human-readable strings.

This function processes nested dictionaries, lists, tuples, numbers, and
PyTorch tensors, formatting numbers to four decimal places and handling
tensors with special formatting rules. It's particularly useful for logging,
debugging, or any scenario where a human-readable representation of complex,
nested data structures is required.

The function handles the following types:

- Numbers: Formatted to four decimal places.
- PyTorch tensors:
- Scalars are formatted to four decimal places.
- 1D tensors with more than 10 elements show the first 10 elements
followed by an ellipsis.
- 1D tensors with 10 or fewer elements are fully listed.
- Multi-dimensional tensors display their shape.
- Dictionaries: Each key-value pair is included in the output with the key
as a prefix.
- Lists and tuples: Flattened and included in the output. Empty lists/tuples are represented
by an empty string.
- None values: Represented by an empty string.

Args:
data: The input data to be flattened and formatted. It can be a nested
combination of dictionaries, lists, tuples, numbers, and PyTorch
tensors.

Returns:
A list of formatted strings, each representing a part of the input data
structure.
"""
formatted_items: List[str] = []

def format_item(item: Any, prefix: str = "") -> Optional[str]:
if isinstance(item, numbers.Number):
return f"{prefix}{item:.4f}"
elif torch.is_tensor(item):
if item.dim() == 0:
return f"{prefix}{item.item():.4f}" # Format scalar tensor without brackets
elif item.dim() == 1 and item.size(0) > 10:
return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item[:10]) + ", ...]"
elif item.dim() == 1:
return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item) + "]"
else:
return f"{prefix} {list(item.shape)}"
elif isinstance(item, dict):
for key, value in item.items():
formatted_value = format_item(value, f"{key}: ")
if formatted_value is not None:
formatted_items.append(formatted_value)
elif isinstance(item, (list, tuple)):
if not item:
if prefix:
formatted_items.append(f"{prefix}")
else:
values = [format_item(x) for x in item]
values_str = [v for v in values if v is not None]
if values_str:
formatted_items.append(f"{prefix}" + ", ".join(values_str))
elif item is None:
if prefix:
formatted_items.append(f"{prefix}")
return None

# Directly handle single numeric values
if isinstance(data, numbers.Number):
return [f"{data:.4f}"]

format_item(data)
return formatted_items


class _CollectionItem:
types_as_collection_item: Tuple = (int, float, torch.Tensor)

Expand Down
52 changes: 50 additions & 2 deletions tests/ignite/handlers/test_fbresearch_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
from unittest.mock import MagicMock

import pytest
import torch
import torch.nn as nn
import torch.optim as optim

from ignite.engine import Engine, Events
from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary
from ignite.engine import create_supervised_trainer, Engine, Events
from ignite.handlers.fbresearch_logger import FBResearchLogger
from ignite.utils import setup_logger


@pytest.fixture
Expand Down Expand Up @@ -56,3 +60,47 @@ def test_output_formatting(mock_engine, fb_research_logger, output, expected_pat

actual_output = fb_research_logger.logger.info.call_args_list[0].args[0]
assert re.search(expected_pattern, actual_output)


def test_logger_type_support():
model = nn.Linear(10, 5)
opt = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

data = [(torch.rand(4, 10), torch.randint(0, 5, size=(4,))) for _ in range(100)]

trainer = create_supervised_trainer(model, opt, criterion)

logger = setup_logger("trainer", level=logging.INFO)
logger = FBResearchLogger(logger=logger, show_output=True)
logger.attach(trainer, name="Train", every=20, optimizer=opt)

trainer.run(data, max_epochs=4)
trainer.state.output = {"loss": 4.2}
trainer.fire_event(Events.ITERATION_COMPLETED)
trainer.state.output = "4.2"
trainer.fire_event(Events.ITERATION_COMPLETED)
trainer.state.output = [4.2, 4.2]
trainer.fire_event(Events.ITERATION_COMPLETED)
trainer.state.output = (4.2, 4.2)
trainer.fire_event(Events.ITERATION_COMPLETED)


def test_fbrlogger_with_output_transform(mock_logger):
trainer = Engine(lambda e, b: 42)
fbr = FBResearchLogger(logger=mock_logger, show_output=True)
fbr.attach(trainer, "Training", output_transform=lambda x: {"loss": x})
trainer.run(data=[10], epoch_length=1, max_epochs=1)
assert "loss: 42.0000" in fbr.logger.info.call_args_list[-2].args[0]


def test_fbrlogger_with_state_attrs(mock_logger):
trainer = Engine(lambda e, b: 42)
fbr = FBResearchLogger(logger=mock_logger, show_output=True)
fbr.attach(trainer, "Training", state_attributes=["alpha", "beta", "gamma"])
trainer.state.alpha = 3.899
trainer.state.beta = torch.tensor(12.21)
trainer.state.gamma = torch.tensor([21.0, 6.0])
trainer.run(data=[10], epoch_length=1, max_epochs=1)
attrs = "alpha: 3.8990 beta: 12.2100 gamma: [21.0000, 6.0000]"
assert attrs in fbr.logger.info.call_args_list[-2].args[0]
25 changes: 24 additions & 1 deletion tests/ignite/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from packaging.version import Version

from ignite.engine import Engine, Events
from ignite.utils import convert_tensor, deprecated, hash_checkpoint, setup_logger, to_onehot
from ignite.utils import _to_str_list, convert_tensor, deprecated, hash_checkpoint, setup_logger, to_onehot


def test_convert_tensor():
Expand Down Expand Up @@ -55,6 +55,29 @@ def test_convert_tensor():
convert_tensor(12345)


@pytest.mark.parametrize(
"input_data,expected",
[
(42, ["42.0000"]),
([{"a": 15, "b": torch.tensor([2.0])}], ["a: 15.0000", "b: [2.0000]"]),
({"a": 10, "b": 2.33333}, ["a: 10.0000", "b: 2.3333"]),
({"x": torch.tensor(0.1234), "y": [1, 2.3567]}, ["x: 0.1234", "y: 1.0000, 2.3567"]),
(({"nested": [3.1415, torch.tensor(0.0001)]},), ["nested: 3.1415, 0.0001"]),
(
{"large_vector": torch.tensor(range(20))},
["large_vector: [0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, ...]"],
),
({"large_matrix": torch.randn(5, 5)}, ["large_matrix: Shape: torch.Size([5, 5])"]),
({"empty": []}, ["empty: "]),
([], []),
({"none": None}, ["none: "]),
({1: 100, 2: 200}, ["1: 100.0000", "2: 200.0000"]),
],
)
def test__to_str_list(input_data, expected):
assert _to_str_list(input_data) == expected


def test_to_onehot():
indices = torch.tensor([0, 1, 2, 3], dtype=torch.long)
actual = to_onehot(indices, 4)
Expand Down