In [None]:
!pip install git+https://github.com/tky823/ssspy.git

In [None]:
import numpy as np
import scipy.signal as ss
import matplotlib.pyplot as plt
import IPython.display as ipd

In [None]:
from ssspy.utils.dataset import download_sample_speech_data

In [None]:
n_sources = 2
max_duration = 10
sisec2010_tag = "dev1_female3"
n_fft, hop_length = 4096, 2048

In [None]:
waveform_src_img, sample_rate = download_sample_speech_data(
    n_sources=n_sources,
    sisec2010_tag=sisec2010_tag,
    max_duration=max_duration,
    conv=True,
)  # (n_channels, n_sources, n_samples)
waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)

In [None]:
for idx, waveform in enumerate(waveform_mix):
    print("Mixture: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

In [None]:
from ssspy.transform import whiten
from ssspy.algorithm import projection_back
from ssspy.bss.admmbss import ADMMBSS
from ssspy.linalg import prox

In [None]:
def l21_fn(y: np.ndarray) -> np.ndarray:
    """Mixed L21 norm.

    Args:
        y (np.ndarray):
            Input vector with shape of (n_sources, n_bins, n_frames).

    Returns:
        Sum of mixed L21 norm.
    """
    G = np.linalg.norm(y, axis=1)
    loss = np.sum(G, axis=(0, 1))

    return loss

def prox_l21(y, step_size: float = 1) -> np.ndarray:
    """Apply proximal operator of mixed L21 norm.

    Args:
        y (np.ndarray):
            Input vector with shape of (n_sources, n_bins, n_frames).
        step_size (float):
            Step size parameter.

    Returns:
        Output value computed by proximal operator of mixed L21 norm.
        The shape of (n_sources, n_bins, n_frames).
    """
    norm = np.linalg.norm(y, axis=1, keepdims=True)

    # to suppress warning RuntimeWarning
    norm = np.where(norm < step_size, step_size, norm)

    return np.maximum(1 - step_size / norm, 0) * y

In [None]:
admm_bss = ADMMBSS(
    rho=0.5,
    relaxation=1.75,
    penalty_fn=l21_fn,
    prox_penalty=prox_l21,
    scale_restoration=False,
)
print(admm_bss)

In [None]:
_, _, spectrogram_mix = ss.stft(waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [None]:
spectrogram_mix_whitened = whiten(spectrogram_mix)
spectrogram_mix_normalized = admm_bss.normalize_by_spectral_norm(spectrogram_mix_whitened)
spectrogram_est = admm_bss(spectrogram_mix_normalized, n_iter=1000)
spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)

In [None]:
_, waveform_est = ss.istft(spectrogram_est, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [None]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

In [None]:
plt.figure()
plt.plot(admm_bss.loss[1:])
plt.show()
plt.close()