Skip to content

Commit

Permalink
Inference in chunks and no float64, to avoid oom
Browse files Browse the repository at this point in the history
  • Loading branch information
sevagh committed Apr 8, 2022
1 parent 4795bf9 commit 4551669
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 137 deletions.
29 changes: 29 additions & 0 deletions docs/inference.md
@@ -0,0 +1,29 @@
No-chunk inference:
bss evaluation to store in exp-04-trained-models/slicq-wslicq
drums ==> SDR: 3.930 SIR: 8.810 ISR: 6.519 SAR: 5.847
bass ==> SDR: -4.861 SIR: -5.665 ISR: 6.639 SAR: 6.193
other ==> SDR: 0.348 SIR: -1.541 ISR: 1.419 SAR: 2.100
vocals ==> SDR: 5.895 SIR: 13.361 ISR: 9.629 SAR: 7.209
accompaniment ==> SDR: 12.948 SIR: 15.880 ISR: 20.559 SAR: 15.627


Chunked inference (2621440, 59.44 seconds @ 44100 Hz)
bss evaluation to store in exp-04-trained-models/slicq-wslicq
drums ==> SDR: 3.990 SIR: 8.558 ISR: 6.609 SAR: 5.896
bass ==> SDR: -4.809 SIR: -5.768 ISR: 6.494 SAR: 6.074
other ==> SDR: 0.350 SIR: -1.408 ISR: 1.442 SAR: 1.995
vocals ==> SDR: 5.887 SIR: 13.241 ISR: 9.703 SAR: 7.153
accompaniment ==> SDR: 13.001 SIR: 15.890 ISR: 20.690 SAR: 15.513


tradeoffs here and there, no cause for alarm. keep chunking

Remove float64 in Wiener-EM function. results are still similar:

no float64, chunked inference:
bss evaluation to store in exp-04-trained-models/slicq-wslicq
drums ==> SDR: 3.990 SIR: 8.558 ISR: 6.609 SAR: 5.896
bass ==> SDR: -4.809 SIR: -5.768 ISR: 6.494 SAR: 6.074
other ==> SDR: 0.350 SIR: -1.408 ISR: 1.442 SAR: 1.995
vocals ==> SDR: 5.887 SIR: 13.241 ISR: 9.703 SAR: 7.153
accompaniment ==> SDR: 13.001 SIR: 15.890 ISR: 20.690 SAR: 15.513
7 changes: 2 additions & 5 deletions xumx_slicq/filtering.py
Expand Up @@ -421,14 +421,11 @@ def wiener(
mix_stft = mix_stft / max_abs
y = y / max_abs

y = y.to(torch.float64)

# call expectation maximization
y = expectation_maximization(y, mix_stft.to(torch.float64), iterations, eps=eps, slicq=slicq)[0]
y = expectation_maximization(y, mix_stft, iterations, eps=eps, slicq=slicq)[0]

# scale estimates up again
y = y * max_abs
return y.to(torch.float32)
return y * max_abs


def _covariance(y_j):
Expand Down
283 changes: 151 additions & 132 deletions xumx_slicq/model.py
@@ -1,5 +1,6 @@
from typing import Optional

from tqdm import trange
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -349,6 +350,7 @@ def __init__(
stft_wiener: bool = True,
softmask: bool = False,
wiener_win_len: Optional[int] = 300,
chunk_size: Optional[int] = 2621440,
n_fft: Optional[int] = 4096,
n_hop: Optional[int] = 1024,
):
Expand All @@ -374,6 +376,7 @@ def __init__(
self.n_fft = n_fft
self.n_hop = n_hop
self.wiener_win_len = wiener_win_len
self.chunk_size = chunk_size if chunk_size is not None else sys.maxsize

if not self.stft_wiener:
# first, get frequency and time limits to build the large zero-padded matrix
Expand All @@ -399,7 +402,7 @@ def freeze(self):
self.eval()

@torch.no_grad()
def forward(self, audio: Tensor) -> Tensor:
def forward(self, audio_big: Tensor) -> Tensor:
"""Performing the separation on audio input
Args:
Expand All @@ -411,160 +414,176 @@ def forward(self, audio: Tensor) -> Tensor:
shape `(nb_samples, nb_targets, nb_channels, nb_timesteps)`
"""
nb_sources = 4
nb_samples = audio.shape[0]

X = self.nsgt(audio)
Xmag = self.complexnorm(X)

# xumx inference - magnitude slicq estimate
Ymag_bass, Ymag_vocals, Ymag_other, Ymag_drums = self.xumx_model(Xmag)

if self.stft_wiener:
print('STFT WIENER')

# initial mix phase + magnitude estimate
Ycomplex_bass = phasemix_sep(X, Ymag_bass)
Ycomplex_vocals = phasemix_sep(X, Ymag_vocals)
Ycomplex_drums = phasemix_sep(X, Ymag_drums)
Ycomplex_other = phasemix_sep(X, Ymag_other)

y_bass = self.insgt(Ycomplex_bass, audio.shape[-1])
y_drums = self.insgt(Ycomplex_drums, audio.shape[-1])
y_other = self.insgt(Ycomplex_other, audio.shape[-1])
y_vocals = self.insgt(Ycomplex_vocals, audio.shape[-1])

# initial estimate was obtained with slicq
# now we switch to the STFT domain for the wiener step

audio = torch.squeeze(audio, dim=0)

mix_stft = torch.view_as_real(torch.stft(audio, self.n_fft, hop_length=self.n_hop, return_complex=True))
X = torch.abs(torch.view_as_complex(mix_stft))

# initializing spectrograms variable
spectrograms = torch.zeros(X.shape + (nb_sources,), dtype=audio.dtype, device=X.device)

for j, target_name in enumerate(self.ordered_targets):
# apply current model to get the source spectrogram
if target_name == 'bass':
target_est = torch.squeeze(y_bass, dim=0)
elif target_name == 'vocals':
target_est = torch.squeeze(y_vocals, dim=0)
elif target_name == 'drums':
target_est = torch.squeeze(y_drums, dim=0)
elif target_name == 'other':
target_est = torch.squeeze(y_other, dim=0)
spectrograms[..., j] = torch.abs(torch.stft(target_est, self.n_fft, hop_length=self.n_hop, return_complex=True))

# transposing it as
# (nb_samples, nb_frames, nb_bins,{1,nb_channels}, nb_sources)

spectrograms = spectrograms.permute(2, 1, 0, 3)

# rearranging it into:
# (nb_samples, nb_frames, nb_bins, nb_channels, 2) to feed
# into filtering methods
mix_stft = mix_stft.permute(2, 1, 0, 3)

nb_frames = spectrograms.shape[0]
targets_stft = torch.zeros(
mix_stft.shape + (nb_sources,), dtype=audio.dtype, device=mix_stft.device
)

pos = 0
if self.wiener_win_len:
wiener_win_len = self.wiener_win_len
else:
wiener_win_len = nb_frames
while pos < nb_frames:
cur_frame = torch.arange(pos, min(nb_frames, pos + wiener_win_len))
pos = int(cur_frame[-1]) + 1

targets_stft[cur_frame] = wiener(
spectrograms[cur_frame],
mix_stft[cur_frame],
self.niter,
softmask=self.softmask,
slicq=False, # stft wiener
nb_samples = audio_big.shape[0]
N = audio_big.shape[-1]

nchunks = (N // self.chunk_size)
if (N % self.chunk_size) != 0:
nchunks += 1

print(f'n chunks: {nchunks}')

final_estimates = []

for chunk_idx in trange(nchunks):
audio = audio_big[..., chunk_idx * self.chunk_size: min((chunk_idx + 1) * self.chunk_size, N)]
print(f'audio.shape: {audio.shape}')

X = self.nsgt(audio)
Xmag = self.complexnorm(X)

# xumx inference - magnitude slicq estimate
Ymag_bass, Ymag_vocals, Ymag_other, Ymag_drums = self.xumx_model(Xmag)

if self.stft_wiener:
print('STFT WIENER')

# initial mix phase + magnitude estimate
Ycomplex_bass = phasemix_sep(X, Ymag_bass)
Ycomplex_vocals = phasemix_sep(X, Ymag_vocals)
Ycomplex_drums = phasemix_sep(X, Ymag_drums)
Ycomplex_other = phasemix_sep(X, Ymag_other)

y_bass = self.insgt(Ycomplex_bass, audio.shape[-1])
y_drums = self.insgt(Ycomplex_drums, audio.shape[-1])
y_other = self.insgt(Ycomplex_other, audio.shape[-1])
y_vocals = self.insgt(Ycomplex_vocals, audio.shape[-1])

# initial estimate was obtained with slicq
# now we switch to the STFT domain for the wiener step

audio = torch.squeeze(audio, dim=0)

mix_stft = torch.view_as_real(torch.stft(audio, self.n_fft, hop_length=self.n_hop, return_complex=True))
X = torch.abs(torch.view_as_complex(mix_stft))

# initializing spectrograms variable
spectrograms = torch.zeros(X.shape + (nb_sources,), dtype=audio.dtype, device=X.device)

for j, target_name in enumerate(self.ordered_targets):
# apply current model to get the source spectrogram
if target_name == 'bass':
target_est = torch.squeeze(y_bass, dim=0)
elif target_name == 'vocals':
target_est = torch.squeeze(y_vocals, dim=0)
elif target_name == 'drums':
target_est = torch.squeeze(y_drums, dim=0)
elif target_name == 'other':
target_est = torch.squeeze(y_other, dim=0)
spectrograms[..., j] = torch.abs(torch.stft(target_est, self.n_fft, hop_length=self.n_hop, return_complex=True))

# transposing it as
# (nb_samples, nb_frames, nb_bins,{1,nb_channels}, nb_sources)

spectrograms = spectrograms.permute(2, 1, 0, 3)

# rearranging it into:
# (nb_samples, nb_frames, nb_bins, nb_channels, 2) to feed
# into filtering methods
mix_stft = mix_stft.permute(2, 1, 0, 3)

nb_frames = spectrograms.shape[0]
targets_stft = torch.zeros(
mix_stft.shape + (nb_sources,), dtype=audio.dtype, device=mix_stft.device
)

# getting to (nb_samples, nb_targets, channel, fft_size, n_frames, 2)
targets_stft = torch.view_as_complex(targets_stft.permute(4, 2, 1, 0, 3).contiguous())
pos = 0
if self.wiener_win_len:
wiener_win_len = self.wiener_win_len
else:
wiener_win_len = nb_frames
while pos < nb_frames:
cur_frame = torch.arange(pos, min(nb_frames, pos + wiener_win_len))
pos = int(cur_frame[-1]) + 1

targets_stft[cur_frame] = wiener(
spectrograms[cur_frame],
mix_stft[cur_frame],
self.niter,
softmask=self.softmask,
slicq=False, # stft wiener
)

# getting to (nb_samples, nb_targets, channel, fft_size, n_frames, 2)
targets_stft = torch.view_as_complex(targets_stft.permute(4, 2, 1, 0, 3).contiguous())

# inverse STFT
estimates = torch.empty(audio.shape + (nb_sources,), dtype=audio.dtype, device=audio.device)
# inverse STFT
estimates = torch.empty(audio.shape + (nb_sources,), dtype=audio.dtype, device=audio.device)

for j, target_name in enumerate(self.ordered_targets):
estimates[..., j] = torch.istft(targets_stft[j, ...], self.n_fft, hop_length=self.n_hop, length=audio.shape[-1])
for j, target_name in enumerate(self.ordered_targets):
estimates[..., j] = torch.istft(targets_stft[j, ...], self.n_fft, hop_length=self.n_hop, length=audio.shape[-1])

estimates = torch.unsqueeze(estimates, dim=0).permute(0, 3, 1, 2).contiguous()
else:
print('sliCQT WIENER')
estimates = torch.unsqueeze(estimates, dim=0).permute(0, 3, 1, 2).contiguous()
else:
print('sliCQT WIENER')

# block-wise wiener
# assemble it all into a zero-padded matrix
# block-wise wiener
# assemble it all into a zero-padded matrix

nb_slices = X[0].shape[3]
last_dim = 2
nb_slices = X[0].shape[3]
last_dim = 2

X_matrix = torch.zeros((nb_samples, self.nb_channels, self.total_f_bins, nb_slices, self.max_t_bins, last_dim), dtype=X[0].dtype, device=X[0].device)
spectrograms = torch.zeros(X_matrix.shape[:-1] + (nb_sources,), dtype=audio.dtype, device=X_matrix.device)
X_matrix = torch.zeros((nb_samples, self.nb_channels, self.total_f_bins, nb_slices, self.max_t_bins, last_dim), dtype=X[0].dtype, device=X[0].device)
spectrograms = torch.zeros(X_matrix.shape[:-1] + (nb_sources,), dtype=audio.dtype, device=X_matrix.device)

freq_start = 0
for i, X_block in enumerate(X):
nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins, last_dim = X_block.shape
freq_start = 0
for i, X_block in enumerate(X):
nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins, last_dim = X_block.shape

# assign up to the defined time bins - to the right will be zeros
X_matrix[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :] = X_block
# assign up to the defined time bins - to the right will be zeros
X_matrix[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :] = X_block

spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 0] = Ymag_vocals[i]
spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 1] = Ymag_drums[i]
spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 2] = Ymag_bass[i]
spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 3] = Ymag_other[i]
spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 0] = Ymag_vocals[i]
spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 1] = Ymag_drums[i]
spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 2] = Ymag_bass[i]
spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, 3] = Ymag_other[i]

freq_start += nb_f_bins
freq_start += nb_f_bins

spectrograms = wiener(
torch.squeeze(spectrograms, dim=0),
torch.squeeze(X_matrix, dim=0),
self.niter,
softmask=self.softmask,
slicq=True,
)
spectrograms = wiener(
torch.squeeze(spectrograms, dim=0),
torch.squeeze(X_matrix, dim=0),
self.niter,
softmask=self.softmask,
slicq=True,
)

# reverse the wiener/EM permutes etc.
spectrograms = torch.unsqueeze(spectrograms.permute(2, 1, 0, 3, 4), dim=0)
spectrograms = spectrograms.reshape(nb_samples, self.nb_channels, self.total_f_bins, nb_slices, self.max_t_bins, *spectrograms.shape[-2:])
# reverse the wiener/EM permutes etc.
spectrograms = torch.unsqueeze(spectrograms.permute(2, 1, 0, 3, 4), dim=0)
spectrograms = spectrograms.reshape(nb_samples, self.nb_channels, self.total_f_bins, nb_slices, self.max_t_bins, *spectrograms.shape[-2:])

slicq_vocals = [None]*len(X)
slicq_bass = [None]*len(X)
slicq_drums = [None]*len(X)
slicq_other = [None]*len(X)
slicq_vocals = [None]*len(X)
slicq_bass = [None]*len(X)
slicq_drums = [None]*len(X)
slicq_other = [None]*len(X)

estimates = torch.empty(audio.shape + (nb_sources,), dtype=audio.dtype, device=audio.device)
estimates = torch.empty(audio.shape + (nb_sources,), dtype=audio.dtype, device=audio.device)

nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins = X_matrix.shape[:-1]
nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins = X_matrix.shape[:-1]

# matrix back to list form for insgt
freq_start = 0
for i, X_block in enumerate(X):
nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins, _ = X_block.shape
# matrix back to list form for insgt
freq_start = 0
for i, X_block in enumerate(X):
nb_samples, self.nb_channels, nb_f_bins, nb_slices, nb_t_bins, _ = X_block.shape

slicq_vocals[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 0].contiguous()
slicq_drums[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 1].contiguous()
slicq_bass[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 2].contiguous()
slicq_other[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 3].contiguous()
slicq_vocals[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 0].contiguous()
slicq_drums[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 1].contiguous()
slicq_bass[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 2].contiguous()
slicq_other[i] = spectrograms[:, :, freq_start:freq_start+nb_f_bins, :, : nb_t_bins, :, 3].contiguous()

freq_start += nb_f_bins
freq_start += nb_f_bins

estimates[..., 0] = self.insgt(slicq_vocals, audio.shape[-1])
estimates[..., 1] = self.insgt(slicq_drums, audio.shape[-1])
estimates[..., 2] = self.insgt(slicq_bass, audio.shape[-1])
estimates[..., 3] = self.insgt(slicq_other, audio.shape[-1])
estimates[..., 0] = self.insgt(slicq_vocals, audio.shape[-1])
estimates[..., 1] = self.insgt(slicq_drums, audio.shape[-1])
estimates[..., 2] = self.insgt(slicq_bass, audio.shape[-1])
estimates[..., 3] = self.insgt(slicq_other, audio.shape[-1])

estimates = estimates.permute(0, 3, 1, 2).contiguous()
estimates = estimates.permute(0, 3, 1, 2).contiguous()
final_estimates.append(estimates)

return estimates
ests_concat = torch.cat(final_estimates, axis=-1)
print(f'ests concat: {ests_concat.shape}')
return ests_concat

def to_dict(self, estimates: Tensor, aggregate_dict: Optional[dict] = None) -> dict:
"""Convert estimates as stacked tensor to dictionary
Expand Down

0 comments on commit 4551669

Please sign in to comment.