Skip to content

Commit

Permalink
Merge branch 'master' into fix_2459
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Apr 19, 2022
2 parents 229435b + 0d40173 commit a6d8ae3
Show file tree
Hide file tree
Showing 33 changed files with 958 additions and 826 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ parameters:
pytorch_stable_image:
type: string
# https://hub.docker.com/r/pytorch/pytorch/tags
default: "pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime"
default: "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime"
pytorch_stable_image_devel:
type: string
# https://hub.docker.com/r/pytorch/pytorch/tags
default: "pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel"
default: "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel"
workingdir:
type: string
default: "/tmp/ignite"
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/pytorch-version-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ jobs:
pip install -r requirements-dev.txt
python setup.py install
- name: Install appropriate Pillow for PyTorch 1.3.1
shell: bash -l {0}
if: ${{ matrix.pytorch-version == '1.3.1' }}
run: |
pip install --upgrade 'Pillow<7'
python -c "import torchvision"
- name: Download MNIST
uses: pytorch-ignite/download-mnist-github-action@master
with:
Expand Down
6 changes: 3 additions & 3 deletions docker/docker.cfg
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[DEFAULT]
build_docker_image_pytorch_version = 1.10.0-cuda11.3-cudnn8
build_docker_image_hvd_version = v0.23.0
build_docker_image_msdp_version = v0.5.9
build_docker_image_pytorch_version = 1.11.0-cuda11.3-cudnn8
build_docker_image_hvd_version = v0.24.2
build_docker_image_msdp_version = v0.6.0
13 changes: 11 additions & 2 deletions examples/contrib/mnist/mnist_with_tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,20 @@ def compute_metrics(engine):

tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))

tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))
tb_logger.attach(
trainer,
log_handler=WeightsHistHandler(
model,
whitelist=[
"conv",
],
),
event_name=Events.ITERATION_COMPLETED(every=100),
)

tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))

tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))

def score_function(engine):
return engine.state.metrics["accuracy"]
Expand Down
72 changes: 68 additions & 4 deletions ignite/contrib/handlers/tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ class WeightsHistHandler(BaseWeightsHistHandler):
Args:
model: model to log weights
tag: common title for all produced plots. For example, "generator"
whitelist: specific weights to log. Should be list of model's submodules
or parameters names, or a callable which gets weight along with its name
and determines if it should be logged. Names should be fully-qualified.
For more information please refer to `PyTorch docs
<https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_submodule>`_.
If not given, all of model's weights are logged.
Examples:
.. code-block:: python
Expand All @@ -419,20 +425,78 @@ class WeightsHistHandler(BaseWeightsHistHandler):
event_name=Events.ITERATION_COMPLETED,
log_handler=WeightsHistHandler(model)
)
.. code-block:: python
from ignite.contrib.handlers.tensorboard_logger import *
# Create a logger
tb_logger = TensorboardLogger(log_dir="experiments/tb_logs")
# Log weights of `fc` layer
weights = ['fc']
# Attach the logger to the trainer to log weights norm after each iteration
tb_logger.attach(
trainer,
event_name=Events.ITERATION_COMPLETED,
log_handler=WeightsHistHandler(model, whitelist=weights)
)
.. code-block:: python
from ignite.contrib.handlers.tensorboard_logger import *
# Create a logger
tb_logger = TensorboardLogger(log_dir="experiments/tb_logs")
# Log weights which name include 'conv'.
weight_selector = lambda name, p: 'conv' in name
# Attach the logger to the trainer to log weights norm after each iteration
tb_logger.attach(
trainer,
event_name=Events.ITERATION_COMPLETED,
log_handler=WeightsHistHandler(model, whitelist=weight_selector)
)
.. versionchanged:: 0.5.0
optional argument `whitelist` added.
"""

def __init__(self, model: nn.Module, tag: Optional[str] = None):
def __init__(
self,
model: nn.Module,
tag: Optional[str] = None,
whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None,
):
super(WeightsHistHandler, self).__init__(model, tag=tag)

weights = {}
if whitelist is None:

weights = dict(model.named_parameters())
elif callable(whitelist):

for n, p in model.named_parameters():
if whitelist(n, p):
weights[n] = p
else:

for n, p in model.named_parameters():
for item in whitelist:
if n.startswith(item):
weights[n] = p

self.weights = weights.items()

def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None:
if not isinstance(logger, TensorboardLogger):
raise RuntimeError("Handler 'WeightsHistHandler' works only with TensorboardLogger")

global_step = engine.state.get_event_attrib_value(event_name)
tag_prefix = f"{self.tag}/" if self.tag else ""
for name, p in self.model.named_parameters():
if p.grad is None:
continue
for name, p in self.weights:

name = name.replace(".", "/")
logger.writer.add_histogram(
Expand Down
4 changes: 3 additions & 1 deletion ignite/contrib/handlers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ class ProgressBar(BaseLogger):
def __init__(
self,
persist: bool = False,
bar_format: str = "{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]",
bar_format: Union[
str, None
] = "{desc}[{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} [{elapsed}<{remaining}]",
**tqdm_kwargs: Any,
):

Expand Down
2 changes: 1 addition & 1 deletion ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __or__(self, other: Any) -> "EventsList":
return EventsList() | self | other


class EventEnum(CallableEventWithFilter, Enum): # type: ignore[misc]
class EventEnum(CallableEventWithFilter, Enum):
"""Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit
this class.
Expand Down
148 changes: 119 additions & 29 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,18 @@ def __init__(
self.include_self = include_self
self.greater_or_equal = greater_or_equal

def _get_filename_pattern(self, global_step: Optional[int]) -> str:
if self.filename_pattern is None:
filename_pattern = self.setup_filename_pattern(
with_prefix=len(self.filename_prefix) > 0,
with_score=self.score_function is not None,
with_score_name=self.score_name is not None,
with_global_step=global_step is not None,
)
else:
filename_pattern = self.filename_pattern
return filename_pattern

def reset(self) -> None:
"""Method to reset saved checkpoint names.
Expand Down Expand Up @@ -402,15 +414,7 @@ def __call__(self, engine: Engine) -> None:
name = k
checkpoint = checkpoint[name]

if self.filename_pattern is None:
filename_pattern = self.setup_filename_pattern(
with_prefix=len(self.filename_prefix) > 0,
with_score=self.score_function is not None,
with_score_name=self.score_name is not None,
with_global_step=global_step is not None,
)
else:
filename_pattern = self.filename_pattern
filename_pattern = self._get_filename_pattern(global_step)

filename_dict = {
"filename_prefix": self.filename_prefix,
Expand Down Expand Up @@ -519,41 +523,51 @@ def _check_objects(objs: Mapping, attr: str) -> None:
raise TypeError(f"Object {type(obj)} should have `{attr}` method")

@staticmethod
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None:
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping, Path], **kwargs: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.
Args:
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
"optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain
directly corresponding state_dict.
checkpoint: a path, a string filepath or a dictionary with state_dicts to load, e.g.
`{"model": model_state_dict, "optimizer": opt_state_dict}`. If `to_load` contains a single key,
then checkpoint can contain directly corresponding state_dict.
kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
the user to load part of the pretrained model (useful for example, in Transfer Learning)
Examples:
.. code-block:: python
import tempfile
from pathlib import Path
import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint
trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)
to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
with tempfile.TemporaryDirectory() as tmpdirname:
handler = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)
to_load = to_save
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
# or using a string for checkpoint filepath
# or using a string for checkpoint filepath
to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
to_load = to_save
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
Note:
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
Expand All @@ -564,13 +578,13 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
"""

if isinstance(checkpoint, str):
if isinstance(checkpoint, (str, Path)):
checkpoint_obj = torch.load(checkpoint)
else:
checkpoint_obj = checkpoint

Checkpoint._check_objects(to_load, "load_state_dict")
if not isinstance(checkpoint, (collections.Mapping, str)):
if not isinstance(checkpoint, (collections.Mapping, str, Path)):
raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}")

if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]):
Expand Down Expand Up @@ -599,6 +613,82 @@ def _load_object(obj: Any, chkpt_obj: Any) -> None:
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
_load_object(obj, checkpoint_obj[k])

def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, **filename_components: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load``. Filename components such as
name, score and global state can be configured.
Args:
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
load_kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
the user to load part of the pretrained model (useful for example, in Transfer Learning)
filename_components: Filename components used to define the checkpoint file path.
Keyword arguments accepted are `name`, `score` and `global_state`.
Examples:
.. code-block:: python
import tempfile
import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint
trainer = Engine(lambda engine, batch: None)
with tempfile.TemporaryDirectory() as tmpdirname:
checkpoint = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, to_save)
trainer.run(torch.randn(10, 1), 5)
to_load = to_save
# load checkpoint myprefix_checkpoint_40.pt
checkpoint.load_objects(to_load=to_load, global_step=40)
Note:
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
`DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
torch.nn.parallel.DistributedDataParallel.html
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
"""

global_step = filename_components.get("global_step", None)

filename_pattern = self._get_filename_pattern(global_step)

checkpoint = self._setup_checkpoint()
name = "checkpoint"
if len(checkpoint) == 1:
for k in checkpoint:
name = k
name = filename_components.get("name", name)
score = filename_components.get("score", None)

filename_dict = {
"filename_prefix": self.filename_prefix,
"ext": self.ext,
"name": name,
"score_name": self.score_name,
"score": score,
"global_step": global_step,
}

checkpoint_fp = filename_pattern.format(**filename_dict)

path = self.save_handler.dirname / checkpoint_fp

load_kwargs = {} if load_kwargs is None else load_kwargs

Checkpoint.load_objects(to_load=to_load, checkpoint=path, **load_kwargs)

def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
Can be used to save internal state of the class.
Expand Down

0 comments on commit a6d8ae3

Please sign in to comment.