In [None]:
!pip install git+https://github.com/tky823/ssspy.git

In [None]:
import numpy as np
import scipy.signal as ss
import soundfile as sf
import matplotlib.pyplot as plt
import IPython.display as ipd

In [None]:
n_sources = 2
max_samples = 10 * 16000
n_fft, hop_length = 4096, 2048
sisec2011_npz_path = "/content/SiSEC2011-{}ch.npz".format(n_sources)
mird_npz_path = "/content/MIRD-{}ch.npz".format(n_sources)

In [None]:
sisec2011_npz = np.load(sisec2011_npz_path)
mird_npz = np.load(mird_npz_path)

waveform_src_img = []

for src_idx in range(n_sources):
    key = "src_{}".format(src_idx + 1)
    waveform_src = sisec2011_npz[key][:max_samples]
    n_samples = len(waveform_src)
    _waveform_src_img = []

    for waveform_rir in mird_npz[key]:
        waveform_conv = np.convolve(waveform_src, waveform_rir)[:n_samples]
        _waveform_src_img.append(waveform_conv)

    _waveform_src_img = np.stack(_waveform_src_img, axis=0)  # (n_channels, n_samples)
    waveform_src_img.append(_waveform_src_img)

waveform_src_img = np.stack(waveform_src_img, axis=1)  # (n_channels, n_sources, n_samples)
waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)

In [None]:
for idx, waveform in enumerate(waveform_mix):
    print("Mixture: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=16000))
    print()

In [None]:
from ssspy.transform import whiten
from ssspy.algorithm import projection_back
from ssspy.bss.fdica import GradFDICA

In [None]:
def contrast_fn(y):
    return 2 * np.abs(y)

def score_fn(y):
    denom = np.maximum(np.abs(y), 1e-10)
    return y / denom

## Holonomic type

In [None]:
grad_fdica = GradFDICA(
    step_size=1e-1,
    contrast_fn=contrast_fn,
    score_fn=score_fn,
    is_holonomic=True,
    should_apply_projection_back=False
)
print(grad_fdica)

In [None]:
_, _, spectrogram_mix = ss.stft(waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [None]:
spectrogram_mix_whitened = whiten(spectrogram_mix)
spectrogram_est = grad_fdica(spectrogram_mix_whitened, n_iter=500)
spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)

In [None]:
_, waveform_est = ss.istft(spectrogram_est, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [None]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=16000))
    print()

In [None]:
plt.figure()
plt.plot(grad_fdica.loss)
plt.show()
plt.close()

## Non-holonomic type

In [None]:
grad_fdica = GradFDICA(
    step_size=1e-1,
    contrast_fn=contrast_fn,
    score_fn=score_fn,
    is_holonomic=False,
    should_apply_projection_back=False
)
print(grad_fdica)

In [None]:
_, _, spectrogram_mix = ss.stft(waveform_mix, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [None]:
spectrogram_mix_whitened = whiten(spectrogram_mix)
spectrogram_est = grad_fdica(spectrogram_mix_whitened, n_iter=500)
spectrogram_est = projection_back(spectrogram_est, reference=spectrogram_mix)

In [None]:
_, waveform_est = ss.istft(spectrogram_est, window="hann", nperseg=n_fft, noverlap=n_fft-hop_length)

In [None]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=16000))
    print()

In [None]:
plt.figure()
plt.plot(grad_fdica.loss)
plt.show()
plt.close()