From 2f5833cdf5c1111f2374e638fba388812228745b Mon Sep 17 00:00:00 2001 From: leej3 Date: Mon, 22 Apr 2024 16:06:53 +0100 Subject: [PATCH 1/7] fbr logger: improve types and kwargs supported --- ignite/handlers/fbresearch_logger.py | 34 ++++++-- ignite/metrics/maximum_mean_discrepancy.py | 2 +- ignite/utils.py | 78 +++++++++++++++++++ .../ignite/handlers/test_fbresearch_logger.py | 52 ++++++++++++- tests/ignite/test_utils.py | 25 +++++- 5 files changed, 179 insertions(+), 12 deletions(-) diff --git a/ignite/handlers/fbresearch_logger.py b/ignite/handlers/fbresearch_logger.py index 395561ae575..18dcd47fbe6 100644 --- a/ignite/handlers/fbresearch_logger.py +++ b/ignite/handlers/fbresearch_logger.py @@ -1,18 +1,20 @@ """FBResearch logger and its helper handlers.""" import datetime -from typing import Any, Optional - -# from typing import Any, Dict, Optional, Union +from typing import Any, Callable, List, Optional import torch +from ignite import utils from ignite.engine import Engine, Events from ignite.handlers import Timer +from ignite.handlers.utils import global_step_from_engine # noqa MB = 1024.0 * 1024.0 +__all__ = ["FBResearchLogger", "global_step_from_engine"] + class FBResearchLogger: """Logs training and validation metrics for research purposes. @@ -98,7 +100,13 @@ def __init__(self, logger: Any, delimiter: str = " ", show_output: bool = False self.show_output: bool = show_output def attach( - self, engine: Engine, name: str, every: int = 1, optimizer: Optional[torch.optim.Optimizer] = None + self, + engine: Engine, + name: str, + every: int = 1, + output_transform: Optional[Callable] = None, + state_attributes: Optional[List[str]] = None, + optimizer: Optional[torch.optim.Optimizer] = None, ) -> None: """Attaches all the logging handlers to the given engine. @@ -106,8 +114,13 @@ def attach( engine: The engine to attach the logging handlers to. name: The name of the engine (e.g., "Train", "Validate") to include in log messages. every: Frequency of iterations to log information. Logs are generated every 'every' iterations. + output_transform: A function to select the value to log. + state_attributes: A list of attributes to log. optimizer: The optimizer used during training to log current learning rates. """ + self.name = name + self.output_transform = output_transform + self.state_attributes = state_attributes engine.add_event_handler(Events.EPOCH_STARTED, self.log_epoch_started, engine, name) engine.add_event_handler(Events.ITERATION_COMPLETED(every=every), self.log_every, engine, optimizer=optimizer) engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name) @@ -151,10 +164,9 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] = outputs = [] if self.show_output and engine.state.output is not None: output = engine.state.output - if isinstance(output, dict): - outputs += [f"{k}: {v:.4f}" for k, v in output.items()] - else: - outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore + if self.output_transform is not None: + output = self.output_transform(output) + outputs = utils._to_str_list(output) lrs = "" if optimizer is not None: @@ -164,6 +176,11 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] = for i, g in enumerate(optimizer.param_groups): lrs += f"lr [g{i}]: {g['lr']:.5f}" + state_attrs = [] + if self.state_attributes is not None: + state_attrs = utils._to_str_list( + {name: getattr(engine.state, name, None) for name in self.state_attributes} + ) msg = self.delimiter.join( [ f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]", @@ -172,6 +189,7 @@ def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] = f"{lrs}", ] + outputs + + [" ".join(state_attrs)] + [ f"Iter time: {iter_avg_time:.4f} s", f"Data prep time: {self.data_timer.value():.4f} s", diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py index 24faf5758c6..586aa94ffb7 100644 --- a/ignite/metrics/maximum_mean_discrepancy.py +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -29,7 +29,7 @@ class MaximumMeanDiscrepancy(Metric): More details can be found in `Gretton et al. 2012`__. - __ https://jmlr.csail.mit.edu/papers/v13/gretton12a.html + __ https://www.jmlr.org/papers/volume13/gretton12a/gretton12a.pdf - ``update`` must receive output of the form ``(x, y)``. - ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`. diff --git a/ignite/utils.py b/ignite/utils.py index 6e5b2176d6a..c45c49a2449 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -2,6 +2,7 @@ import functools import hashlib import logging +import numbers import random import shutil import warnings @@ -14,6 +15,7 @@ "convert_tensor", "apply_to_tensor", "apply_to_type", + "_to_str_list", "to_onehot", "setup_logger", "manual_seed", @@ -90,6 +92,82 @@ def _tree_map( return func(x, key=key) +def _to_str_list(data: Any) -> List[str]: + """ + Recursively flattens and formats complex data structures, including keys for + dictionaries, into a list of human-readable strings. + + This function processes nested dictionaries, lists, tuples, numbers, and + PyTorch tensors, formatting numbers to four decimal places and handling + tensors with special formatting rules. It's particularly useful for logging, + debugging, or any scenario where a human-readable representation of complex, + nested data structures is required. + + The function handles the following types: + + - Numbers: Formatted to four decimal places. + - PyTorch tensors: + - Scalars are formatted to four decimal places. + - 1D tensors with more than 10 elements show the first 10 elements + followed by an ellipsis. + - 1D tensors with 10 or fewer elements are fully listed. + - Multi-dimensional tensors display their shape. + - Dictionaries: Each key-value pair is included in the output with the key + as a prefix. + - Lists and tuples: Flattened and included in the output. Empty lists/tuples are represented + by an empty string. + - None values: Represented by an empty string. + + Args: + data: The input data to be flattened and formatted. It can be a nested + combination of dictionaries, lists, tuples, numbers, and PyTorch + tensors. + + Returns: + A list of formatted strings, each representing a part of the input data + structure. + """ + formatted_items: List[str] = [] + + def format_item(item: Any, prefix: str = "") -> Optional[str]: + if isinstance(item, numbers.Number): + return f"{prefix}{item:.4f}" + elif torch.is_tensor(item): + if item.dim() == 0: + return f"{prefix}{item.item():.4f}" # Format scalar tensor without brackets + elif item.dim() == 1 and item.size(0) > 10: + return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item[:10]) + ", ...]" + elif item.dim() == 1: + return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item) + "]" + else: + return f"{prefix}Shape: {item.shape}" + elif isinstance(item, dict): + for key, value in item.items(): + formatted_value = format_item(value, f"{key}: ") + if formatted_value is not None: + formatted_items.append(formatted_value) + elif isinstance(item, (list, tuple)): + if not item: + if prefix: + formatted_items.append(f"{prefix}") + else: + values = [format_item(x) for x in item] + values_str = [v for v in values if v is not None] + if values_str: + formatted_items.append(f"{prefix}" + ", ".join(values_str)) + elif item is None: + if prefix: + formatted_items.append(f"{prefix}") + return None + + # Directly handle single numeric values + if isinstance(data, numbers.Number): + return [f"{data:.4f}"] + + format_item(data) + return formatted_items + + class _CollectionItem: types_as_collection_item: Tuple = (int, float, torch.Tensor) diff --git a/tests/ignite/handlers/test_fbresearch_logger.py b/tests/ignite/handlers/test_fbresearch_logger.py index b85bdcf2794..728c97870e0 100644 --- a/tests/ignite/handlers/test_fbresearch_logger.py +++ b/tests/ignite/handlers/test_fbresearch_logger.py @@ -3,9 +3,13 @@ from unittest.mock import MagicMock import pytest +import torch +import torch.nn as nn +import torch.optim as optim -from ignite.engine import Engine, Events -from ignite.handlers.fbresearch_logger import FBResearchLogger # Adjust the import path as necessary +from ignite.engine import create_supervised_trainer, Engine, Events +from ignite.handlers.fbresearch_logger import FBResearchLogger +from ignite.utils import setup_logger @pytest.fixture @@ -56,3 +60,47 @@ def test_output_formatting(mock_engine, fb_research_logger, output, expected_pat actual_output = fb_research_logger.logger.info.call_args_list[0].args[0] assert re.search(expected_pattern, actual_output) + + +def test_logger_type_support(): + model = nn.Linear(10, 5) + opt = optim.SGD(model.parameters(), lr=0.001) + criterion = nn.CrossEntropyLoss() + + data = [(torch.rand(4, 10), torch.randint(0, 5, size=(4,))) for _ in range(100)] + + trainer = create_supervised_trainer(model, opt, criterion) + + logger = setup_logger("trainer", level=logging.INFO) + logger = FBResearchLogger(logger=logger, show_output=True) + logger.attach(trainer, name="Train", every=20, optimizer=opt) + + trainer.run(data, max_epochs=4) + trainer.state.output = {"loss": 4.2} + trainer.fire_event(Events.ITERATION_COMPLETED) + trainer.state.output = "4.2" + trainer.fire_event(Events.ITERATION_COMPLETED) + trainer.state.output = [4.2, 4.2] + trainer.fire_event(Events.ITERATION_COMPLETED) + trainer.state.output = (4.2, 4.2) + trainer.fire_event(Events.ITERATION_COMPLETED) + + +def test_fbrlogger_with_output_transform(mock_logger): + trainer = Engine(lambda e, b: 42) + fbr = FBResearchLogger(logger=mock_logger, show_output=True) + fbr.attach(trainer, "Training", output_transform=lambda x: {"loss": x}) + trainer.run(data=[10], epoch_length=1, max_epochs=1) + assert "loss: 42.0000" in fbr.logger.info.call_args_list[-2].args[0] + + +def test_fbrlogger_with_state_attrs(mock_logger): + trainer = Engine(lambda e, b: 42) + fbr = FBResearchLogger(logger=mock_logger, show_output=True) + fbr.attach(trainer, "Training", state_attributes=["alpha", "beta", "gamma"]) + trainer.state.alpha = 3.899 + trainer.state.beta = torch.tensor(12.21) + trainer.state.gamma = torch.tensor([21.0, 6.0]) + trainer.run(data=[10], epoch_length=1, max_epochs=1) + attrs = "alpha: 3.8990 beta: 12.2100 gamma: [21.0000, 6.0000]" + assert attrs in fbr.logger.info.call_args_list[-2].args[0] diff --git a/tests/ignite/test_utils.py b/tests/ignite/test_utils.py index 828533ce201..98039255397 100644 --- a/tests/ignite/test_utils.py +++ b/tests/ignite/test_utils.py @@ -8,7 +8,7 @@ from packaging.version import Version from ignite.engine import Engine, Events -from ignite.utils import convert_tensor, deprecated, hash_checkpoint, setup_logger, to_onehot +from ignite.utils import _to_str_list, convert_tensor, deprecated, hash_checkpoint, setup_logger, to_onehot def test_convert_tensor(): @@ -55,6 +55,29 @@ def test_convert_tensor(): convert_tensor(12345) +@pytest.mark.parametrize( + "input_data,expected", + [ + (42, ["42.0000"]), + ([{"a": 15, "b": torch.tensor([2.0])}], ["a: 15.0000", "b: [2.0000]"]), + ({"a": 10, "b": 2.33333}, ["a: 10.0000", "b: 2.3333"]), + ({"x": torch.tensor(0.1234), "y": [1, 2.3567]}, ["x: 0.1234", "y: 1.0000, 2.3567"]), + (({"nested": [3.1415, torch.tensor(0.0001)]},), ["nested: 3.1415, 0.0001"]), + ( + {"large_vector": torch.tensor(range(20))}, + ["large_vector: [0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, ...]"], + ), + ({"large_matrix": torch.randn(5, 5)}, ["large_matrix: Shape: torch.Size([5, 5])"]), + ({"empty": []}, ["empty: "]), + ([], []), + ({"none": None}, ["none: "]), + ({1: 100, 2: 200}, ["1: 100.0000", "2: 200.0000"]), + ], +) +def test__to_str_list(input_data, expected): + assert _to_str_list(input_data) == expected + + def test_to_onehot(): indices = torch.tensor([0, 1, 2, 3], dtype=torch.long) actual = to_onehot(indices, 4) From e7887137cf0810c9b6eda5284e1d1b6879391032 Mon Sep 17 00:00:00 2001 From: leej3 Date: Tue, 25 Jun 2024 11:50:28 +0100 Subject: [PATCH 2/7] remove autolist for utils --- docs/source/utils.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 887168436ea..da8c2814ea6 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -7,7 +7,6 @@ Module with helper methods .. autosummary:: :nosignatures: - :autolist: .. automodule:: ignite.utils :members: From cd3d36702d02a13def0ddd66486317abaafda6f1 Mon Sep 17 00:00:00 2001 From: leej3 Date: Tue, 25 Jun 2024 11:57:45 +0100 Subject: [PATCH 3/7] add clean directive to docs Makefile --- docs/Makefile | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/Makefile b/docs/Makefile index 3d1f9ada6a8..413cdff94ad 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -22,6 +22,13 @@ docset: html rebuild: rm -rf source/generated && make clean && make html +clean: + @echo "Cleaning up..." + python -c "import shutil; shutil.rmtree('$(BUILDDIR)', ignore_errors=True)" + python -c "import shutil; shutil.rmtree('$(SOURCEDIR)/generated', ignore_errors=True)" + python -c "import os; [os.remove(f) for f in os.listdir('.') if f.endswith('.pyc')]" + python -c "import shutil; import os; [shutil.rmtree(f) for f in os.listdir('.') if f == '__pycache__' and os.path.isdir(f)]" + .PHONY: help Makefile docset # Catch-all target: route all unknown targets to Sphinx using the new From 5ab1c882897167c0b3a670904e4f7f227d88301c Mon Sep 17 00:00:00 2001 From: leej3 Date: Tue, 25 Jun 2024 14:42:02 +0100 Subject: [PATCH 4/7] tidy matrix display --- ignite/utils.py | 2 +- tests/ignite/test_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/utils.py b/ignite/utils.py index c45c49a2449..1b2d8c03775 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -140,7 +140,7 @@ def format_item(item: Any, prefix: str = "") -> Optional[str]: elif item.dim() == 1: return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item) + "]" else: - return f"{prefix}Shape: {item.shape}" + return f"{prefix}Shape: {list(item.shape)}" elif isinstance(item, dict): for key, value in item.items(): formatted_value = format_item(value, f"{key}: ") diff --git a/tests/ignite/test_utils.py b/tests/ignite/test_utils.py index 98039255397..1b10ba168f0 100644 --- a/tests/ignite/test_utils.py +++ b/tests/ignite/test_utils.py @@ -67,7 +67,7 @@ def test_convert_tensor(): {"large_vector": torch.tensor(range(20))}, ["large_vector: [0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, ...]"], ), - ({"large_matrix": torch.randn(5, 5)}, ["large_matrix: Shape: torch.Size([5, 5])"]), + ({"large_matrix": torch.randn(5, 5)}, ["large_matrix: Shape: [5, 5]"]), ({"empty": []}, ["empty: "]), ([], []), ({"none": None}, ["none: "]), From d13d6cf9e25f76121c7b250f51f7a604f2be5a9e Mon Sep 17 00:00:00 2001 From: leej3 Date: Tue, 25 Jun 2024 15:01:27 +0100 Subject: [PATCH 5/7] make reporting of shape more compact --- ignite/utils.py | 2 +- tests/ignite/test_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/utils.py b/ignite/utils.py index 1b2d8c03775..1345e2bb0d8 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -140,7 +140,7 @@ def format_item(item: Any, prefix: str = "") -> Optional[str]: elif item.dim() == 1: return f"{prefix}[" + ", ".join(f"{x.item():.4f}" for x in item) + "]" else: - return f"{prefix}Shape: {list(item.shape)}" + return f"{prefix}Shape{list(item.shape)}" elif isinstance(item, dict): for key, value in item.items(): formatted_value = format_item(value, f"{key}: ") diff --git a/tests/ignite/test_utils.py b/tests/ignite/test_utils.py index 1b10ba168f0..4b00fb8c67a 100644 --- a/tests/ignite/test_utils.py +++ b/tests/ignite/test_utils.py @@ -67,7 +67,7 @@ def test_convert_tensor(): {"large_vector": torch.tensor(range(20))}, ["large_vector: [0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, ...]"], ), - ({"large_matrix": torch.randn(5, 5)}, ["large_matrix: Shape: [5, 5]"]), + ({"large_matrix": torch.randn(5, 5)}, ["large_matrix: Shape[5, 5]"]), ({"empty": []}, ["empty: "]), ([], []), ({"none": None}, ["none: "]), From 016bc96dc83e8dfc88b6c10735edb42bc4dc4a25 Mon Sep 17 00:00:00 2001 From: leej3 Date: Thu, 27 Jun 2024 13:44:57 +0100 Subject: [PATCH 6/7] remove superfluous import --- ignite/handlers/fbresearch_logger.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ignite/handlers/fbresearch_logger.py b/ignite/handlers/fbresearch_logger.py index 18dcd47fbe6..4243a636b6f 100644 --- a/ignite/handlers/fbresearch_logger.py +++ b/ignite/handlers/fbresearch_logger.py @@ -8,12 +8,10 @@ from ignite import utils from ignite.engine import Engine, Events from ignite.handlers import Timer -from ignite.handlers.utils import global_step_from_engine # noqa - MB = 1024.0 * 1024.0 -__all__ = ["FBResearchLogger", "global_step_from_engine"] +__all__ = ["FBResearchLogger"] class FBResearchLogger: From c6decd50dd2c86d2ae2ed977da22d1a933790f43 Mon Sep 17 00:00:00 2001 From: leej3 Date: Thu, 27 Jun 2024 14:27:43 +0100 Subject: [PATCH 7/7] fix bug in autosummary --- docs/source/conf.py | 10 +++++++++- docs/source/utils.rst | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 80c15e9b4d2..e26a50785f2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -301,7 +301,15 @@ def run(self): names = [name[0] for name in getmembers(module)] # Filter out members w/o doc strings - names = [name for name in names if getattr(module, name).__doc__ is not None] + filtered_names = [] + for name in names: + try: + if not name.startswith("_") and getattr(module, name).__doc__ is not None: + filtered_names.append(name) + except AttributeError: + continue + + names = filtered_names if auto == "autolist": # Get list of all classes and functions inside module diff --git a/docs/source/utils.rst b/docs/source/utils.rst index da8c2814ea6..887168436ea 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -7,6 +7,7 @@ Module with helper methods .. autosummary:: :nosignatures: + :autolist: .. automodule:: ignite.utils :members: