# Gauss-IDLMAによる多チャネル音源分離
**注意**: このサンプルではDNNを学習していないため，分離の品質が低い．

In [None]:
%%shell
git clone https://github.com/tky823/audio_source_separation.git
pip install soundfile

In [None]:
%cd "/content/audio_source_separation/egs/sss-example/idlma"

## データの準備
[CMU ARCTICデータベース](http://www.festvox.org/cmu_arctic/)の音声，および[Multi-Channel Impulse Response Database](https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/multi-channel-impulse-response-database/)のインパルス応答を用いて，多チャネルの混合音をシミュレーションする．

In [None]:
%%shell
. ./prepare.sh

In [None]:
import sys
sys.path.append("../../../src")
sys.path.append("./src")

In [None]:
import numpy as np
import scipy.signal as ss
import soundfile as sf
import IPython.display as ipd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [None]:
from sss.idlma import GaussIDLMA

In [None]:
plt.rcParams['figure.dpi'] = 200

窓長などについて
- $T_{60}=160$ [ms]の残響のインパルス応答を使用する．
- 空間がランク$1$である仮定から，フーリエ変換の窓長は，$4096$サンプル（$=256$ [ms]）としている．
- シフト長は，窓長の半分の$2048$サンプルとしている

In [None]:
fft_size, hop_size = 4096, 2048

## DNN
`MLP` は`(batch_size, n_bins)`のサイズの入力から単一の音源を推定する．

`MLPforEstimation`は`(n_sources, n_bins, n_frames)`のサイズの入力から単一の音源を推定する．

In [None]:
class MLP(nn.Module):
    def __init__(self, n_bins, hidden_channels=1024, num_layers=5):
        super().__init__()

        net = []
        for n in range(num_layers):
            if n == 0:
                net.append(nn.Linear(n_bins, hidden_channels))
            elif n == num_layers - 1:
                net.append(nn.Linear(hidden_channels, n_bins))
            else:
                net.append(nn.Linear(hidden_channels, hidden_channels))
            net.append(nn.ReLU())
        self.net = nn.Sequential(*net)


    def forward(self, input):
        """
        Args:
            input (batch_size, n_bins)
        Returns:
            output (batch_size, n_bins)
        """
        output = self.net(input)

        return output

In [None]:
class MLPforEstimation(nn.Module):
    def __init__(self, n_bins, hidden_channels=2049, num_layers=5, n_sources=None):
        super().__init__()
        if n_sources is None:
            raise ValueError("Specify number of sources.")
        
        net = []
        for n in range(n_sources):
            net.append(MLP(n_bins, hidden_channels=hidden_channels, num_layers=num_layers))
        
        self.net = torch.nn.ModuleList(net)
    
    def forward(self, input):
        output = []
        for n, x in enumerate(input):
            x = x.permute(1, 0)
            x = self.net[n](x)
            x = x.permute(1, 0).unsqueeze(dim=0)
            output.append(x)
        
        output = torch.cat(output, dim=0)

        return output

このサンプルでは学習していないDNNを用いているが，実際にはDNNを事前に学習させる必要がある．

## 2音源分離

In [None]:
aew_mic3, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic3.wav")
axb_mic3, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic3.wav")
x_mic3 = aew_mic3 + axb_mic3

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4

x = np.vstack([x_mic3, x_mic4])
n_sources, T = x.shape

### インパルス応答畳み込み後の音

In [None]:
display(ipd.Audio(aew_mic3, rate=sr))
display(ipd.Audio(axb_mic3, rate=sr))

### 混合音

In [None]:
for idx in range(2):
    display(ipd.Audio(x[idx], rate=sr))

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

### ILRMAの実行

In [None]:
torch.manual_seed(111)
dnn = MLPforEstimation(n_bins=fft_size//2+1, num_layers=2, n_sources=2)
if torch.cuda.is_available():
    dnn.cuda()

In [None]:
np.random.seed(111)
idlma = GaussIDLMA(normalize='projection-back')

In [None]:
Y = idlma(X, dnn=dnn, iteration=100)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:,:T]

### 分離音

In [None]:
for idx in range(2):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(idlma.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

## 3音源分離

In [None]:
aew_mic2, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic2.wav")
axb_mic2, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic2.wav")
bdl_mic2, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic2.wav")
x_mic2 = aew_mic2 + axb_mic2 + bdl_mic2

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
bdl_mic4, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4 + bdl_mic4

aew_mic5, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic5.wav")
axb_mic5, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic5.wav")
bdl_mic5, sr = sf.read("./data/cmu_us_bdl_arctic/trimmed/convolved-16000_deg330-mic5.wav")
x_mic5 = aew_mic5 + axb_mic5 + bdl_mic5

x = np.vstack([x_mic2, x_mic4, x_mic5])
n_sources, T = x.shape

### インパルス応答畳み込み後の音

In [None]:
display(ipd.Audio(aew_mic2, rate=sr))
display(ipd.Audio(axb_mic2, rate=sr))
display(ipd.Audio(bdl_mic2, rate=sr))

### 混合音

In [None]:
for idx in range(3):
    display(ipd.Audio(x[idx], rate=sr))

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=hop_size)

### IDLMAの実行

In [None]:
torch.manual_seed(111)
dnn = MLPforEstimation(n_bins=fft_size//2+1, num_layers=2, n_sources=3)
if torch.cuda.is_available():
    dnn.cuda()

In [None]:
np.random.seed(111)
idlma = GaussIDLMA(normalize='projection-back')

In [None]:
Y = idlma(X, dnn=dnn, iteration=100)

In [None]:
_, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
y = y[:,:T]

### 分離音

In [None]:
for idx in range(3):
    display(ipd.Audio(y[idx], rate=sr))

In [None]:
plt.figure()
plt.plot(idlma.loss, color='black')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

## コールバック関数の例

In [None]:
aew_mic3, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic3.wav")
axb_mic3, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic3.wav")
x_mic3 = aew_mic3 + axb_mic3

aew_mic4, sr = sf.read("./data/cmu_us_aew_arctic/trimmed/convolved-16000_deg60-mic4.wav")
axb_mic4, sr = sf.read("./data/cmu_us_axb_arctic/trimmed/convolved-16000_deg300-mic4.wav")
x_mic4 = aew_mic4 + axb_mic4

x = np.vstack([x_mic3, x_mic4])
n_sources, T = x.shape

In [None]:
s = np.vstack([aew_mic3, axb_mic3])
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)

### SDR改善量の記録

In [None]:
%%shell
pip install mir_eval

In [None]:
from mir_eval.separation import bss_eval_sources

from algorithm.projection_back import projection_back

In [None]:
def record_sdri(model):
    reference_id = model.reference_id
    s = model.target # Time domain
    X, Y = model.input, model.estimation # Time-frequency domain
    n_sources, T = s.shape

    scale = projection_back(Y, reference=X[reference_id])
    Y = Y * scale[...,np.newaxis] # (n_sources, n_bins, n_frames)
    _, y = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
    y = y[:,:T]

    if hasattr(model, 'sdr_input'):
        sdr_input = model.sdr_input
    else:
        _, x = ss.istft(X, nperseg=fft_size, noverlap=fft_size-hop_size)
        x = x[reference_id,:T]
        x = np.tile(x, reps=(n_sources, 1))
        sdr_input, _, _, _ = bss_eval_sources(s, estimated_sources=x)
        model.sdr_input = sdr_input

    sdr_estimated, _, _, _ = bss_eval_sources(s, estimated_sources=y)
    sdri = sdr_estimated - sdr_input
    
    model.sdri.append(sdri.mean())

In [None]:
torch.manual_seed(111)
dnn = MLPforEstimation(n_bins=fft_size//2+1, num_layers=2, n_sources=2)
if torch.cuda.is_available():
    dnn.cuda()

In [None]:
np.random.seed(111)
idlma = GaussIDLMA(normalize='projection-back', callback=record_sdri)

In [None]:
Y = idlma(X, dnn=dnn, iteration=100, target=s, sdri=[])

In [None]:
plt.figure()
plt.plot(idlma.sdri, color='black')
plt.xlabel('Iteration')
plt.ylabel('SDR improvement')
plt.show()