In [None]:
!pip install ssspy

In [None]:
import numpy as np
import scipy.signal as ss
import matplotlib.pyplot as plt
import IPython.display as ipd
from tqdm.notebook import tqdm

In [None]:
from ssspy.utils.dataset import download_sample_speech_data

In [None]:
n_sources = 2
max_duration = 10
sisec2010_tag = "dev1_female3"
n_fft, hop_length = 4096, 2048

In [None]:
waveform_src_img, sample_rate = download_sample_speech_data(
    n_sources=n_sources,
    sisec2010_tag=sisec2010_tag,
    max_duration=max_duration,
    conv=True,
)  # (n_channels, n_sources, n_samples)
waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)

In [None]:
for idx, waveform in enumerate(waveform_mix):
    print("Mixture: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

In [None]:
from ssspy.bss.ipsdta import GaussIPSDTA as GaussIPSDTABase

In [None]:
class GaussIPSDTA(GaussIPSDTABase):
    def __init__(self, *args, source_steps=1, spatial_steps=1, **kwargs):
        super().__init__(*args, **kwargs)

        self.progress_bar = None
        self.source_steps = source_steps
        self.spatial_steps = spatial_steps

    def __call__(self, *args, n_iter: int = 100, **kwargs):
        self.n_iter = n_iter

        return super().__call__(*args, n_iter=n_iter, **kwargs)

    def update_once(self) -> None:
        if self.progress_bar is None:
            self.progress_bar = tqdm(total=self.n_iter)

        for _ in range(self.source_steps):
            self.update_source_model()

        for _ in range(self.spatial_steps):
            self.update_spatial_model()

        self.progress_bar.update(1)

In [None]:
ipsdta = GaussIPSDTA(
    n_basis=2,
    n_blocks=1024,  # block 1: {1, 2}, ..., block 1023: {2045, 2046}, block 1024: {2047, 2048, 2049}
    spatial_algorithm="VCD",
    source_steps=1,
    spatial_steps=10,
    rng=np.random.default_rng(42),
)
print(ipsdta)

In [None]:
_, _, spectrogram_mix = ss.stft(waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [None]:
spectrogram_est = ipsdta(spectrogram_mix, n_iter=100)

In [None]:
_, waveform_est = ss.istft(spectrogram_est, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [None]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

In [None]:
plt.figure()
plt.plot(ipsdta.loss[1:])
plt.show()
plt.close()