# IS-MNMFによる多チャネル音源分離

In [None]:
%%shell
git clone https://github.com/tky823/audio_source_separation.git

In [None]:
%cd "/content/audio_source_separation/egs/bss-example/mnmf"

In [None]:
import sys
sys.path.append("../../../src")

In [None]:
import numpy as np
import scipy.signal as ss
import soundfile as sf
import IPython.display as ipd
import matplotlib.pyplot as plt

In [None]:
from bss.mnmf import MultichannelISNMF as MNMF

In [None]:
plt.rcParams['figure.dpi'] = 200

## 1\. 楽音分離

### 楽音分離のためのデータ準備
[Multi-Channel Impulse Response Database](https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/multi-channel-impulse-response-database/)のインパルス応答により作成した多チャネルの混合音を用いる．
音源（ピアノとベース）および混合音は`audio_source_separation/dataset/sample-song/`で確認できる．

### 目的音源

In [None]:
source_piano, sr = sf.read("../../../dataset/sample-song/sample-3_piano_16000.wav")
source_bass, sr = sf.read("../../../dataset/sample-song/sample-3_bass_16000.wav")

In [None]:
display(ipd.Audio(source_piano, rate=sr))
display(ipd.Audio(source_bass, rate=sr))

In [None]:
y = np.vstack([source_piano, source_bass])

### 混合音

In [None]:
mixture, sr = sf.read("../../../dataset/sample-song/sample-3_mixture_16000.wav")
x = mixture.T
n_channels, T = x.shape
n_sources = n_channels

In [None]:
for idx in range(n_channels):
    display(ipd.Audio(x[idx], rate=sr))

窓長などについて
- $T_{60}=160$ [ms]の残響のインパルス応答を使用する．
- フーリエ変換の窓長は，$4096$サンプル（$=256$ [ms]）としている．
- シフト長は，窓長の半分の$2048$サンプルとしている

In [None]:
fft_size, hop_size = 4096, 2048

### MNMFの実行

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
np.random.seed(111)
mnmf = MNMF(n_basis=20, normalize=False)

In [None]:
print(mnmf)

In [None]:
Y = mnmf(X, iteration=200)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:, :T]

### 分離音

In [None]:
for idx in range(n_sources):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(mnmf.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

## 2\. 話者分離

### 2.1 話者分離のためのデータの準備
[CMU ARCTICデータベース](http://www.festvox.org/cmu_arctic/)の音声，および[Multi-Channel Impulse Response Database](https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/multi-channel-impulse-response-database/)のインパルス応答を用いて，多チャネルの混合音をシミュレーションする．

In [None]:
%%shell
. ./prepare.sh

窓長などについて
- $T_{60}=360$ [ms]の残響のインパルス応答を使用する．
- 空間がフルランクである仮定から，フーリエ変換の窓長は，$4096$サンプル（$=256$ [ms] $> T_{60}$）としている．
- シフト長は，窓長の半分の$2048$サンプルとしている．

In [None]:
fft_size, hop_size = 4096, 2048

### 2.2 2話者分離

In [None]:
aew_mic3, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic3.wav")
axb_mic3, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic3.wav")
x_mic3 = aew_mic3 + axb_mic3

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4

x = np.vstack([x_mic3, x_mic4])
n_channels, T = x.shape
n_sources = n_channels

#### インパルス応答畳み込み後の音

In [None]:
display(ipd.Audio(aew_mic3, rate=sr))
display(ipd.Audio(axb_mic3, rate=sr))

#### 混合音

In [None]:
for idx in range(n_channels):
    display(ipd.Audio(x[idx], rate=sr))

#### MNMFの実行

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
np.random.seed(111)
mnmf = MNMF(n_basis=20)

In [None]:
print(mnmf)

In [None]:
Y = mnmf(X, iteration=200)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:,:T]

#### 分離音

In [None]:
for idx in range(n_sources):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(mnmf.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

### 2.3 3話者分離（3マイク）

In [None]:
aew_mic2, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic2.wav")
axb_mic2, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic2.wav")
bdl_mic2, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic2.wav")
x_mic2 = aew_mic2 + axb_mic2 + bdl_mic2

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
bdl_mic4, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4 + bdl_mic4

aew_mic5, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic5.wav")
axb_mic5, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic5.wav")
bdl_mic5, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic5.wav")
x_mic5 = aew_mic5 + axb_mic5 + bdl_mic5

x = np.vstack([x_mic2, x_mic4, x_mic5])
n_channels, T = x.shape
n_sources = n_channels

#### インパルス応答畳み込み後の音

In [None]:
display(ipd.Audio(aew_mic2, rate=sr))
display(ipd.Audio(axb_mic2, rate=sr))
display(ipd.Audio(bdl_mic2, rate=sr))

#### 混合音

In [None]:
for idx in range(n_channels):
    display(ipd.Audio(x[idx], rate=sr))

#### MNMFの実行

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
np.random.seed(111)
mnmf = MNMF(n_basis=20)

In [None]:
print(mnmf)

In [None]:
Y = mnmf(X, iteration=200)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:,:T]

#### 分離音

In [None]:
for idx in range(n_sources):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(mnmf.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

### 2.4 3話者分離 (2マイク)

In [None]:
aew_mic2, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic2.wav")
axb_mic2, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic2.wav")
bdl_mic2, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic2.wav")
x_mic2 = aew_mic2 + axb_mic2 + bdl_mic2

aew_mic5, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic5.wav")
axb_mic5, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic5.wav")
bdl_mic5, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic5.wav")
x_mic5 = aew_mic5 + axb_mic5 + bdl_mic5

x = np.vstack([x_mic2, x_mic5])
n_channels, T = x.shape
n_sources = 3

#### インパルス応答畳み込み後の音

In [None]:
display(ipd.Audio(aew_mic2, rate=sr))
display(ipd.Audio(axb_mic2, rate=sr))
display(ipd.Audio(bdl_mic2, rate=sr))

#### 混合音

In [None]:
for idx in range(n_channels):
    display(ipd.Audio(x[idx], rate=sr))

#### MNMFの実行

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
np.random.seed(111)
mnmf = MNMF(n_basis=20, n_sources=n_sources)

In [None]:
print(mnmf)

In [None]:
Y = mnmf(X, iteration=200)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:,:T]

#### Separated sources

In [None]:
for idx in range(n_sources):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(mnmf.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

## 3\. コールバック関数の例

In [None]:
aew_mic3, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic3.wav")
axb_mic3, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic3.wav")
x_mic3 = aew_mic3 + axb_mic3

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4

x = np.vstack([x_mic3, x_mic4])
n_channels, T = x.shape
n_sources = n_channels

In [None]:
s = np.vstack([aew_mic3, axb_mic3])
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

#### 3.1 SDR改善量の記録

In [None]:
%%shell
pip install mir_eval

In [None]:
from mir_eval.separation import bss_eval_sources

In [None]:
def record_sdri(model):
    reference_id = model.reference_id
    s = model.target # Time domain
    X, Y = model.input, model.estimation # Time-frequency domain
    n_sources, T = s.shape
    
    _, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
    y = y[:,:T]

    if hasattr(model, 'sdr_input'):
        sdr_input = model.sdr_input
    else:
        _, x = ss.istft(X, nperseg=fft_size, noverlap=fft_size-hop_size)
        x = x[reference_id,:T]
        x = np.tile(x, reps=(n_sources, 1))
        sdr_input, _, _, _ = bss_eval_sources(s, estimated_sources=x)
        model.sdr_input = sdr_input

    sdr_estimated, _, _, _ = bss_eval_sources(s, estimated_sources=y)
    sdri = sdr_estimated - sdr_input
    
    model.sdri.append(sdri.mean())

In [None]:
np.random.seed(111)
mnmf = MNMF(n_basis=20, callbacks=record_sdri)

In [None]:
Y = mnmf(X, iteration=200, target=s, sdri=[])

In [None]:
plt.figure()
plt.plot(mnmf.sdri, color='black')
plt.xlabel('Iteration')
plt.ylabel('SDR improvement')
plt.show()