diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index 90c6a8d8..67800789 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -343,8 +343,8 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None: # optimizer self.manual_backward(loss_disc_all) - optim_d.zero_grad() optim_d.step() + optim_d.zero_grad() self.untoggle_optimizer(optim_d) def validation_step(self, batch, batch_idx):