Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

solo updates #933

Merged
merged 10 commits into from Feb 23, 2021
28 changes: 28 additions & 0 deletions scvi/external/solo/_model.py
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
from anndata import AnnData
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from scvi import _CONSTANTS
from scvi.compose import auto_move_data
Expand Down Expand Up @@ -152,6 +153,10 @@ def train(
validation_size: Optional[float] = None,
batch_size: int = 128,
plan_kwargs: Optional[dict] = None,
callbacks: list = [],
adamgayoso marked this conversation as resolved.
Show resolved Hide resolved
early_stopping: bool = True,
early_stopping_patience: int = 30,
early_stopping_min_delta: float = 0.0,
**kwargs,
):
"""
Expand All @@ -174,6 +179,10 @@ def train(
Minibatch size to use during training.
plan_kwargs
Keyword args for :class:`~scvi.lightning.ClassifierTrainingPlan`. Keyword arguments passed to
early_stopping_patience
adamgayoso marked this conversation as resolved.
Show resolved Hide resolved
Number of times early stopping metric can not improve over early_stopping_min_delta
early_stopping_min_delta
Threshold for counting an epoch torwards patience
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**kwargs
Other keyword args for :class:`~scvi.lightning.Trainer`.
Expand All @@ -185,13 +194,32 @@ def train(
plan_kwargs.update(update_dict)
else:
plan_kwargs = update_dict

adamgayoso marked this conversation as resolved.
Show resolved Hide resolved
if early_stopping:
callbacks += [
EarlyStopping(
monitor="validation_loss",
min_delta=early_stopping_min_delta,
patience=early_stopping_patience,
mode="min",
)
]
check_val_every_n_epoch = 1
else:
check_val_every_n_epoch = (
check_val_every_n_epoch
if check_val_every_n_epoch is not None
else np.inf
)

super().train(
max_epochs=max_epochs,
use_gpu=use_gpu,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
plan_kwargs=plan_kwargs,
callbacks=callbacks,
**kwargs,
)

Expand Down
10 changes: 10 additions & 0 deletions scvi/lightning/_trainingplans.py
Expand Up @@ -47,6 +47,8 @@ class TrainingPlan(pl.LightningModule):
Threshold for measuring the new optimum.
lr_scheduler_metric
Which metric to track for learning rate reduction.
lr_min
Minimum learning rate allowed
**loss_kwargs
Keyword args to pass to the loss method of the `vae_model`.
`kl_weight` should not be passed here and is handled automatically.
Expand All @@ -69,6 +71,7 @@ def __init__(
lr_scheduler_metric: Literal[
"elbo_validation", "reconstruction_loss_validation", "kl_local_validation"
] = "elbo_validation",
lr_min: float = None,
adamgayoso marked this conversation as resolved.
Show resolved Hide resolved
**loss_kwargs,
):
super(TrainingPlan, self).__init__()
Expand All @@ -85,6 +88,7 @@ def __init__(
self.lr_patience = lr_patience
self.lr_scheduler_metric = lr_scheduler_metric
self.lr_threshold = lr_threshold
self.lr_min = lr_min
self.loss_kwargs = loss_kwargs

# automatic handling of kl weight
Expand Down Expand Up @@ -172,6 +176,7 @@ def configure_optimizers(self):
patience=self.lr_patience,
factor=self.lr_factor,
threshold=self.lr_threshold,
min_lr=self.lr_min,
threshold_mode="abs",
verbose=True,
)
Expand Down Expand Up @@ -228,6 +233,8 @@ class AdversarialTrainingPlan(TrainingPlan):
Threshold for measuring the new optimum.
lr_scheduler_metric
Which metric to track for learning rate reduction.
lr_min
Minimum learning rate allowed
adversarial_classifier
Whether to use adversarial classifier in the latent space
scale_adversarial_loss
Expand All @@ -254,6 +261,7 @@ def __init__(
lr_scheduler_metric: Literal[
"elbo_validation", "reconstruction_loss_validation", "kl_local_validation"
] = "elbo_validation",
lr_min: float = None,
adamgayoso marked this conversation as resolved.
Show resolved Hide resolved
adversarial_classifier: Union[bool, Classifier] = False,
scale_adversarial_loss: Union[float, Literal["auto"]] = "auto",
**loss_kwargs,
Expand All @@ -270,6 +278,7 @@ def __init__(
lr_patience=lr_patience,
lr_threshold=lr_threshold,
lr_scheduler_metric=lr_scheduler_metric,
lr_min=lr_min,
)
if adversarial_classifier is True:
self.n_output_classifier = self.module.n_batch
Expand Down Expand Up @@ -361,6 +370,7 @@ def configure_optimizers(self):
patience=self.lr_patience,
factor=self.lr_factor,
threshold=self.lr_threshold,
min_lr=self.lr_min,
threshold_mode="abs",
verbose=True,
)
Expand Down