# Multichannel audio source separation by FastMNMF

In [None]:
%%shell
git clone -b feature/mnmf https://github.com/tky823/audio_source_separation.git

In [None]:
%cd "/content/audio_source_separation/egs/bss-example/mnmf"

In [None]:
import sys
sys.path.append("../../../src")

In [None]:
import numpy as np
import scipy.signal as ss
import soundfile as sf
import IPython.display as ipd
import matplotlib.pyplot as plt

In [None]:
from bss.mnmf import FastMultichannelISNMF as FastMNMF

In [None]:
plt.rcParams['figure.dpi'] = 200

## 1\. Music source separation

### Data preparation for music source separation
We already created multichannel mixtures using the impulse responses of [Multi-Channel Impulse Response Database](https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/multi-channel-impulse-response-database/).
You can find the original sources (piano & violin) and its mixture in `audio_source_separation/dataset/sample-song/`.

### Target sources

In [None]:
source_piano, sr = sf.read("../../../dataset/sample-song/sample-2_piano_16000.wav")
source_violin, sr = sf.read("../../../dataset/sample-song/sample-2_violin_16000.wav")

In [None]:
display(ipd.Audio(source_piano, rate=sr))
display(ipd.Audio(source_violin, rate=sr))

In [None]:
y = np.vstack([source_piano, source_violin])

### Mixture

In [None]:
mixture, sr = sf.read("../../../dataset/sample-song/sample-2_mixture_16000.wav")
x = mixture.T
n_channels, T = x.shape
n_sources = n_channels

In [None]:
for idx in range(n_channels):
    display(ipd.Audio(x[idx], rate=sr))

Configuration of STFT
- The reverberation time is $T_{60}=160$ [ms] in the impulse response.
- The window length is $4096$ samples (= $256$ [ms]).
- The hop length is the half of the window length, i.e. $2048$ samples (= $128$ [ms]) .

In [None]:
fft_size, hop_size = 4096, 2048

### Exection of MNMF

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
np.random.seed(111)
mnmf = FastMNMF(n_bases=4)

In [None]:
print(mnmf)

In [None]:
Y = mnmf(X, iteration=50)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:, :T]

### Separated sources

In [None]:
for idx in range(n_sources):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(mnmf.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()