In [None]:
!python -m pip install --upgrade setuptools
!pip install -qU kaleido # to save image
!pip install -q git+https://github.com/tky823/ssspy.git@v0.0.2
!pip install -q git+https://github.com/tky823/self-study_2022-summer.git

In [None]:
import os
import itertools
from concurrent.futures import ProcessPoolExecutor

In [None]:
import numpy as np
import scipy.signal as ss
import soundfile as sf
from tqdm.notebook import tqdm
from ssspy.bss._select_pair import combination_pair_selector
from ssspy.bss.ilrma import TILRMA

In [None]:
from study2022summer.data import download_data
from study2022summer.reporter import SDRiReporter

In [None]:
n_sources = 3
sample_rate = 16000
max_samples = float("inf")
n_fft, hop_length = 4096, 1024
window = "hann"
reference_id = 0

if n_sources == 3:
    cmu_arctic_tags = [["aew", "ahw", "aup"], ["awb", "axb", "bdl"], ["clb", "eey", "fem"], ["gka", "jmk", "ksp"]]
    degrees = [[0, 15, 345], [15, 30, 330], [45, 315, 345], [330, 45, 300]]
    channels = [[2, 3, 4], [2, 4, 5], [1, 3, 5], [1, 4, 6]]
else:
    raise ValueError

In [None]:
spatial_algorithms = ["IP1", "IP2", "ISS1", "ISS2"]
n_iter = 200
save_freq = 200
dof = 1000
n_basis = 2
seed = 42

In [None]:
for _cmu_arctic_tags in cmu_arctic_tags:
    for _degrees in degrees:
        for _channels in channels:
            download_data(
                cmu_arctic_tags=_cmu_arctic_tags,
                max_samples=max_samples,
                degrees=_degrees,
                channels=_channels
            )

In [None]:
def run(spatial_algorithm, cmu_arctic_tags, degrees, channels, n_basis=2, dof=1000, max_samples=48000):
    if spatial_algorithm in ["IP2", "ISS2"]:
        kwargs = {
            "pair_selector": combination_pair_selector
        }
    else:
        kwargs = {}

    waveform_src_img = download_data(
        cmu_arctic_tags=cmu_arctic_tags,
        max_samples=max_samples,
        degrees=degrees,
        channels=channels
    )
    waveform_mix = np.sum(waveform_src_img, axis=1)
    _, _, spectrogram_mix = ss.stft(
        waveform_mix,
        window=window,
        nperseg=n_fft,
        noverlap=n_fft-hop_length
    )

    sdri_reporter = SDRiReporter(
        waveform_src_img,
        n_fft=n_fft,
        hop_length=hop_length,
        window=window,
        save_freq=save_freq
    )

    ilrma = TILRMA(
        n_basis=n_basis,
        dof=dof,
        spatial_algorithm=spatial_algorithm,
        callbacks=sdri_reporter,
        record_loss=False,
        reference_id=reference_id,
        rng=np.random.default_rng(seed),
        **kwargs
    )

    spectrogram_est = ilrma(spectrogram_mix, n_iter=n_iter, sdri=[], times=[])

    _, waveform_est = ss.istft(
        spectrogram_est,
        window=window,
        nperseg=n_fft,
        noverlap=n_fft-hop_length
    )

    save_dir = "estimated_sources"
    os.makedirs(save_dir, exist_ok=True)

    s = "-".join(cmu_arctic_tags)
    s += "_"
    s += "-".join([str(degree) for degree in degrees])
    s += "_"
    s += "-".join([str(channel) for channel in channels])
    s += "_{}.wav"

    for src_idx, waveform in enumerate(waveform_est):
        save_path = os.path.join(save_dir, s.format(src_idx + 1))
        sf.write(save_path, waveform, sample_rate)

    return spatial_algorithm, ilrma.sdri[-1]

In [None]:
with ProcessPoolExecutor(max_workers=64) as executor:
    futures = [
        executor.submit(
            run,
            spatial_algorithm,
            _cmu_arctic_tags,
            _degrees,
            _channels,
            n_basis=n_basis,
            dof=dof,
            max_samples=max_samples
        ) for spatial_algorithm, _cmu_arctic_tags, _degrees, _channels in itertools.product(
            spatial_algorithms,
            cmu_arctic_tags,
            degrees,
            channels
        )
    ]

    sdri = {
        spatial_algorithm: [] for spatial_algorithm in spatial_algorithms
    }

    for future in tqdm(futures):
        spatial_algorithm, value = future.result()
        sdri[spatial_algorithm].append(value)

    np.savez(f"{TILRMA.__name__}_{n_sources}src_dof{dof}_{n_basis}bases_seed{seed}_SDRi.npz", **sdri)

In [None]:
import plotly.graph_objects as go

In [None]:
from study2022summer.utils import box_plot_sdri

In [None]:
sdri = np.load(f"{TILRMA.__name__}_{n_sources}src_dof{dof}_{n_basis}bases_seed{seed}_SDRi.npz")

In [None]:
fig = go.Figure()

for spatial_algorithm in spatial_algorithms:
    box_plot_sdri(
        fig,
        sdri[spatial_algorithm],
        label=spatial_algorithm
    )

fig.update_layout(
    title=f"<i>t</i>-ILRMA (dof={dof}, {n_basis}bases, seed={seed})",
    xaxis_title="Update algorithm",
    yaxis_title="SDR improvement [dB]",
    font={
        "size": 20
    }
)

fig.write_image(
    f"{TILRMA.__name__}_{n_sources}src_dof{dof}_{n_basis}bases_seed{seed}_SDRi.png",
    scale=10,
    width=990,
    height=540
)
fig.show()