Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SaveCheckpoint callback for checkpointing #2317

Merged
merged 11 commits into from Nov 14, 2023
1 change: 1 addition & 0 deletions docs/api/developer.md
Expand Up @@ -244,6 +244,7 @@ TrainingPlans define train/test/val optimization steps for modules.
train.TrainingPlan
train.TrainRunner
train.SaveBestState
train.SaveCheckpoint
train.LoudEarlyStopping

```
Expand Down
3 changes: 3 additions & 0 deletions docs/release_notes/index.md
Expand Up @@ -41,6 +41,9 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits
training plans {pr}`2280`.
- {class}`scvi.train.SemiSupervisedTrainingPlan` now logs the classifier
calibration error {pr}`2299`.
- Passing `enable_checkpointing=True` into `train` methods is now
compatible with our model saves. Additional options can be specified
by initializing with {class}`scvi.train.SaveCheckpoint` {pr}`2317`.

#### Fixed

Expand Down
4 changes: 2 additions & 2 deletions scvi/model/_scanvi.py
Expand Up @@ -433,7 +433,7 @@ def train(
self.module, self.n_labels, **plan_kwargs
)
if "callbacks" in trainer_kwargs.keys():
trainer_kwargs["callbacks"].concatenate(sampler_callback)
trainer_kwargs["callbacks"] + [sampler_callback]
else:
trainer_kwargs["callbacks"] = sampler_callback

Expand All @@ -455,7 +455,7 @@ def setup_anndata(
cls,
adata: AnnData,
labels_key: str,
unlabeled_category: str | int | float,
unlabeled_category: str,
layer: str | None = None,
batch_key: str | None = None,
size_factor_key: str | None = None,
Expand Down
3 changes: 2 additions & 1 deletion scvi/train/__init__.py
@@ -1,4 +1,4 @@
from ._callbacks import JaxModuleInit, LoudEarlyStopping, SaveBestState
from ._callbacks import JaxModuleInit, LoudEarlyStopping, SaveBestState, SaveCheckpoint
from ._trainer import Trainer
from ._trainingplans import (
AdversarialTrainingPlan,
Expand All @@ -22,6 +22,7 @@
"TrainRunner",
"LoudEarlyStopping",
"SaveBestState",
"SaveCheckpoint",
"JaxModuleInit",
"JaxTrainingPlan",
]
129 changes: 123 additions & 6 deletions scvi/train/_callbacks.py
@@ -1,12 +1,17 @@
from __future__ import annotations

import os
import warnings
from copy import deepcopy
from typing import Callable, Optional, Union
from datetime import datetime
from shutil import rmtree
from typing import Callable

import flax
import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.utilities import rank_zero_info

Expand All @@ -17,6 +22,120 @@
MetricCallable = Callable[[BaseModelClass], float]


class SaveCheckpoint(ModelCheckpoint):
"""``EXPERIMENTAL`` Saves model checkpoints based on a monitored metric.

Inherits from :class:`~lightning.pytorch.callbacks.ModelCheckpoint` and
modifies the default behavior to save the full model state instead of just
the state dict. This is necessary for compatibility with
:class:`~scvi.model.base.BaseModelClass`.

The best model save and best model score based on ``monitor`` can be
accessed post-training with the ``best_model_path`` and ``best_model_score``
attributes, respectively.

Known issues:

* Does not set ``train_indices``, ``validation_indices``, and
``test_indices`` for checkpoints.
* Does not set ``history`` for checkpoints. This can be accessed in the
final model however.
* Unsupported arguments: ``save_weights_only`` and ``save_last``

Parameters
----------
dirpath
Base directory to save the model checkpoints. If ``None``, defaults to
a directory formatted with the current date, time, and monitor within
``settings.logging_dir``.
filename
Name of the checkpoint directories. Can contain formatting options to be
auto-filled. If ``None``, defaults to ``{epoch}-{step}-{monitor}``.
monitor
Metric to monitor for checkpointing.
**kwargs
Additional keyword arguments passed into
:class:`~lightning.pytorch.callbacks.ModelCheckpoint`.
"""

def __init__(
self,
dirpath: str | None = None,
filename: str | None = None,
monitor: str = "validation_loss",
**kwargs,
):
if dirpath is None:
dirpath = os.path.join(
settings.logging_dir,
datetime.now().strftime("%Y-%m-%d-%H:%M:%S"),
)
dirpath += f"-{monitor}"

if filename is None:
filename = "{epoch}-{step}-{" + monitor + "}"

if "save_weights_only" in kwargs:
warnings.warn(
"`save_weights_only` is not supported and will be ignored.",
RuntimeWarning,
stacklevel=settings.warnings_stacklevel,
)
kwargs.pop("save_weights_only")
if "save_last" in kwargs:
warnings.warn(
"`save_last` is not supported and will be ignored.",
RuntimeWarning,
stacklevel=settings.warnings_stacklevel,
)
kwargs.pop("save_last")

super().__init__(
dirpath=dirpath,
filename=filename,
monitor=monitor,
**kwargs,
)

def on_save_checkpoint(self, trainer: pl.Trainer, *args) -> None:
"""Saves the model state on Lightning checkpoint saves."""
# set post training state before saving
model = trainer._model
model.module.eval()
model.is_trained_ = True
model.trainer = trainer

monitor_candidates = self._monitor_candidates(trainer)
save_path = self.format_checkpoint_name(monitor_candidates)
# by default, the function above gives a .ckpt extension
save_path = save_path.split(".ckpt")[0]
model.save(save_path, save_andnata=False, overwrite=True)

model.module.train()
model.is_trained_ = False

def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None:
"""Removes model saves that are no longer needed."""
super()._remove_checkpoint(trainer, filepath)

model_path = filepath.split(".ckpt")[0]
if os.path.exists(model_path) and os.path.isdir(model_path):
rmtree(model_path)

def _update_best_and_save(
self,
current: torch.Tensor,
trainer: pl.Trainer,
monitor_candidates: dict[str, torch.Tensor],
) -> None:
"""Replaces Lightning checkpoints with our model saves."""
super()._update_best_and_save(current, trainer, monitor_candidates)

if os.path.exists(self.best_model_path):
os.remove(self.best_model_path)
self.best_model_path = self.best_model_path.split(".ckpt")[0]


class MetricsCallback(Callback):
"""Computes metrics on validation end and logs them to the logger.

Expand All @@ -38,9 +157,7 @@ class MetricsCallback(Callback):

def __init__(
self,
metric_fns: Union[
MetricCallable, list[MetricCallable], dict[str, MetricCallable]
],
metric_fns: MetricCallable | list[MetricCallable] | dict[str, MetricCallable],
):
super().__init__()

Expand Down Expand Up @@ -211,7 +328,7 @@ def teardown(
self,
_trainer: pl.Trainer,
_pl_module: pl.LightningModule,
stage: Optional[str] = None,
stage: str | None = None,
) -> None:
"""Print the reason for stopping on teardown."""
if self.early_stopping_reason is not None:
Expand Down
24 changes: 21 additions & 3 deletions scvi/train/_trainer.py
Expand Up @@ -10,7 +10,12 @@
from scvi import settings
from scvi.autotune._types import Tunable, TunableMixin

from ._callbacks import LoudEarlyStopping, MetricCallable, MetricsCallback
from ._callbacks import (
LoudEarlyStopping,
MetricCallable,
MetricsCallback,
SaveCheckpoint,
)
from ._logger import SimpleLogger
from ._progress import ProgressBar
from ._trainingplans import PyroTrainingPlan
Expand Down Expand Up @@ -42,8 +47,10 @@ class Trainer(TunableMixin, pl.Trainer):
Defaults to `scvi.settings.logging_dir`. Can be remote file paths such as
s3://mybucket/path or ‘hdfs://path/’
enable_checkpointing
If `True`, enable checkpointing. It will configure a default ModelCheckpoint
callback if there is no user-defined ModelCheckpoint in `callbacks`.
If ``True``, enables checkpointing with a default :class:`~scvi.train.SaveCheckpoint`
callback if there is no user-defined :class:`~scvi.train.SaveCheckpoint` in ``callbacks``.
checkpointing_monitor
If ``enable_checkpointing`` is ``True``, specifies the metric to monitor for checkpointing.
num_sanity_val_steps
Sanity check runs n validation batches before starting the training routine.
Set it to -1 to run all batches in all validation dataloaders.
Expand Down Expand Up @@ -94,6 +101,7 @@ def __init__(
max_epochs: Tunable[int] = 400,
default_root_dir: Optional[str] = None,
enable_checkpointing: bool = False,
checkpointing_monitor: str = "validation_loss",
num_sanity_val_steps: int = 0,
enable_model_summary: bool = False,
early_stopping: bool = False,
Expand Down Expand Up @@ -130,6 +138,16 @@ def __init__(
callbacks.append(early_stopping_callback)
check_val_every_n_epoch = 1

if enable_checkpointing and not any(
isinstance(c, SaveCheckpoint) for c in callbacks
):
callbacks.append(SaveCheckpoint(monitor=checkpointing_monitor))
check_val_every_n_epoch = 1
elif any(isinstance(c, SaveCheckpoint) for c in callbacks):
# check if user provided already provided the callback
enable_checkpointing = True
check_val_every_n_epoch = 1

if learning_rate_monitor and not any(
isinstance(c, LearningRateMonitor) for c in callbacks
):
Expand Down
60 changes: 59 additions & 1 deletion tests/train/test_callbacks.py
@@ -1,7 +1,65 @@
import os

import pytest

import scvi
from scvi.train._callbacks import MetricsCallback
from scvi.train._callbacks import MetricsCallback, SaveCheckpoint


def test_modelcheckpoint_callback(save_path: str):
def check_checkpoint_logging(model, adata):
assert any(isinstance(c, SaveCheckpoint) for c in model.trainer.callbacks)
callback = [c for c in model.trainer.callbacks if isinstance(c, SaveCheckpoint)]
assert len(callback) == 1
callback = callback[0]
assert callback.best_model_path is not None
assert callback.best_model_score is not None
assert os.path.exists(callback.best_model_path)

log_dirs = os.listdir(scvi.settings.logging_dir)
assert len(log_dirs) >= 1
checkpoints_dir = os.path.join(scvi.settings.logging_dir, log_dirs[0])
checkpoint_dirs = os.listdir(checkpoints_dir)
assert len(checkpoint_dirs) >= 1
checkpoint_dir = os.path.join(checkpoints_dir, checkpoint_dirs[0])
checkpoint = model.__class__.load(checkpoint_dir, adata=adata)
assert checkpoint.is_trained_

def test_model_cls(model_cls, adata):
scvi.settings.logging_dir = os.path.join(save_path, model_cls.__name__)

# enable_checkpointing=True, default monitor
model = model_cls(adata)
model.train(max_epochs=5, enable_checkpointing=True)
check_checkpoint_logging(model, adata)

# enable_checkpointing=True, custom monitor
model = model_cls(adata)
model.train(
max_epochs=5,
enable_checkpointing=True,
checkpointing_monitor="elbo_validation",
)
check_checkpoint_logging(model, adata)

# manual callback
model = model_cls(adata)
model.train(
max_epochs=5,
callbacks=[SaveCheckpoint(monitor="elbo_validation")],
)
check_checkpoint_logging(model, adata)

old_logging_dir = scvi.settings.logging_dir
adata = scvi.data.synthetic_iid()

scvi.model.SCVI.setup_anndata(adata)
test_model_cls(scvi.model.SCVI, adata)

scvi.model.SCANVI.setup_anndata(adata, "labels", "label_0")
test_model_cls(scvi.model.SCANVI, adata)

scvi.settings.logging_dir = old_logging_dir


def test_metricscallback_init():
Expand Down