diff --git a/src/so_vits_svc_fork/modules/commons.py b/src/so_vits_svc_fork/modules/commons.py index 16b3cfbe..68e990d5 100644 --- a/src/so_vits_svc_fork/modules/commons.py +++ b/src/so_vits_svc_fork/modules/commons.py @@ -1,27 +1,79 @@ -import math +from __future__ import annotations import torch -from torch.nn import functional as F +import torch.nn.functional as F +from torch import Tensor + + +def slice_segments(x: Tensor, starts: Tensor, length: int) -> Tensor: + if length is None: + return x + length = min(length, x.size(-1)) + x_slice = torch.zeros((x.size()[:-1] + (length,)), dtype=x.dtype, device=x.device) + ends = starts + length + for i, (start, end) in enumerate(zip(starts, ends)): + # LOG.debug(i, start, end, x.size(), x[i, ..., start:end].size(), x_slice.size()) + # x_slice[i, ...] = x[i, ..., start:end] need to pad + # x_slice[i, ..., :end - start] = x[i, ..., start:end] this does not work + x_slice[i, ...] = F.pad(x[i, ..., start:end], (0, max(0, length - x.size(-1)))) + return x_slice + + +def rand_slice_segments_with_pitch( + x: Tensor, f0: Tensor, x_lengths: Tensor | int | None, segment_size: int | None +): + if segment_size is None: + return x, f0, torch.arange(x.size(0), device=x.device) + if x_lengths is None: + x_lengths = x.size(-1) * torch.ones( + x.size(0), dtype=torch.long, device=x.device + ) + # slice_starts = (torch.rand(z.size(0), device=z.device) * (z_lengths - segment_size)).long() + slice_starts = ( + torch.rand(x.size(0), device=x.device) + * torch.max( + x_lengths - segment_size, torch.zeros_like(x_lengths, device=x.device) + ) + ).long() + z_slice = slice_segments(x, slice_starts, segment_size) + f0_slice = slice_segments(f0, slice_starts, segment_size) + return z_slice, f0_slice, slice_starts + + +def slice_2d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor: + batch_size, num_features, seq_len = x.shape + ends = starts + length + idxs = ( + torch.arange(seq_len, device=x.device) + .unsqueeze(0) + .unsqueeze(1) + .repeat(batch_size, num_features, 1) + ) + mask = (idxs >= starts.unsqueeze(-1).unsqueeze(-1)) & ( + idxs < ends.unsqueeze(-1).unsqueeze(-1) + ) + return x[mask].reshape(batch_size, num_features, length) -def slice_pitch_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, idx_str:idx_end] - return ret +def slice_1d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor: + batch_size, seq_len = x.shape + ends = starts + length + idxs = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1) + mask = (idxs >= starts.unsqueeze(-1)) & (idxs < ends.unsqueeze(-1)) + return x[mask].reshape(batch_size, length) -def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size) - return ret, ret_pitch, ids_str +def _slice_segments_v3(x: Tensor, starts: Tensor, length: int) -> Tensor: + shape = x.shape[:-1] + (length,) + ends = starts + length + idxs = torch.arange(x.shape[-1], device=x.device).unsqueeze(0).unsqueeze(0) + unsqueeze_dims = len(shape) - len( + x.shape + ) # calculate number of dimensions to unsqueeze + starts = starts.reshape(starts.shape + (1,) * unsqueeze_dims) + ends = ends.reshape(ends.shape + (1,) * unsqueeze_dims) + mask = (idxs >= starts) & (idxs < ends) + return x[mask].reshape(shape) def init_weights(m, mean=0.0, std=0.01): @@ -40,89 +92,6 @@ def convert_pad_shape(pad_shape): return pad_shape -def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += ( - 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) - ) - return kl - - -def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def rand_spec_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( - num_timescales - 1 - ) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - def subsequent_mask(length): mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) return mask @@ -138,11 +107,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): return acts -def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - def sequence_mask(length, max_length=None): if max_length is None: max_length = length.max() @@ -150,24 +114,6 @@ def sequence_mask(length, max_length=None): return x.unsqueeze(0) < length.unsqueeze(1) -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - duration.device - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2, 3) * mask - return path - - def clip_grad_value_(parameters, clip_value, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index ea8c46a0..4e1b02e4 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -375,6 +375,7 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None: ids_slice * self.hparams.data.hop_length, self.hparams.train.segment_size, ) + y = y[..., : y_hat.shape[-1]] # generator loss y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(y, y_hat)