Skip to content

Commit

Permalink
Merge branch 'master' into feature/#2465_LRScheduler_attach_Events
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Mar 22, 2022
2 parents 56be4c5 + 3a286b1 commit 52dcbcf
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 129 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ parameters:
pytorch_stable_image:
type: string
# https://hub.docker.com/r/pytorch/pytorch/tags
default: "pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime"
default: "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime"
pytorch_stable_image_devel:
type: string
# https://hub.docker.com/r/pytorch/pytorch/tags
default: "pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel"
default: "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel"
workingdir:
type: string
default: "/tmp/ignite"
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/pytorch-version-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ jobs:
pip install -r requirements-dev.txt
python setup.py install
- name: Install appropriate Pillow for PyTorch 1.3.1
shell: bash -l {0}
if: ${{ matrix.pytorch-version == '1.3.1' }}
run: |
pip install --upgrade 'Pillow<7'
python -c "import torchvision"
- name: Download MNIST
uses: pytorch-ignite/download-mnist-github-action@master
with:
Expand Down
6 changes: 3 additions & 3 deletions docker/docker.cfg
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[DEFAULT]
build_docker_image_pytorch_version = 1.10.0-cuda11.3-cudnn8
build_docker_image_hvd_version = v0.23.0
build_docker_image_msdp_version = v0.5.9
build_docker_image_pytorch_version = 1.11.0-cuda11.3-cudnn8
build_docker_image_hvd_version = v0.24.2
build_docker_image_msdp_version = v0.6.0
5 changes: 3 additions & 2 deletions examples/contrib/cifar10/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@


def get_train_test_datasets(path):
if not Path(path).exists():
Path.mkdir(path, parents=True)
path = Path(path)
if not path.exists():
path.mkdir(parents=True)
download = True
else:
download = True if len(os.listdir(path)) < 1 else False
Expand Down
12 changes: 6 additions & 6 deletions ignite/contrib/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Te
except ImportError:
raise RuntimeError("This contrib module requires sklearn to be installed.")

y_true = y_targets.numpy()
y_pred = y_preds.numpy()
y_true = y_targets.cpu().numpy()
y_pred = y_preds.cpu().numpy()
return precision_recall_curve(y_true, y_pred)


Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(

def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("EpochMetric must have at least one example before it can be computed.")
raise NotComputableError("PrecisionRecallCurve must have at least one example before it can be computed.")

_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)
Expand All @@ -101,11 +101,11 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if idist.get_rank() == 0:
# Run compute_fn on zero rank only
precision, recall, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
precision = torch.Tensor(precision)
recall = torch.Tensor(recall)
precision = torch.tensor(precision)
recall = torch.tensor(recall)
# thresholds can have negative strides, not compatible with torch tensors
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
thresholds = torch.Tensor(thresholds.copy())
thresholds = torch.tensor(thresholds.copy())
else:
precision, recall, thresholds = None, None, None

Expand Down
2 changes: 1 addition & 1 deletion ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __or__(self, other: Any) -> "EventsList":
return EventsList() | self | other


class EventEnum(CallableEventWithFilter, Enum): # type: ignore[misc]
class EventEnum(CallableEventWithFilter, Enum):
"""Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit
this class.
Expand Down
148 changes: 119 additions & 29 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,18 @@ def __init__(
self.include_self = include_self
self.greater_or_equal = greater_or_equal

def _get_filename_pattern(self, global_step: Optional[int]) -> str:
if self.filename_pattern is None:
filename_pattern = self.setup_filename_pattern(
with_prefix=len(self.filename_prefix) > 0,
with_score=self.score_function is not None,
with_score_name=self.score_name is not None,
with_global_step=global_step is not None,
)
else:
filename_pattern = self.filename_pattern
return filename_pattern

def reset(self) -> None:
"""Method to reset saved checkpoint names.
Expand Down Expand Up @@ -402,15 +414,7 @@ def __call__(self, engine: Engine) -> None:
name = k
checkpoint = checkpoint[name]

if self.filename_pattern is None:
filename_pattern = self.setup_filename_pattern(
with_prefix=len(self.filename_prefix) > 0,
with_score=self.score_function is not None,
with_score_name=self.score_name is not None,
with_global_step=global_step is not None,
)
else:
filename_pattern = self.filename_pattern
filename_pattern = self._get_filename_pattern(global_step)

filename_dict = {
"filename_prefix": self.filename_prefix,
Expand Down Expand Up @@ -519,41 +523,51 @@ def _check_objects(objs: Mapping, attr: str) -> None:
raise TypeError(f"Object {type(obj)} should have `{attr}` method")

@staticmethod
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None:
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping, Path], **kwargs: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.
Args:
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
"optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain
directly corresponding state_dict.
checkpoint: a path, a string filepath or a dictionary with state_dicts to load, e.g.
`{"model": model_state_dict, "optimizer": opt_state_dict}`. If `to_load` contains a single key,
then checkpoint can contain directly corresponding state_dict.
kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
the user to load part of the pretrained model (useful for example, in Transfer Learning)
Examples:
.. code-block:: python
import tempfile
from pathlib import Path
import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint
trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)
to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
with tempfile.TemporaryDirectory() as tmpdirname:
handler = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)
to_load = to_save
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
# or using a string for checkpoint filepath
# or using a string for checkpoint filepath
to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
to_load = to_save
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
Note:
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
Expand All @@ -564,13 +578,13 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
"""

if isinstance(checkpoint, str):
if isinstance(checkpoint, (str, Path)):
checkpoint_obj = torch.load(checkpoint)
else:
checkpoint_obj = checkpoint

Checkpoint._check_objects(to_load, "load_state_dict")
if not isinstance(checkpoint, (collections.Mapping, str)):
if not isinstance(checkpoint, (collections.Mapping, str, Path)):
raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}")

if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]):
Expand Down Expand Up @@ -599,6 +613,82 @@ def _load_object(obj: Any, chkpt_obj: Any) -> None:
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
_load_object(obj, checkpoint_obj[k])

def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, **filename_components: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load``. Filename components such as
name, score and global state can be configured.
Args:
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
load_kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
the user to load part of the pretrained model (useful for example, in Transfer Learning)
filename_components: Filename components used to define the checkpoint file path.
Keyword arguments accepted are `name`, `score` and `global_state`.
Examples:
.. code-block:: python
import tempfile
import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint
trainer = Engine(lambda engine, batch: None)
with tempfile.TemporaryDirectory() as tmpdirname:
checkpoint = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, to_save)
trainer.run(torch.randn(10, 1), 5)
to_load = to_save
# load checkpoint myprefix_checkpoint_40.pt
checkpoint.load_objects(to_load=to_load, global_step=40)
Note:
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
`DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
torch.nn.parallel.DistributedDataParallel.html
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
"""

global_step = filename_components.get("global_step", None)

filename_pattern = self._get_filename_pattern(global_step)

checkpoint = self._setup_checkpoint()
name = "checkpoint"
if len(checkpoint) == 1:
for k in checkpoint:
name = k
name = filename_components.get("name", name)
score = filename_components.get("score", None)

filename_dict = {
"filename_prefix": self.filename_prefix,
"ext": self.ext,
"name": name,
"score_name": self.score_name,
"score": score,
"global_step": global_step,
}

checkpoint_fp = filename_pattern.format(**filename_dict)

path = self.save_handler.dirname / checkpoint_fp

load_kwargs = {} if load_kwargs is None else load_kwargs

Checkpoint.load_objects(to_load=to_load, checkpoint=path, **load_kwargs)

def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
Can be used to save internal state of the class.
Expand Down
2 changes: 2 additions & 0 deletions tests/ignite/contrib/metrics/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def get_test_cases():
for _ in range(3):
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
y_pred = y_pred.to(device)
y = y.to(device)
_test(y_pred, y, batch_size, "cpu")
if device.type != "xla":
_test(y_pred, y, batch_size, idist.device())
Expand Down
22 changes: 12 additions & 10 deletions tests/ignite/contrib/metrics/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def update_fn(engine, batch):

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()
precision = precision.cpu().numpy()
recall = recall.cpu().numpy()
thresholds = thresholds.cpu().numpy()

assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
Expand Down Expand Up @@ -168,9 +168,9 @@ def _test(y_pred, y, batch_size, metric_device):
res = prc.compute()

assert isinstance(res, Tuple)
assert precision_recall_curve(np_y, np_y_pred)[0] == pytest.approx(res[0])
assert precision_recall_curve(np_y, np_y_pred)[1] == pytest.approx(res[1])
assert precision_recall_curve(np_y, np_y_pred)[2] == pytest.approx(res[2])
assert precision_recall_curve(np_y, np_y_pred)[0] == pytest.approx(res[0].cpu().numpy())
assert precision_recall_curve(np_y, np_y_pred)[1] == pytest.approx(res[1].cpu().numpy())
assert precision_recall_curve(np_y, np_y_pred)[2] == pytest.approx(res[2].cpu().numpy())

def get_test_cases():
test_cases = [
Expand All @@ -183,9 +183,11 @@ def get_test_cases():
]
return test_cases

for _ in range(5):
for _ in range(3):
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
y_pred = y_pred.to(device)
y = y.to(device)
_test(y_pred, y, batch_size, "cpu")
if device.type != "xla":
_test(y_pred, y, batch_size, idist.device())
Expand Down Expand Up @@ -229,9 +231,9 @@ def update(engine, i):
assert precision.shape == sk_precision.shape
assert recall.shape == sk_recall.shape
assert thresholds.shape == sk_thresholds.shape
assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
assert pytest.approx(thresholds) == sk_thresholds
assert pytest.approx(precision.cpu().numpy()) == sk_precision
assert pytest.approx(recall.cpu().numpy()) == sk_recall
assert pytest.approx(thresholds.cpu().numpy()) == sk_thresholds

metric_devices = ["cpu"]
if device.type != "xla":
Expand Down
7 changes: 7 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,9 @@ def test_model_checkpoint_simple_recovery(dirname):
assert fname.exists()
loaded_objects = torch.load(fname)
assert loaded_objects == model.state_dict()
to_load = {"model": DummyModel()}
h.reload_objects(to_load=to_load, global_step=1)
assert to_load["model"].state_dict() == model.state_dict()


def test_model_checkpoint_simple_recovery_from_existing_non_empty(dirname):
Expand All @@ -600,6 +603,9 @@ def _test(ext, require_empty):
assert previous_fname.exists()
loaded_objects = torch.load(fname)
assert loaded_objects == model.state_dict()
to_load = {"model": DummyModel()}
h.reload_objects(to_load=to_load, global_step=1)
assert to_load["model"].state_dict() == model.state_dict()
fname.unlink()

_test(".txt", require_empty=True)
Expand Down Expand Up @@ -1118,6 +1124,7 @@ def _get_multiple_objs_to_save():
assert str(dirname / _PREFIX) in str(fname)
assert fname.exists()
Checkpoint.load_objects(to_save, str(fname))
Checkpoint.load_objects(to_save, fname)
fname.unlink()

# case: multiple objects
Expand Down

0 comments on commit 52dcbcf

Please sign in to comment.