diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index 98cae807..bb1c3596 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -66,6 +66,10 @@ def train( utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan")) datamodule = VCDataModule(hparams) + strategy = ( + "ddp_find_unused_parameters_true" if torch.cuda.device_count() > 1 else "auto" + ) + LOG.info(f"Using strategy: {strategy}") trainer = pl.Trainer( logger=TensorBoardLogger(model_path), # profiler="simple", @@ -77,6 +81,7 @@ def train( else "bf16-mixed" if hparams.train.get("bf16_run", False) else 32, + strategy=strategy, ) model = VitsLightning(reset_optimizer=reset_optimizer, **hparams) trainer.fit(model, datamodule=datamodule) @@ -326,7 +331,6 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None: ) # generator loss - LOG.debug("Calculating generator loss") y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(y, y_hat) with autocast(enabled=False):