Skip to content

Commit

Permalink
add some mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Mar 26, 2024
1 parent cc88528 commit 1f6f20d
Showing 1 changed file with 35 additions and 22 deletions.
57 changes: 35 additions & 22 deletions ignite/handlers/fbresearch_logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""FBResearch logger and its helper handlers."""

import datetime
from typing import Any, Optional

# from typing import Any, Dict, Optional, Union

import torch

Expand Down Expand Up @@ -33,14 +36,16 @@ class FBResearchLogger:
logger.attach(trainer, name="Train", every=10, optimizer=my_optimizer)
"""

def __init__(self, logger, delimiter=" ", show_output=False):
def __init__(self, logger: Any, delimiter: str = " ", show_output: bool = False):
self.delimiter = delimiter
self.logger = logger
self.iter_timer = None
self.data_timer = None
self.show_output = show_output

def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
self.logger: Any = logger
self.iter_timer: Timer = Timer(average=True)
self.data_timer: Timer = Timer(average=True)
self.show_output: bool = show_output

def attach(
self, engine: Engine, name: str, every: int = 1, optimizer: Optional[torch.optim.Optimizer] = None
) -> None:
"""Attaches all the logging handlers to the given engine.
Args:
Expand All @@ -54,15 +59,15 @@ def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_epoch_completed, engine, name)
engine.add_event_handler(Events.COMPLETED, self.log_completed, engine, name)

self.iter_timer = Timer(average=True)
self.iter_timer.reset()
self.iter_timer.attach(
engine,
start=Events.EPOCH_STARTED,
resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED,
step=Events.ITERATION_COMPLETED,
)
self.data_timer = Timer(average=True)
self.data_timer.reset()
self.data_timer.attach(
engine,
start=Events.EPOCH_STARTED,
Expand All @@ -71,14 +76,15 @@ def attach(self, engine: Engine, name: str, every: int = 1, optimizer=None):
step=Events.GET_BATCH_COMPLETED,
)

def log_every(self, engine, optimizer=None):
def log_every(self, engine: Engine, optimizer: Optional[torch.optim.Optimizer] = None) -> None:
"""
Logs the training progress at regular intervals.
Args:
engine (Engine): The training engine.
optimizer (torch.optim.Optimizer, optional): The optimizer used for training. Defaults to None.
"""
assert engine.state.epoch_length is not None
cuda_max_mem = ""
if torch.cuda.is_available():
cuda_max_mem = f"GPU Max Mem: {torch.cuda.max_memory_allocated() / MB:.0f} MB"
Expand All @@ -89,12 +95,12 @@ def log_every(self, engine, optimizer=None):
eta_seconds = iter_avg_time * (engine.state.epoch_length - current_iter)

outputs = []
if self.show_output:
if self.show_output and engine.state.output is not None:
output = engine.state.output
if isinstance(output, dict):
outputs += [f"{k}: {v:.4f}" for k, v in output.items()]
else:
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output]
outputs += [f"{v:.4f}" if isinstance(v, float) else f"{v}" for v in output] # type: ignore

lrs = ""
if optimizer is not None:
Expand All @@ -120,7 +126,7 @@ def log_every(self, engine, optimizer=None):
)
self.logger.info(msg)

def log_epoch_started(self, engine, name):
def log_epoch_started(self, engine: Engine, name: str) -> None:
"""
Logs the start of an epoch.
Expand All @@ -132,37 +138,44 @@ def log_epoch_started(self, engine, name):
msg = f"{name}: start epoch [{engine.state.epoch}/{engine.state.max_epochs}]"
self.logger.info(msg)

def log_epoch_completed(self, engine, name):
def log_epoch_completed(self, engine: Engine, name: str) -> None:
"""
Logs the completion of an epoch.
Args:
engine (Engine): The engine object.
name (str): The name of the epoch.
engine (Engine): The engine object that triggered the event.
name (str): The name of the event.
Returns:
None
"""
epoch_time = engine.state.times[Events.EPOCH_COMPLETED.name]
epoch_info = f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]" if engine.state.max_epochs > 1 else ""
epoch_info = (
f"Epoch [{engine.state.epoch}/{engine.state.max_epochs}]"
if engine.state.max_epochs > 1
else "" # type: ignore
)
msg = self.delimiter.join(
[
f"{name}: {epoch_info}",
f"Total time: {datetime.timedelta(seconds=int(epoch_time))}",
f"({epoch_time / engine.state.epoch_length:.4f} s / it)",
f"Total time: {datetime.timedelta(seconds=int(epoch_time))}", # type: ignore
f"({epoch_time / engine.state.epoch_length:.4f} s / it)", # type: ignore
]
)
self.logger.info(msg)

def log_completed(self, engine, name):
def log_completed(self, engine: Engine, name: str) -> None:
"""
Logs the completion of a run.
Args:
engine (Engine): The engine object.
engine (Engine): The engine object representing the training/validation loop.
name (str): The name of the run.
"""
if engine.state.max_epochs > 1:
if engine.state.max_epochs and engine.state.max_epochs > 1:
total_time = engine.state.times[Events.COMPLETED.name]
assert total_time is not None
msg = self.delimiter.join(
[
f"{name}: run completed",
Expand Down

0 comments on commit 1f6f20d

Please sign in to comment.