# Multichannel audio source separation by ILRMA

In [None]:
%%shell
git clone https://github.com/tky823/audio_source_separation.git
pip install soundfile

In [None]:
%cd "/content/audio_source_separation/egs/bss-example/ilrma"

## Data preparation
Create multichannel mixtures using the audios of [CMU ARCTIC database](http://www.festvox.org/cmu_arctic/) and impulse responses of [Multi-Channel Impulse Response Database](https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/multi-channel-impulse-response-database/).

In [None]:
%%shell
. ./prepare.sh

In [None]:
import sys
sys.path.append("../../../src")

In [None]:
import numpy as np
import scipy.signal as ss
import soundfile as sf
import IPython.display as ipd
import matplotlib.pyplot as plt

In [None]:
from bss.ilrma import GaussILRMA

In [None]:
plt.rcParams['figure.dpi'] = 200

Configuration of STFT
- The reverberation time is $T_{60}=160$ [ms] in the impulse response.
- The window length is $4096$ samples (= $256$ [ms]) because of the assumption of rank-1 constraint.
- The hop length is the half of the window length, i.e. $2048$ samples (= $128$ [ms]) .

In [None]:
fft_size, hop_size = 4096, 2048

## 2 speakers

In [None]:
aew_mic3, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic3.wav")
axb_mic3, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic3.wav")
x_mic3 = aew_mic3 + axb_mic3

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4

x = np.vstack([x_mic3, x_mic4])
n_sources, T = x.shape

### Target sources after convolution of impulse response

In [None]:
ipd.Audio(aew_mic3, rate=sr)

In [None]:
ipd.Audio(axb_mic3, rate=sr)

### Mixture

In [None]:
ipd.Audio(x[0], rate=sr)

In [None]:
ipd.Audio(x[1], rate=sr)

### Execution of ILRMA

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

In [None]:
np.random.seed(111)
ilrma = GaussILRMA(n_bases=2, normalize='projection-back')

In [None]:
Y = ilrma(X, iteration=100)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:,:T]

### Separated sources

In [None]:
ipd.Audio(y[0], rate=sr)

In [None]:
ipd.Audio(y[1], rate=sr)

In [None]:
plt.figure()
plt.plot(ilrma.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

## 3 speakers

In [None]:
aew_mic2, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic2.wav")
axb_mic2, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic2.wav")
bdl_mic2, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic2.wav")
x_mic2 = aew_mic2 + axb_mic2 + bdl_mic2

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
bdl_mic4, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4 + bdl_mic4

aew_mic5, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic5.wav")
axb_mic5, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic5.wav")
bdl_mic5, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic5.wav")
x_mic5 = aew_mic5 + axb_mic5 + bdl_mic5

x = np.vstack([x_mic2, x_mic4, x_mic5])
n_sources, T = x.shape

### Target sources after convolution of impulse response

In [None]:
ipd.Audio(aew_mic2, rate=sr)

In [None]:
ipd.Audio(axb_mic2, rate=sr)

In [None]:
ipd.Audio(bdl_mic2, rate=sr)

### Mixture

In [None]:
ipd.Audio(x[0], rate=sr)

In [None]:
ipd.Audio(x[1], rate=sr)

In [None]:
ipd.Audio(x[2], rate=sr)

### Execution of ILRMA

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=hop_size)

In [None]:
np.random.seed(111)
ilrma = GaussILRMA(n_bases=2, normalize='projection-back')

In [None]:
Y = ilrma(X, iteration=100)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=hop_size)
y = y[:,:T]

### Separated sources

In [None]:
ipd.Audio(y[0], rate=sr)

In [None]:
ipd.Audio(y[1], rate=sr)

In [None]:
ipd.Audio(y[2], rate=sr)

In [None]:
plt.figure()
plt.plot(ilrma.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

## Example of Callback Function

In [None]:
aew_mic3, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic3.wav")
axb_mic3, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic3.wav")
x_mic3 = aew_mic3 + axb_mic3

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4

x = np.vstack([x_mic3, x_mic4])
n_sources, T = x.shape

In [None]:
s = np.vstack([aew_mic3, axb_mic3])
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

### SDR改善量の記録

In [None]:
%%shell
pip install mir_eval

In [None]:
from mir_eval.separation import bss_eval_sources

from algorithm.projection_back import projection_back

In [None]:
def record_sdri(model):
    reference_id = model.reference_id
    s = model.target # Time domain
    X, Y = model.input, model.estimation # Time-frequency domain
    n_sources, T = s.shape

    scale = projection_back(Y, reference=X[reference_id])
    Y = Y * scale[...,np.newaxis] # (n_sources, n_bins, n_frames)
    _, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
    y = y[:,:T]

    if hasattr(model, 'sdr_input'):
        sdr_input = model.sdr_input
    else:
        _, x = ss.istft(X, nperseg=fft_size, noverlap=fft_size-hop_size)
        x = x[reference_id,:T]
        x = np.tile(x, reps=(n_sources, 1))
        sdr_input, _, _, _ = bss_eval_sources(s, estimated_sources=x)
        model.sdr_input = sdr_input

    sdr_estimated, _, _, _ = bss_eval_sources(s, estimated_sources=y)
    sdri = sdr_estimated - sdr_input
    
    model.sdri.append(sdri.mean())

In [None]:
np.random.seed(111)
ilrma = GaussILRMA(n_bases=2, normalize='projection-back', callback=record_sdri)

In [None]:
Y = ilrma(X, iteration=100, target=s, sdri=[])

In [None]:
plt.figure()
plt.plot(ilrma.sdri, color='black')
plt.xlabel('Iteration')
plt.ylabel('SDR improvement')
plt.show()

### 基底とアクティベーションの保存

In [None]:
import os

In [None]:
class BaseActivationSaver:
    def __init__(self, base_dir='tmp'):
        self.base_dir = base_dir
        self.iteration = 0

        os.makedirs(self.base_dir, exist_ok=True)
    
    def __call__(self, model):
        npz_name = "{}.npz".format(self.iteration + 1)
        path = os.path.join(self.base_dir, npz_name)
        np.savez(path, base=model.base, activation=model.activation)

        self.iteration += 1

In [None]:
saver = BaseActivationSaver()

In [None]:
np.random.seed(111)
ilrma = GaussILRMA(n_bases=2, normalize='projection-back', callback=saver)

In [None]:
Y = ilrma(X, iteration=100)

In [None]:
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(5, 2))

npz = np.load("tmp/{}.npz".format(1))
base, activation = npz['base'], npz['activation']
TV = 10 * np.log10(base @ activation + 1e-12)

n_sources, n_bins, n_frames = TV.shape
t, f = np.arange(n_frames) * (hop_size / sr), np.arange(n_bins) * (sr / n_bins)

axes[0].pcolormesh(t, f, TV[1])

npz = np.load("tmp/{}.npz".format(100))
base, activation = npz['base'], npz['activation']
TV = 10 * np.log10(base @ activation + 1e-12)
axes[1].pcolormesh(t, f, TV[1])

plt.show()