Skip to content

Commit

Permalink
Backport PR #2672 on branch 1.2.x (feat(train): add load_best_on_end …
Browse files Browse the repository at this point in the history
…argument to save checkpoint) (#2675)

Backport PR #2672: feat(train): add load_best_on_end argument to save
checkpoint

Co-authored-by: Martin Kim <46072231+martinkim0@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and martinkim0 committed Apr 3, 2024
1 parent 6dd31aa commit 204f992
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/
- Add `unsigned` argument to {meth}`scvi.hub.HubModel.pull_from_s3` to allow for unsigned
downloads of models from AWS S3 {pr}`2615`.
- Add support for `batch_key` in {meth}`scvi.model.CondSCVI.setup_anndata` {pr}`2626`.
- Add `load_best_on_end` argument to {class}`scvi.train.SaveCheckpoint` to load the best model
state at the end of training {pr}`2672`.

#### Changed

Expand Down
33 changes: 23 additions & 10 deletions scvi/train/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,20 @@
from scvi import settings
from scvi.dataloaders import AnnDataLoader
from scvi.model.base import BaseModelClass
from scvi.model.base._utils import _load_saved_files

MetricCallable = Callable[[BaseModelClass], float]


class SaveCheckpoint(ModelCheckpoint):
"""``EXPERIMENTAL`` Saves model checkpoints based on a monitored metric.
"""``BETA`` Saves model checkpoints based on a monitored metric.
Inherits from :class:`~lightning.pytorch.callbacks.ModelCheckpoint` and
modifies the default behavior to save the full model state instead of just
the state dict. This is necessary for compatibility with
:class:`~scvi.model.base.BaseModelClass`.
Inherits from :class:`~lightning.pytorch.callbacks.ModelCheckpoint` and modifies the default
behavior to save the full model state instead of just the state dict. This is necessary for
compatibility with :class:`~scvi.model.base.BaseModelClass`.
The best model save and best model score based on ``monitor`` can be
accessed post-training with the ``best_model_path`` and ``best_model_score``
attributes, respectively.
The best model save and best model score based on ``monitor`` can be accessed post-training
with the ``best_model_path`` and ``best_model_score`` attributes, respectively.
Known issues:
Expand All @@ -50,6 +49,8 @@ class SaveCheckpoint(ModelCheckpoint):
If ``None``, defaults to ``{epoch}-{step}-{monitor}``.
monitor
Metric to monitor for checkpointing.
load_best_on_end
If ``True``, loads the best model state into the model at the end of training.
**kwargs
Additional keyword arguments passed into the constructor for
:class:`~lightning.pytorch.callbacks.ModelCheckpoint`.
Expand All @@ -60,17 +61,16 @@ def __init__(
dirpath: str | None = None,
filename: str | None = None,
monitor: str = "validation_loss",
load_best_on_end: bool = False,
**kwargs,
):
if dirpath is None:
dirpath = os.path.join(
settings.logging_dir,
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + f"_{monitor}",
)

if filename is None:
filename = "{epoch}-{step}-{" + monitor + "}"

if "save_weights_only" in kwargs:
warnings.warn(
"`save_weights_only` is not supported in `SaveCheckpoint` and will be ignored.",
Expand All @@ -85,6 +85,7 @@ def __init__(
stacklevel=settings.warnings_stacklevel,
)
kwargs.pop("save_last")
self.load_best_on_end = load_best_on_end

super().__init__(
dirpath=dirpath,
Expand Down Expand Up @@ -139,6 +140,18 @@ def _update_best_and_save(
os.remove(self.best_model_path)
self.best_model_path = self.best_model_path.split(".ckpt")[0]

def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Loads the best model state into the model at the end of training."""
if not self.load_best_on_end:
return

_, _, best_state_dict, _ = _load_saved_files(
self.best_model_path,
load_adata=False,
map_location=pl_module.module.device,
)
pl_module.module.load_state_dict(best_state_dict)


class SubSampleLabels(Callback):
"""Subsample labels."""
Expand Down
35 changes: 30 additions & 5 deletions tests/train/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,52 @@
import os

import pytest

import scvi
from scvi.train._callbacks import SaveCheckpoint


def test_savecheckpoint(save_path: str):
def check_checkpoint_logging(model, adata):
@pytest.mark.parametrize("load_best_on_end", [True, False])
def test_savecheckpoint(save_path: str, load_best_on_end: bool):
import torch
from anndata import AnnData

from scvi.model.base import BaseModelClass
from scvi.train._callbacks import SaveCheckpoint

def check_checkpoint_logging(model: BaseModelClass, adata: AnnData):
assert any(isinstance(c, SaveCheckpoint) for c in model.trainer.callbacks)

callback = [c for c in model.trainer.callbacks if isinstance(c, SaveCheckpoint)]
assert len(callback) == 1

callback = callback[0]
assert callback.best_model_path is not None
assert callback.best_model_score is not None
assert os.path.exists(callback.best_model_path)

log_dirs = os.listdir(scvi.settings.logging_dir)
assert len(log_dirs) >= 1

checkpoints_dir = os.path.join(scvi.settings.logging_dir, log_dirs[0])
checkpoint_dirs = os.listdir(checkpoints_dir)
assert len(checkpoint_dirs) >= 1

checkpoint_dir = os.path.join(checkpoints_dir, checkpoint_dirs[0])
checkpoint = model.__class__.load(checkpoint_dir, adata=adata)
assert checkpoint.is_trained_

def test_model_cls(model_cls, adata):
if load_best_on_end:
best_model = model.__class__.load(callback.best_model_path, adata=adata)
assert best_model.is_trained_

current_state_dict = model.module.state_dict()
best_state_dict = best_model.module.state_dict()
assert len(current_state_dict) == len(best_state_dict)
for k, v in current_state_dict.items():
assert torch.equal(v, best_state_dict[k])
assert v.device == best_state_dict[k].device

def test_model_cls(model_cls, adata: AnnData):
scvi.settings.logging_dir = os.path.join(save_path, model_cls.__name__)

# enable_checkpointing=True, default monitor
Expand All @@ -44,7 +67,9 @@ def test_model_cls(model_cls, adata):
model = model_cls(adata)
model.train(
max_epochs=5,
callbacks=[SaveCheckpoint(monitor="elbo_validation")],
callbacks=[
SaveCheckpoint(monitor="elbo_validation", load_best_on_end=load_best_on_end)
],
)
check_checkpoint_logging(model, adata)

Expand Down

0 comments on commit 204f992

Please sign in to comment.