In [None]:
!pip install git+https://github.com/tky823/ssspy.git

In [None]:
import numpy as np
import scipy.signal as ss
import matplotlib.pyplot as plt
import IPython.display as ipd

In [None]:
from ssspy.utils.dataset import download_sample_speech_data

In [None]:
n_sources = 2
sample_rate = 16000
max_samples = 10 * sample_rate
sisec2010_tag = "dev1_female3"

In [None]:
waveform_src_img = download_sample_speech_data(
    n_sources=n_sources,
    sisec2010_tag=sisec2010_tag,
    max_samples=max_samples,
    conv=False,
) # (n_channels, n_sources, n_samples)
waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)

In [None]:
for idx, waveform in enumerate(waveform_mix):
    print("Mixture: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

In [None]:
from ssspy.transform import whiten
from ssspy.bss.ica import FastICA

In [None]:
def contrast_fn(x):
    return np.log(1 + np.exp(x))

def score_fn(x):
    return 1 / (1 + np.exp(-x))

def d_score_fn(x):
    sigma = 1 / (1 + np.exp(-x))
    return sigma * (1 - sigma)

In [None]:
ica = FastICA(
    contrast_fn=contrast_fn,
    score_fn=score_fn,
    d_score_fn=d_score_fn,
)
print(ica)

In [None]:
waveform_mix_whitened = whiten(waveform_mix)
waveform_est = ica(waveform_mix_whitened, n_iter=10)

In [None]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()

In [None]:
plt.figure()
plt.plot(ica.loss)
plt.show()
plt.close()