In [1]:
!pip install git+https://github.com/tky823/ssspy.git@b8da718

Collecting git+https://github.com/tky823/ssspy.git@b8da718
  Cloning https://github.com/tky823/ssspy.git (to revision b8da718) to /tmp/pip-req-build-imedc_71
  Running command git clone --filter=blob:none --quiet https://github.com/tky823/ssspy.git /tmp/pip-req-build-imedc_71
[0m  Running command git checkout -q b8da718
  Resolved https://github.com/tky823/ssspy.git to commit b8da718
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: ssspy
  Building wheel for ssspy (pyproject.toml) ... [?25l[?25hdone
  Created wheel for ssspy: filename=ssspy-0.1.8.dev137-py3-none-any.whl size=123658 sha256=6cee03dc70227b1babb0578b51197da777e543c5a1e71a3a0eff15678d38e7b1
  Stored in directory: /tmp/pip-ephem-wheel-cache-0i6kda5j/wheels/d2/fd/7c/a24bdbab80510a55c46c96ff0e722c4da21977994d9b

In [20]:
import numpy as np
import scipy.signal as ss
import IPython.display as ipd
from tqdm.notebook import tqdm

In [21]:
from ssspy.utils.dataset import download_sample_speech_data

In [22]:
n_sources = 4
max_duration = 10
sisec2010_tag = "dev1_female4"
n_fft, hop_length = 4096, 2048

In [23]:
waveform_src_img, sample_rate = download_sample_speech_data(
    n_sources=n_sources,
    sisec2010_tag=sisec2010_tag,
    max_duration=max_duration,
    conv=True,
)  # (n_channels, n_sources, n_samples)
waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)

In [24]:
for idx, waveform in enumerate(waveform_mix):
    print("Mixture: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

Mixture: 1



Mixture: 2



Mixture: 3



Mixture: 4





In [25]:
import functools
from typing import Optional, Callable, List, Union, Iterable, Tuple
from ssspy.utils.flooring import choose_flooring_fn
from ssspy.bss.ilrma import GaussILRMA as GaussILRMABase
from ssspy.bss.ilrma import EPS, max_flooring, source_algorithms

In [26]:
def update_by_fast_ip1(
    separated: np.ndarray,
    weight: np.ndarray,
    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(
        max_flooring, eps=EPS
    ),
    overwrite: bool = True,
) -> np.ndarray:
    r"""Update demixing filters by fast iterative projection.

    Args:
        separated (numpy.ndarray): Estimated spectrograms to be updated.
            The shape is (n_sources, n_bins, n_frames).
        weight (numpy.ndarray): Weights for estimated spectrogram.
            The shape is (n_sources, n_bins, n_frames).
        flooring_fn (callable, optional): A flooring function for numerical stability.
            This function is expected to return the same shape tensor as the input.
            If you explicitly set ``flooring_fn=None``,
            the identity function (``lambda x: x``) is used.
            Default: ``functools.partial(max_flooring, eps=1e-10)``.
        overwrite (bool): Overwrite ``separated`` if ``overwrite=True``. Default: ``True``.

    Returns: numpy.ndarray of updated spectrograms. The shape is (n_sources, n_bins, n_frames).

    """
    if flooring_fn is None:
        flooring_fn = identity

    if overwrite:
        Y = separated
    else:
        Y = separated.copy()

    varphi = weight
    n_sources, n_bins, _ = Y.shape
    n_channels = n_sources

    E = np.eye(n_sources, n_channels)
    E = np.tile(E, reps=(n_bins, 1, 1))

    for src_idx in range(n_sources):
        YY_conj = Y[:, np.newaxis, :, :] * Y[np.newaxis, :, :, :].conj()
        varphi_n = varphi[src_idx]
        U_tilde_n = np.mean(varphi_n * YY_conj, axis=-1)
        U_tilde_n = U_tilde_n.transpose(2, 0, 1)
        e_n = E[:, src_idx, :]
        eta_n = np.linalg.solve(U_tilde_n, e_n)
        eta_nn = np.real(eta_n[:, src_idx])
        eta_nn = np.maximum(eta_nn, 0)
        eta_nn = np.sqrt(eta_nn)
        eta_nn = flooring_fn(eta_nn)
        eta_n = eta_n / eta_nn[:, np.newaxis]
        eta_n = eta_n.transpose(1, 0).conj()
        eta_n_conj = eta_n[:, :, np.newaxis]
        Y[src_idx] = np.sum(eta_n_conj * Y, axis=0)

    return Y

In [27]:
class GaussILRMA(GaussILRMABase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.progress_bar = None

    def __call__(self, *args, n_iter: int = 100, **kwargs):
        self.n_iter = n_iter

        return super().__call__(*args, n_iter=n_iter, **kwargs)

    def update_once(self) -> None:
        if self.progress_bar is None:
            self.progress_bar = tqdm(total=self.n_iter)

        super().update_once()

        self.progress_bar.update(1)

In [28]:
class GaussILRMAFastIP1(GaussILRMA):
    def __init__(
        self,
        n_basis: int,
        source_algorithm: str = "MM",
        domain: float = 2,
        partitioning: bool = False,
        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(
            max_flooring, eps=EPS
        ),
        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,
        callbacks: Optional[
            Union[Callable[["GaussILRMA"], None], List[Callable[["GaussILRMA"], None]]]
        ] = None,
        normalization: Optional[Union[bool, str]] = True,
        scale_restoration: Union[bool, str] = True,
        record_loss: bool = True,
        reference_id: int = 0,
        rng: Optional[np.random.Generator] = None,
    ) -> None:
        super().__init__(
            n_basis=n_basis,
            partitioning=partitioning,
            flooring_fn=flooring_fn,
            callbacks=callbacks,
            scale_restoration=scale_restoration,
            record_loss=record_loss,
            reference_id=reference_id,
            rng=rng,
        )

        assert source_algorithm in source_algorithms, "Not support {}.".format(source_algorithm)
        assert 0 < domain <= 2, "domain parameter should be chosen from [0, 2]."

        if source_algorithm == "ME":
            assert domain == 2, "domain parameter should be 2 when you specify ME algorithm."

        self.spatial_algorithm = "FastIP1"
        self.source_algorithm = source_algorithm
        self.domain = domain
        self.normalization = normalization

    def _reset(
        self,
        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self",
        **kwargs,
    ) -> None:
        flooring_fn = choose_flooring_fn(flooring_fn, method=self)

        super()._reset(flooring_fn=flooring_fn, **kwargs)

        self.demix_filter = None

    def update_spatial_model(
        self,
        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self",
    ) -> None:
        flooring_fn = choose_flooring_fn(flooring_fn, method=self)

        self.update_spatial_model_fast_ip1(flooring_fn=flooring_fn)

    def update_spatial_model_fast_ip1(
        self,
        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self",
    ) -> None:
        p = self.domain

        flooring_fn = choose_flooring_fn(flooring_fn, method=self)

        Y = self.output

        if self.partitioning:
            Z = self.latent
            T, V = self.basis, self.activation
            ZTV = self.reconstruct_nmf(T, V, latent=Z)
            R = ZTV ** (2 / p)
        else:
            T, V = self.basis, self.activation
            TV = self.reconstruct_nmf(T, V)
            R = TV ** (2 / p)

        varphi = 1 / R

        self.output = update_by_fast_ip1(Y, varphi, flooring_fn=flooring_fn)

In [29]:
_, _, spectrogram_mix = ss.stft(waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [30]:
ilrma_ip1 = GaussILRMA(
    n_basis=2,
    spatial_algorithm="IP1",
    source_algorithm="MM",
    domain=2,
    partitioning=False,
    rng=np.random.default_rng(42),
    record_loss=False,
)
print(ilrma_ip1)

GaussILRMA(n_basis=2, spatial_algorithm=IP1, source_algorithm=MM, domain=2, partitioning=False, normalization=True, scale_restoration=True, record_loss=False, reference_id=0)


In [31]:
spectrogram_est = ilrma_ip1(spectrogram_mix, n_iter=100)

  0%|          | 0/100 [00:00<?, ?it/s]

In [32]:
_, waveform_est = ss.istft(spectrogram_est, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [33]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

Estimated source: 1



Estimated source: 2



Estimated source: 3



Estimated source: 4





In [34]:
ilrma_fast_ip1 = GaussILRMAFastIP1(
    n_basis=2,
    source_algorithm="MM",
    domain=2,
    partitioning=False,
    rng=np.random.default_rng(42),
    record_loss=False,
)
print(ilrma_fast_ip1)

GaussILRMA(n_basis=2, spatial_algorithm=FastIP1, source_algorithm=MM, domain=2, partitioning=False, normalization=True, scale_restoration=True, record_loss=False, reference_id=0)


In [35]:
spectrogram_est = ilrma_fast_ip1(spectrogram_mix, n_iter=100)

  0%|          | 0/100 [00:00<?, ?it/s]

In [36]:
_, waveform_est = ss.istft(spectrogram_est, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [37]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

Estimated source: 1



Estimated source: 2



Estimated source: 3



Estimated source: 4



