Skip to content

Commit

Permalink
fix(train): fix fp16_run not being mix precision and fix bf16 errors (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
34j committed Apr 10, 2023
1 parent 5418d4e commit b0dd0ed
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/so_vits_svc_fork/f0.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def compute_f0_crepe(
pad=True,
)

f0 = pitch.squeeze(0).cpu().numpy()
f0 = pitch.squeeze(0).cpu().float().numpy()
p_len = p_len or wav_numpy.shape[0] // hop_length
f0 = _resize_f0(f0, p_len)
return f0
Expand Down
61 changes: 49 additions & 12 deletions src/so_vits_svc_fork/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

LOG = getLogger(__name__)
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")


class VCDataModule(pl.LightningDataModule):
Expand Down Expand Up @@ -71,9 +72,9 @@ def train(
val_check_interval=hparams.train.eval_interval,
max_epochs=hparams.train.epochs,
check_val_every_n_epoch=None,
precision=16
precision="16-mixed"
if hparams.train.fp16_run
else "bf16"
else "bf16-mixed"
if hparams.train.get("bf16_run", False)
else 32,
)
Expand Down Expand Up @@ -159,6 +160,40 @@ def stft(

torch.stft = stft

elif "bf" in self.trainer.precision:
LOG.warning("Using bf. Patching torch.stft to use fp32.")

def stft(
input: torch.Tensor,
n_fft: int,
hop_length: int | None = None,
win_length: int | None = None,
window: torch.Tensor | None = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: bool | None = None,
return_complex: bool | None = None,
) -> torch.Tensor:
dtype = input.dtype
input = input.float()
if window is not None:
window = window.float()
return torch.functional.stft(
input,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
return_complex,
).to(dtype)

torch.stft = stft

def set_current_epoch(self, epoch: int):
LOG.info(f"Setting current epoch to {epoch}")
self.trainer.fit_loop.epoch_progress.current.completed = epoch
Expand Down Expand Up @@ -239,7 +274,7 @@ def log_audio_dict(self, audio_dict: dict[str, Any]) -> None:
for k, v in audio_dict.items():
writer.add_audio(
k,
v,
v.float(),
self.total_batch_idx,
sample_rate=self.hparams.data.sampling_rate,
)
Expand Down Expand Up @@ -334,21 +369,21 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None:
self.log_image_dict(
{
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy()
y_mel[0].data.cpu().float().numpy()
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
y_hat_mel[0].data.cpu().float().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
mel[0].data.cpu().float().numpy()
),
"all/lf0": so_vits_svc_fork.utils.plot_data_to_numpy(
lf0[0, 0, :].cpu().numpy(),
pred_lf0[0, 0, :].detach().cpu().numpy(),
lf0[0, 0, :].cpu().float().numpy(),
pred_lf0[0, 0, :].detach().cpu().float().numpy(),
),
"all/norm_lf0": so_vits_svc_fork.utils.plot_data_to_numpy(
lf0[0, 0, :].cpu().numpy(),
norm_lf0[0, 0, :].detach().cpu().numpy(),
lf0[0, 0, :].cpu().float().numpy(),
norm_lf0[0, 0, :].detach().cpu().float().numpy(),
),
}
)
Expand Down Expand Up @@ -395,9 +430,11 @@ def validation_step(self, batch, batch_idx):
self.log_image_dict(
{
"gen/mel": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy()
y_hat_mel[0].cpu().float().numpy()
),
"gt/mel": utils.plot_spectrogram_to_numpy(
mel[0].cpu().float().numpy()
),
"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
}
)
if self.current_epoch == 0 or batch_idx != 0:
Expand Down

0 comments on commit b0dd0ed

Please sign in to comment.