From 4551669f6ddedaf49eba10a73dbf61b997182202 Mon Sep 17 00:00:00 2001 From: Sevag Hanssian Date: Fri, 8 Apr 2022 12:31:56 -0400 Subject: [PATCH] Inference in chunks and no float64, to avoid oom --- docs/inference.md | 29 ++++ xumx_slicq/filtering.py | 7 +- xumx_slicq/model.py | 283 +++++++++++++++++++++------------------- 3 files changed, 182 insertions(+), 137 deletions(-) create mode 100644 docs/inference.md diff --git a/docs/inference.md b/docs/inference.md new file mode 100644 index 0000000..d963a5e --- /dev/null +++ b/docs/inference.md @@ -0,0 +1,29 @@ +No-chunk inference: + bss evaluation to store in exp-04-trained-models/slicq-wslicq + drums ==> SDR: 3.930 SIR: 8.810 ISR: 6.519 SAR: 5.847 + bass ==> SDR: -4.861 SIR: -5.665 ISR: 6.639 SAR: 6.193 + other ==> SDR: 0.348 SIR: -1.541 ISR: 1.419 SAR: 2.100 + vocals ==> SDR: 5.895 SIR: 13.361 ISR: 9.629 SAR: 7.209 + accompaniment ==> SDR: 12.948 SIR: 15.880 ISR: 20.559 SAR: 15.627 + + +Chunked inference (2621440, 59.44 seconds @ 44100 Hz) + bss evaluation to store in exp-04-trained-models/slicq-wslicq + drums ==> SDR: 3.990 SIR: 8.558 ISR: 6.609 SAR: 5.896 + bass ==> SDR: -4.809 SIR: -5.768 ISR: 6.494 SAR: 6.074 + other ==> SDR: 0.350 SIR: -1.408 ISR: 1.442 SAR: 1.995 + vocals ==> SDR: 5.887 SIR: 13.241 ISR: 9.703 SAR: 7.153 + accompaniment ==> SDR: 13.001 SIR: 15.890 ISR: 20.690 SAR: 15.513 + + +tradeoffs here and there, no cause for alarm. keep chunking + +Remove float64 in Wiener-EM function. results are still similar: + +no float64, chunked inference: + bss evaluation to store in exp-04-trained-models/slicq-wslicq + drums ==> SDR: 3.990 SIR: 8.558 ISR: 6.609 SAR: 5.896 + bass ==> SDR: -4.809 SIR: -5.768 ISR: 6.494 SAR: 6.074 + other ==> SDR: 0.350 SIR: -1.408 ISR: 1.442 SAR: 1.995 + vocals ==> SDR: 5.887 SIR: 13.241 ISR: 9.703 SAR: 7.153 + accompaniment ==> SDR: 13.001 SIR: 15.890 ISR: 20.690 SAR: 15.513 diff --git a/xumx_slicq/filtering.py b/xumx_slicq/filtering.py index 27d1a41..658346b 100644 --- a/xumx_slicq/filtering.py +++ b/xumx_slicq/filtering.py @@ -421,14 +421,11 @@ def wiener( mix_stft = mix_stft / max_abs y = y / max_abs - y = y.to(torch.float64) - # call expectation maximization - y = expectation_maximization(y, mix_stft.to(torch.float64), iterations, eps=eps, slicq=slicq)[0] + y = expectation_maximization(y, mix_stft, iterations, eps=eps, slicq=slicq)[0] # scale estimates up again - y = y * max_abs - return y.to(torch.float32) + return y * max_abs def _covariance(y_j): diff --git a/xumx_slicq/model.py b/xumx_slicq/model.py index af509ff..a8e65a6 100644 --- a/xumx_slicq/model.py +++ b/xumx_slicq/model.py @@ -1,5 +1,6 @@ from typing import Optional +from tqdm import trange import torch import torch.nn as nn import torch.nn.functional as F @@ -349,6 +350,7 @@ def __init__( stft_wiener: bool = True, softmask: bool = False, wiener_win_len: Optional[int] = 300, + chunk_size: Optional[int] = 2621440, n_fft: Optional[int] = 4096, n_hop: Optional[int] = 1024, ): @@ -374,6 +376,7 @@ def __init__( self.n_fft = n_fft self.n_hop = n_hop self.wiener_win_len = wiener_win_len + self.chunk_size = chunk_size if chunk_size is not None else sys.maxsize if not self.stft_wiener: # first, get frequency and time limits to build the large zero-padded matrix @@ -399,7 +402,7 @@ def freeze(self): self.eval() @torch.no_grad() - def forward(self, audio: Tensor) -> Tensor: + def forward(self, audio_big: Tensor) -> Tensor: """Performing the separation on audio input Args: @@ -411,160 +414,176 @@ def forward(self, audio: Tensor) -> Tensor: shape `(nb_samples, nb_targets, nb_channels, nb_timesteps)` """ nb_sources = 4 - nb_samples = audio.shape[0] - - X = self.nsgt(audio) - Xmag = self.complexnorm(X) - - # xumx inference - magnitude slicq estimate - Ymag_bass, Ymag_vocals, Ymag_other, Ymag_drums = self.xumx_model(Xmag) - - if self.stft_wiener: - print('STFT WIENER') - - # initial mix phase + magnitude estimate - Ycomplex_bass = phasemix_sep(X, Ymag_bass) - Ycomplex_vocals = phasemix_sep(X, Ymag_vocals) - Ycomplex_drums = phasemix_sep(X, Ymag_drums) - Ycomplex_other = phasemix_sep(X, Ymag_other) - - y_bass = self.insgt(Ycomplex_bass, audio.shape[-1]) - y_drums = self.insgt(Ycomplex_drums, audio.shape[-1]) - y_other = self.insgt(Ycomplex_other, audio.shape[-1]) - y_vocals = self.insgt(Ycomplex_vocals, audio.shape[-1]) - - # initial estimate was obtained with slicq - # now we switch to the STFT domain for the wiener step - - audio = torch.squeeze(audio, dim=0) - - mix_stft = torch.view_as_real(torch.stft(audio, self.n_fft, hop_length=self.n_hop, return_complex=True)) - X = torch.abs(torch.view_as_complex(mix_stft)) - - # initializing spectrograms variable - spectrograms = torch.zeros(X.shape + (nb_sources,), dtype=audio.dtype, device=X.device) - - for j, target_name in enumerate(self.ordered_targets): - # apply current model to get the source spectrogram - if target_name == 'bass': - target_est = torch.squeeze(y_bass, dim=0) - elif target_name == 'vocals': - target_est = torch.squeeze(y_vocals, dim=0) - elif target_name == 'drums': - target_est = torch.squeeze(y_drums, dim=0) - elif target_name == 'other': - target_est = torch.squeeze(y_other, dim=0) - spectrograms[..., j] = torch.abs(torch.stft(target_est, self.n_fft, hop_length=self.n_hop, return_complex=True)) - - # transposing it as - # (nb_samples, nb_frames, nb_bins,{1,nb_channels}, nb_sources) - - spectrograms = spectrograms.permute(2, 1, 0, 3) - - # rearranging it into: - # (nb_samples, nb_frames, nb_bins, nb_channels, 2) to feed - # into filtering methods - mix_stft = mix_stft.permute(2, 1, 0, 3) - - nb_frames = spectrograms.shape[0] - targets_stft = torch.zeros( - mix_stft.shape + (nb_sources,), dtype=audio.dtype, device=mix_stft.device - ) - - pos = 0 - if self.wiener_win_len: - wiener_win_len = self.wiener_win_len - else: - wiener_win_len = nb_frames - while pos < nb_frames: - cur_frame = torch.arange(pos, min(nb_frames, pos + wiener_win_len)) - pos = int(cur_frame[-1]) + 1 - - targets_stft[cur_frame] = wiener( - spectrograms[cur_frame], - mix_stft[cur_frame], - self.niter, - softmask=self.softmask, - slicq=False, # stft wiener + nb_samples = audio_big.shape[0] + N = audio_big.shape[-1] + + nchunks = (N // self.chunk_size) + if (N % self.chunk_size) != 0: + nchunks += 1 + + print(f'n chunks: {nchunks}') + + final_estimates = [] + + for chunk_idx in trange(nchunks): + audio = audio_big[..., chunk_idx * self.chunk_size: min((chunk_idx + 1) * self.chunk_size, N)] + print(f'audio.shape: {audio.shape}') + + X = self.nsgt(audio) + Xmag = self.complexnorm(X) + + # xumx inference - magnitude slicq estimate + Ymag_bass, Ymag_vocals, Ymag_other, Ymag_drums = self.xumx_model(Xmag) + + if self.stft_wiener: + print('STFT WIENER') + + # initial mix phase + magnitude estimate + Ycomplex_bass = phasemix_sep(X, Ymag_bass) + Ycomplex_vocals = phasemix_sep(X, Ymag_vocals) + Ycomplex_drums = phasemix_sep(X, Ymag_drums) + Ycomplex_other = phasemix_sep(X, Ymag_other) + + y_bass = self.insgt(Ycomplex_bass, audio.shape[-1]) + y_drums = self.insgt(Ycomplex_drums, audio.shape[-1]) + y_other = self.insgt(Ycomplex_other, audio.shape[-1]) + y_vocals = self.insgt(Ycomplex_vocals, audio.shape[-1]) + + # initial estimate was obtained with slicq + # now we switch to the STFT domain for the wiener step + + audio = torch.squeeze(audio, dim=0) + + mix_stft = torch.view_as_real(torch.stft(audio, self.n_fft, hop_length=self.n_hop, return_complex=True)) + X = torch.abs(torch.view_as_complex(mix_stft)) + + # initializing spectrograms variable + spectrograms = torch.zeros(X.shape + (nb_sources,), dtype=audio.dtype, device=X.device) + + for j, target_name in enumerate(self.ordered_targets): + # apply current model to get the source spectrogram + if target_name == 'bass': + target_est = torch.squeeze(y_bass, dim=0) + elif target_name == 'vocals': + target_est = torch.squeeze(y_vocals, dim=0) + elif target_name == 'drums': + target_est = torch.squeeze(y_drums, dim=0) + elif target_name == 'other': + target_est = torch.squeeze(y_other, dim=0) + spectrograms[..., j] = torch.abs(torch.stft(target_est, self.n_fft, hop_length=self.n_hop, return_complex=True)) + + # transposing it as + # (nb_samples, nb_frames, nb_bins,{1,nb_channels}, nb_sources) + + spectrograms = spectrograms.permute(2, 1, 0, 3) + + # rearranging it into: + # (nb_samples, nb_frames, nb_bins, nb_channels, 2) to feed + # into filtering methods + mix_stft = mix_stft.permute(2, 1, 0, 3) + + nb_frames = spectrograms.shape[0] + targets_stft = torch.zeros( + mix_stft.shape + (nb_sources,), dtype=audio.dtype, device=mix_stft.device ) - # getting to (nb_samples, nb_targets, channel, fft_size, n_frames, 2) - targets_stft = torch.view_as_complex(targets_stft.permute(4, 2, 1, 0, 3).contiguous()) + pos = 0 + if self.wiener_win_len: + wiener_win_len = self.wiener_win_len + else: + wiener_win_len = nb_frames + while pos < nb_frames: + cur_frame = torch.arange(pos, min(nb_frames, pos + wiener_win_len)) + pos = int(cur_frame[-1]) + 1 + + targets_stft[cur_frame] = wiener( + spectrograms[cur_frame], + mix_stft[cur_frame], + self.niter, + softmask=self.softmask, + slicq=False, # stft wiener + ) + + # getting to (nb_samples, nb_targets, channel, fft_size, n_frames, 2) + targets_stft = torch.view_as_complex(targets_stft.permute(4, 2, 1, 0, 3).contiguous()) - # inverse STFT - estimates = torch.empty(audio.shape + (nb_sources,), dtype=audio.dtype, device=audio.device) + # inverse STFT + estimates = torch.empty(audio.shape + (nb_sources,), dtype=audio.dtype, device=audio.device) - for j, target_name in enumerate(self.ordered_targets): - estimates[..., j] = torch.istft(targets_stft[j, ...], self.n_fft, hop_length=self.n_hop, length=audio.shape[-1]) + for j, target_name in enumerate(self.ordered_targets): + estimates[..., j] = torch.istft(targets_stft[j, ...], self.n_fft, hop_length=self.n_hop, length=audio.shape[-1]) - estimates = torch.unsqueeze(estimates, dim=0).permute(0, 3, 1, 2).contiguous() - else: - print('sliCQT WIENER') + estimates = torch.unsqueeze(estimates, dim=0).permute(0, 3, 1, 2).contiguous() + else: + print('sliCQT WIENER') - # block-wise wiener - # assemble it all into a zero-padded matrix + # block-wise wiener + # assemble it all into a zero-padded matrix - nb_slices = X[0].shape[3] - last_dim = 2 + nb_slices = X[0].shape[3] + last_dim = 2 - X_matrix = torch.zeros((nb_samples, self.nb_channels, self.total_f_bins, nb_slices, self.max_t_bins, last_dim), dtype=X[0].dtype, device=X[0].device) - spectrograms = torch.zeros(X_matrix.shape[:-1] + (nb_sources,), dtype=audio.dtype, device=X_matrix.device) + X_matrix = torch.zeros((nb_samples, self.nb_channels, self.total_f_bins, nb_slices, self.max_t_bins, last_dim), dtype=X[0].dtype, device=X[0].device) + spectrograms = torch.zeros(X_matrix.shape[:-1] + (nb_sources,), dtype=audio.dtype, device=X_matrix.device) - freq_start = 0 - for i, X_block in enumerate(X): - nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins, last_dim = X_block.shape + freq_start = 0 + for i, X_block in enumerate(X): + nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins, last_dim = X_block.shape - # assign up to the defined time bins - to the right will be zeros - X_matrix[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :] = X_block + # assign up to the defined time bins - to the right will be zeros + X_matrix[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :] = X_block - spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 0] = Ymag_vocals[i] - spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 1] = Ymag_drums[i] - spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 2] = Ymag_bass[i] - spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 3] = Ymag_other[i] + spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 0] = Ymag_vocals[i] + spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 1] = Ymag_drums[i] + spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 2] = Ymag_bass[i] + spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 3] = Ymag_other[i] - freq_start += nb_f_bins + freq_start += nb_f_bins - spectrograms = wiener( - torch.squeeze(spectrograms, dim=0), - torch.squeeze(X_matrix, dim=0), - self.niter, - softmask=self.softmask, - slicq=True, - ) + spectrograms = wiener( + torch.squeeze(spectrograms, dim=0), + torch.squeeze(X_matrix, dim=0), + self.niter, + softmask=self.softmask, + slicq=True, + ) - # reverse the wiener/EM permutes etc. - spectrograms = torch.unsqueeze(spectrograms.permute(2, 1, 0, 3, 4), dim=0) - spectrograms = spectrograms.reshape(nb_samples, self.nb_channels, self.total_f_bins, nb_slices, self.max_t_bins, *spectrograms.shape[-2:]) + # reverse the wiener/EM permutes etc. + spectrograms = torch.unsqueeze(spectrograms.permute(2, 1, 0, 3, 4), dim=0) + spectrograms = spectrograms.reshape(nb_samples, self.nb_channels, self.total_f_bins, nb_slices, self.max_t_bins, *spectrograms.shape[-2:]) - slicq_vocals = [None]*len(X) - slicq_bass = [None]*len(X) - slicq_drums = [None]*len(X) - slicq_other = [None]*len(X) + slicq_vocals = [None]*len(X) + slicq_bass = [None]*len(X) + slicq_drums = [None]*len(X) + slicq_other = [None]*len(X) - estimates = torch.empty(audio.shape + (nb_sources,), dtype=audio.dtype, device=audio.device) + estimates = torch.empty(audio.shape + (nb_sources,), dtype=audio.dtype, device=audio.device) - nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins = X_matrix.shape[:-1] + nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins = X_matrix.shape[:-1] - # matrix back to list form for insgt - freq_start = 0 - for i, X_block in enumerate(X): - nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins, _ = X_block.shape + # matrix back to list form for insgt + freq_start = 0 + for i, X_block in enumerate(X): + nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins, _ = X_block.shape - slicq_vocals[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 0].contiguous() - slicq_drums[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 1].contiguous() - slicq_bass[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 2].contiguous() - slicq_other[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 3].contiguous() + slicq_vocals[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 0].contiguous() + slicq_drums[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 1].contiguous() + slicq_bass[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 2].contiguous() + slicq_other[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 3].contiguous() - freq_start += nb_f_bins + freq_start += nb_f_bins - estimates[..., 0] = self.insgt(slicq_vocals, audio.shape[-1]) - estimates[..., 1] = self.insgt(slicq_drums, audio.shape[-1]) - estimates[..., 2] = self.insgt(slicq_bass, audio.shape[-1]) - estimates[..., 3] = self.insgt(slicq_other, audio.shape[-1]) + estimates[..., 0] = self.insgt(slicq_vocals, audio.shape[-1]) + estimates[..., 1] = self.insgt(slicq_drums, audio.shape[-1]) + estimates[..., 2] = self.insgt(slicq_bass, audio.shape[-1]) + estimates[..., 3] = self.insgt(slicq_other, audio.shape[-1]) - estimates = estimates.permute(0, 3, 1, 2).contiguous() + estimates = estimates.permute(0, 3, 1, 2).contiguous() + final_estimates.append(estimates) - return estimates + ests_concat = torch.cat(final_estimates, axis=-1) + print(f'ests concat: {ests_concat.shape}') + return ests_concat def to_dict(self, estimates: Tensor, aggregate_dict: Optional[dict] = None) -> dict: """Convert estimates as stacked tensor to dictionary