diff --git a/scvelo/tools/_vi_model.py b/scvelo/tools/_vi_model.py index 2483bc5c..d4d8f1df 100644 --- a/scvelo/tools/_vi_model.py +++ b/scvelo/tools/_vi_model.py @@ -126,6 +126,8 @@ def train( lr: float = 1e-2, weight_decay: float = 1e-2, use_gpu: Optional[Union[str, int, bool]] = None, + accelerator: str = "auto", + devices: Union[int, List[int], str] = "auto", train_size: float = 0.9, validation_size: Optional[float] = None, batch_size: int = 256, @@ -179,7 +181,6 @@ def train( train_size=train_size, validation_size=validation_size, batch_size=batch_size, - use_gpu=use_gpu, ) training_plan = TrainingPlan(self.module, **plan_kwargs)