# Multichannel audio source separation by ProxIVA

In [None]:
%%shell
git clone https://github.com/tky823/audio_source_separation.git

In [None]:
%cd "/content/audio_source_separation/egs/bss-example/iva"

In [None]:
import sys
sys.path.append("../../../src")

In [None]:
import numpy as np
import scipy.signal as ss
import soundfile as sf
import IPython.display as ipd
import matplotlib.pyplot as plt

In [None]:
from bss.iva import ProxLaplaceIVA

In [None]:
plt.rcParams['figure.dpi'] = 200

## 1 Music source separation

### Data preparation for music source separation
We already created multichannel mixtures using the impulse responses of [Multi-Channel Impulse Response Database](https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/multi-channel-impulse-response-database/).
You can find the original sources (piano & violin) and its mixture in `audio_source_separation/dataset/sample-song/`.

### Target sources

In [None]:
source_piano, sr = sf.read("../../../dataset/sample-song/sample-2_piano_16000.wav")
source_violin, sr = sf.read("../../../dataset/sample-song/sample-2_violin_16000.wav")

In [None]:
display(ipd.Audio(source_piano, rate=sr))
display(ipd.Audio(source_violin, rate=sr))

In [None]:
s = np.vstack([source_piano, source_violin])

### Mixture

In [None]:
mixture, sr = sf.read("../../../dataset/sample-song/sample-2_mixture_16000.wav")
x = mixture.T
n_channels, T = x.shape
n_sources = n_channels

In [None]:
for idx in range(n_channels):
    display(ipd.Audio(x[idx], rate=sr))

Configuration of STFT
- The reverberation time is $T_{60}=160$ [ms] in the impulse response.
- The window length is $4096$ samples (= $256$ [ms]) because of the assumption of rank-1 constraint.
- The hop length is the half of the window length, i.e. $2048$ samples (= $128$ [ms]) .

In [None]:
fft_size, hop_size = 4096, 2048

### Exection of IVA

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
np.random.seed(111)
iva = ProxLaplaceIVA(regularizer=1, step=1.75)

In [None]:
print(iva)

In [None]:
Y = iva(X, iteration=100)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:, :T]

### Separated sources

In [None]:
for idx in range(n_sources):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(iva.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

## 2 Speech separation

### 2.1 Data preparation for speech separation
Create multichannel mixtures using the audios of [CMU ARCTIC database](http://www.festvox.org/cmu_arctic/) and impulse responses of [Multi-Channel Impulse Response Database](https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/multi-channel-impulse-response-database/).

In [None]:
%%shell
. ./prepare.sh

Configuration of STFT
- The reverberation time is $T_{60}=160$ [ms] in the impulse response.
- The window length is $4096$ samples (= $256$ [ms]) because of the assumption of rank-1 constraint.
- The hop length is the half of the window length, i.e. $2048$ samples (= $128$ [ms]) .

In [None]:
fft_size, hop_size = 4096, 2048

### 2.2 2 speakers

In [None]:
aew_mic3, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic3.wav")
axb_mic3, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic3.wav")
x_mic3 = aew_mic3 + axb_mic3

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4

x = np.vstack([x_mic3, x_mic4])
n_channels, T = x.shape
n_sources = n_channels

#### Target sources after convolution of impulse response

In [None]:
display(ipd.Audio(aew_mic3, rate=sr))
display(ipd.Audio(axb_mic3, rate=sr))

#### Mixture

In [None]:
for idx in range(n_channels):
    display(ipd.Audio(x[idx], rate=sr))

#### Execution of IVA

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
np.random.seed(111)
iva = ProxLaplaceIVA(regularizer=1, step=1.75)

In [None]:
print(iva)

In [None]:
Y = iva(X, iteration=500)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:,:T]

#### Separated sources

In [None]:
for idx in range(n_sources):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(iva.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

### 2.3 3 speakers

In [None]:
aew_mic2, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic2.wav")
axb_mic2, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic2.wav")
bdl_mic2, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic2.wav")
x_mic2 = aew_mic2 + axb_mic2 + bdl_mic2

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
bdl_mic4, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4 + bdl_mic4

aew_mic5, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic5.wav")
axb_mic5, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic5.wav")
bdl_mic5, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic5.wav")
x_mic5 = aew_mic5 + axb_mic5 + bdl_mic5

x = np.vstack([x_mic2, x_mic4, x_mic5])
n_channels, T = x.shape
n_sources = n_channels

#### Target sources after convolution of impulse response

In [None]:
display(ipd.Audio(aew_mic2, rate=sr))
display(ipd.Audio(axb_mic2, rate=sr))
display(ipd.Audio(bdl_mic2, rate=sr))

#### Mixture

In [None]:
for idx in range(n_channels):
    display(ipd.Audio(x[idx], rate=sr))

#### Execution of IVA

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
np.random.seed(111)
iva = ProxLaplaceIVA(regularizer=1, step=1.75)

In [None]:
print(iva)

In [None]:
Y = iva(X, iteration=100)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:,:T]

#### Separated sources

In [None]:
for idx in range(n_sources):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(iva.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

## 3 Comparison between AuxIVA and ProxIVA

In [None]:
%%shell
pip install mir_eval

In [None]:
from mir_eval.separation import bss_eval_sources

In [None]:
from bss.iva import AuxLaplaceIVA
from algorithm.projection_back import projection_back

In [None]:
aew_mic3, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic3.wav")
axb_mic3, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic3.wav")
x_mic3 = aew_mic3 + axb_mic3

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4

x = np.vstack([x_mic3, x_mic4])
n_channels, T = x.shape
n_sources = n_channels

In [None]:
s = np.vstack([aew_mic3, axb_mic3])
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
def record_sdri(model):
    reference_id = model.reference_id
    s = model.target # Time domain
    X, Y = model.input, model.estimation # Time-frequency domain
    n_sources, T = s.shape

    scale = projection_back(Y, reference=X[reference_id])
    Y = Y * scale[...,np.newaxis] # (n_sources, n_bins, n_frames)
    _, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
    y = y[:,:T]

    if hasattr(model, 'sdr_input'):
        sdr_input = model.sdr_input
    else:
        _, x = ss.istft(X, nperseg=fft_size, noverlap=fft_size-hop_size)
        x = x[reference_id,:T]
        x = np.tile(x, reps=(n_sources, 1))
        sdr_input, _, _, _ = bss_eval_sources(s, estimated_sources=x)
        model.sdr_input = sdr_input

    sdr_estimated, _, _, _ = bss_eval_sources(s, estimated_sources=y)
    sdri = sdr_estimated - sdr_input
    
    model.sdri.append(sdri.mean())

In [None]:
np.random.seed(111)
aux_iva = AuxLaplaceIVA(callback=record_sdri)
Y = aux_iva(X, iteration=50, target=s, sdri=[])

In [None]:
np.random.seed(111)
prox_iva = ProxLaplaceIVA(regularizer=1, step=1.75)
Y = prox_iva(X, iteration=50, target=s, sdri=[])

In [None]:
plt.figure()
plt.plot(aux_iva.loss, color='black', label='AuxIVA')
plt.plot(prox_iva.loss, color='mediumblue', label='ProxIVA')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.legend()
plt.show()