In [1]:
!pip install git+https://github.com/tky823/ssspy.git@b8da718

Collecting git+https://github.com/tky823/ssspy.git@b8da718
  Cloning https://github.com/tky823/ssspy.git (to revision b8da718) to /tmp/pip-req-build-fy3_g3te
  Running command git clone --filter=blob:none --quiet https://github.com/tky823/ssspy.git /tmp/pip-req-build-fy3_g3te
[0m  Running command git checkout -q b8da718
  Resolved https://github.com/tky823/ssspy.git to commit b8da718
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: ssspy
  Building wheel for ssspy (pyproject.toml) ... [?25l[?25hdone
  Created wheel for ssspy: filename=ssspy-0.1.8.dev137-py3-none-any.whl size=123658 sha256=6fc562e6829bb28add0f4d41fd84549ee2bbb4b9afd3fd51a128e5ef05925174
  Stored in directory: /tmp/pip-ephem-wheel-cache-niojx3_u/wheels/d2/fd/7c/a24bdbab80510a55c46c96ff0e722c4da21977994d9b

In [21]:
import numpy as np
import scipy.signal as ss
import IPython.display as ipd
from tqdm.notebook import tqdm

In [22]:
from ssspy.utils.dataset import download_sample_speech_data

In [23]:
n_sources = 3
max_duration = 10
sisec2010_tag = "dev1_female3"
n_fft, hop_length = 4096, 2048

In [24]:
waveform_src_img, sample_rate = download_sample_speech_data(
    n_sources=n_sources,
    sisec2010_tag=sisec2010_tag,
    max_duration=max_duration,
    conv=True,
)  # (n_channels, n_sources, n_samples)
waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)

In [25]:
for idx, waveform in enumerate(waveform_mix):
    print("Mixture: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

Mixture: 1



Mixture: 2



Mixture: 3





In [26]:
import functools
from typing import Optional, Callable, List, Union, Iterable, Tuple
from ssspy.utils.flooring import choose_flooring_fn
from ssspy.bss.iva import AuxIVA as AuxIVABase
from ssspy.bss.iva import EPS, max_flooring

In [27]:
def contrast_fn(y):
    return 2 * np.linalg.norm(y, axis=1)

def d_contrast_fn(y):
    return 2 * np.ones_like(y)

In [28]:
def update_by_fast_ip1(
    separated: np.ndarray,
    weight: np.ndarray,
    flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(
        max_flooring, eps=EPS
    ),
    overwrite: bool = True,
) -> np.ndarray:
    r"""Update demixing filters by fast iterative projection.

    Args:
        separated (numpy.ndarray): Estimated spectrograms to be updated.
            The shape is (n_sources, n_bins, n_frames).
        weight (numpy.ndarray): Weights for estimated spectrogram.
            The shape is (n_sources, n_bins, n_frames).
        flooring_fn (callable, optional): A flooring function for numerical stability.
            This function is expected to return the same shape tensor as the input.
            If you explicitly set ``flooring_fn=None``,
            the identity function (``lambda x: x``) is used.
            Default: ``functools.partial(max_flooring, eps=1e-10)``.
        overwrite (bool): Overwrite ``separated`` if ``overwrite=True``. Default: ``True``.

    Returns: numpy.ndarray of updated spectrograms. The shape is (n_sources, n_bins, n_frames).

    """
    if flooring_fn is None:
        flooring_fn = identity

    if overwrite:
        Y = separated
    else:
        Y = separated.copy()

    varphi = weight
    n_sources, n_bins, _ = Y.shape
    n_channels = n_sources

    E = np.eye(n_sources, n_channels)
    E = np.tile(E, reps=(n_bins, 1, 1))

    for src_idx in range(n_sources):
        YY_conj = Y[:, np.newaxis, :, :] * Y[np.newaxis, :, :, :].conj()
        varphi_n = varphi[src_idx]
        U_tilde_n = np.mean(varphi_n * YY_conj, axis=-1)
        U_tilde_n = U_tilde_n.transpose(2, 0, 1)
        e_n = E[:, src_idx, :]
        eta_n = np.linalg.solve(U_tilde_n, e_n)
        eta_nn = np.real(eta_n[:, src_idx])
        eta_nn = np.maximum(eta_nn, 0)
        eta_nn = np.sqrt(eta_nn)
        eta_nn = flooring_fn(eta_nn)
        eta_n = eta_n / eta_nn[:, np.newaxis]
        eta_n = eta_n.transpose(1, 0).conj()
        eta_n_conj = eta_n[:, :, np.newaxis]
        Y[src_idx] = np.sum(eta_n_conj * Y, axis=0)

    return Y

In [29]:
class AuxIVA(AuxIVABase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.progress_bar = None

    def __call__(self, *args, n_iter: int = 100, **kwargs):
        self.n_iter = n_iter

        return super().__call__(*args, n_iter=n_iter, **kwargs)

    def update_once(self) -> None:
        if self.progress_bar is None:
            self.progress_bar = tqdm(total=self.n_iter)

        super().update_once()

        self.progress_bar.update(1)

In [30]:
class AuxIVAFastIP1(AuxIVA):
    def __init__(
        self,
        contrast_fn: Callable[[np.ndarray], np.ndarray] = None,
        d_contrast_fn: Callable[[np.ndarray], np.ndarray] = None,
        flooring_fn: Optional[Callable[[np.ndarray], np.ndarray]] = functools.partial(
            max_flooring, eps=EPS
        ),
        pair_selector: Optional[Callable[[int], Iterable[Tuple[int, int]]]] = None,
        callbacks: Optional[
            Union[Callable[["AuxIVA"], None], List[Callable[["AuxIVA"], None]]]
        ] = None,
        scale_restoration: Union[bool, str] = True,
        record_loss: bool = True,
        reference_id: int = 0,
    ) -> None:
        super(AuxIVA, self).__init__(
            contrast_fn=contrast_fn,
            d_contrast_fn=d_contrast_fn,
            flooring_fn=flooring_fn,
            callbacks=callbacks,
            scale_restoration=scale_restoration,
            record_loss=record_loss,
            reference_id=reference_id,
        )

        self.spatial_algorithm = "FastIP1"
        self.progress_bar = None

    def _reset(self, **kwargs) -> None:
        super()._reset(**kwargs)

        self.demix_filter = None

    def update_once(
        self,
        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self",
    ) -> None:
        flooring_fn = choose_flooring_fn(flooring_fn, method=self)

        if self.progress_bar is None:
            self.progress_bar = tqdm(total=self.n_iter)

        self.update_once_fast_ip1(flooring_fn=flooring_fn)

        self.progress_bar.update(1)

    def update_once_fast_ip1(
        self,
        flooring_fn: Optional[Union[str, Callable[[np.ndarray], np.ndarray]]] = "self",
    ) -> None:
        flooring_fn = choose_flooring_fn(flooring_fn, method=self)

        Y = self.output
        r = np.linalg.norm(Y, axis=1)
        denom = flooring_fn(2 * r)
        varphi = self.d_contrast_fn(r) / denom

        self.output = update_by_fast_ip1(Y, varphi[:, np.newaxis, :], flooring_fn=flooring_fn)

In [31]:
_, _, spectrogram_mix = ss.stft(waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [32]:
iva_ip1 = AuxIVA(
    spatial_algorithm="IP1",
    contrast_fn=contrast_fn,
    d_contrast_fn=d_contrast_fn
)
print(iva_ip1)

AuxIVA(spatial_algorithm=IP1, scale_restoration=True, record_loss=True, reference_id=0)


In [33]:
spectrogram_est = iva_ip1(spectrogram_mix, n_iter=50)

  0%|          | 0/50 [00:00<?, ?it/s]

In [34]:
_, waveform_est = ss.istft(spectrogram_est, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [35]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

Estimated source: 1



Estimated source: 2



Estimated source: 3





In [36]:
iva_fast_ip1 = AuxIVAFastIP1(
    contrast_fn=contrast_fn,
    d_contrast_fn=d_contrast_fn
)
print(iva_fast_ip1)

AuxIVA(spatial_algorithm=FastIP1, scale_restoration=True, record_loss=True, reference_id=0)


In [37]:
spectrogram_est = iva_fast_ip1(spectrogram_mix, n_iter=50)

  0%|          | 0/50 [00:00<?, ?it/s]

In [38]:
_, waveform_est = ss.istft(spectrogram_est, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [39]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

Estimated source: 1



Estimated source: 2



Estimated source: 3



