In [None]:
!python -m pip install --upgrade setuptools
!pip install git+https://github.com/tky823/ssspy.git@feature/cacgmm

In [2]:
import numpy as np
import scipy.signal as ss
import matplotlib.pyplot as plt
import IPython.display as ipd
from tqdm.notebook import tqdm

In [3]:
from ssspy.utils.dataset import download_sample_speech_data

In [4]:
n_sources = 2
max_duration = 10
sisec2010_tag = "dev1_female3"
n_fft, hop_length = 4096, 2048

In [5]:
waveform_src_img, sample_rate = download_sample_speech_data(
    n_sources=n_sources,
    sisec2010_tag=sisec2010_tag,
    max_duration=max_duration,
    conv=True,
)  # (n_channels, n_sources, n_samples)
waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)

In [6]:
for idx, waveform in enumerate(waveform_mix):
    print("Mixture: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

Mixture: 1



Mixture: 2





In [7]:
from ssspy.special import logsumexp, softmax
from ssspy.linalg import quadratic
from ssspy.bss.cacgmm import CACGMM as CACGMMbase
from ssspy.bss._psd import to_psd

In [16]:
class CACGMM(CACGMMbase):
    def __init__(self, normalization: bool = True, covariance_randomization: bool = False, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.normalization = normalization
        self.covariance_randomization = covariance_randomization
        self.progress_bar = None

    def __call__(
        self, input: np.ndarray, n_iter: int = 100, initial_call: bool = True, **kwargs
    ) -> np.ndarray:
        self.n_iter = n_iter

        return super().__call__(input, n_iter=n_iter, initial_call=initial_call, **kwargs)

    def _init_parameters(self, rng: np.random.Generator = None) -> None:
        super()._init_parameters(rng=rng)

        if self.covariance_randomization:
            B = self.covariance
            n_sources, n_channels = self.n_sources, self.n_channels
            n_bins = self.n_bins

            eye = np.eye(n_channels, dtype=B.dtype)
            B_diag = self.rng.random((n_sources, n_bins, n_channels))
            B_diag = B_diag / B_diag.sum(axis=-1, keepdims=True)
            B = B_diag[:, :, :, np.newaxis] * np.eye(n_channels)

            self.covariance = B
        
        if self.normalization:
            self.normalize_covariance()

    def update_once(self) -> None:
        if self.progress_bar is None:
            self.progress_bar = tqdm(total=self.n_iter)

        super().update_once()

        if self.normalization:
            self.normalize_covariance()

        self.progress_bar.update(1)

    def update_parameters(self) -> None:
        super().update_parameters()

        B = self.covariance
        B = to_psd(B, flooring_fn=self.flooring_fn)
        self.covariance = B
    
    def normalize_covariance(self) -> None:
        B = self.covariance

        trace = np.trace(B, axis1=-2, axis2=-1)
        trace = np.real(trace)
        B = B / trace[..., np.newaxis, np.newaxis]

        self.covariance = B

In [17]:
_, _, spectrogram_mix = ss.stft(waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

## Initialize covariance by identity matrix

In [18]:
cacgmm = CACGMM(
    normalization=False,
    covariance_randomization=False,
    rng=np.random.default_rng(42)
)
print(cacgmm)

CACGMM(n_sources=None, record_loss=True, reference_id=0)


In [19]:
spectrogram_est = cacgmm(spectrogram_mix, n_iter=200)

  0%|          | 0/200 [00:00<?, ?it/s]

## Initialize covariance at random w/o normalization

In [20]:
cacgmm = CACGMM(
    normalization=False,
    covariance_randomization=True,
    rng=np.random.default_rng(42)
)
print(cacgmm)

CACGMM(n_sources=None, record_loss=True, reference_id=0)


In [21]:
spectrogram_est = cacgmm(spectrogram_mix, n_iter=200)

  0%|          | 0/200 [00:00<?, ?it/s]

LinAlgError: ignored

## Initialize covariance at random w normalization

In [22]:
cacgmm = CACGMM(
    normalization=True,
    covariance_randomization=True,
    rng=np.random.default_rng(42)
)
print(cacgmm)

CACGMM(n_sources=None, record_loss=True, reference_id=0)


In [23]:
spectrogram_est = cacgmm(spectrogram_mix, n_iter=200)

  0%|          | 0/200 [00:00<?, ?it/s]