In [None]:
%%shell
git clone https://github.com/tky823/DNN-based_source_separation.git

cd "./DNN-based_source_separation/egs/tutorials"

pip install -r requirements.txt

## Prepare dataset

In [None]:
%%shell
# Download dataset
librispeech_root="/content/LibriSpeech"
sample_rate=16000

mkdir -p "${librispeech_root}/test-clean"
wget "http://www.openslr.org/resources/12/test-clean.tar.gz" -P "/tmp"
tar -xf "/tmp/test-clean.tar.gz" -C "/tmp/"
rm "/tmp/test-clean.tar.gz"
mv "/tmp/LibriSpeech/test-clean/"* "${librispeech_root}/test-clean/"

In [None]:
import os
import sys
import random
import json

In [None]:
sys.path.append("DNN-based_source_separation/src")
random.seed(111)

In [None]:
import matplotlib.pyplot as plt
import IPython.display as ipd

In [None]:
plt.rcParams['font.size'] = 20

In [None]:
import torch
import torchaudio

In [None]:
from utils.audio import build_window
from algorithm.frequency_mask import compute_ideal_binary_mask
from transforms.pca import PCA
from models.danet import DANet

In [None]:
COLORS = ["red", "blue"]
SAMPLE_RATE_LIBRISPEECH = 16000

n_sources = 2
threshold = 40

In [None]:
class WaveDataset:
    def __init__(self, librispeech_root, json_path=None):
        self.librispeech_root = librispeech_root

        with open(json_path) as f:
            self.json_data = json.load(f)

    def __getitem__(self, idx):
        data = self.json_data[idx]['sources']
    
        waveform_src = []
        
        for src_idx in range(n_sources):
            audio_path = os.path.join(self.librispeech_root, data['source-{}'.format(src_idx)]['path'])
            start = data['source-{}'.format(src_idx)]['start']
            end = data['source-{}'.format(src_idx)]['end']
            num_frames = end - start
            waveform, _ = torchaudio.load(audio_path, offset=start, num_frames=num_frames)
            waveform_src.append(waveform)
        
        waveform_src = torch.cat(waveform_src, dim=0)
        waveform_mix = torch.sum(waveform_src, dim=0, keepdim=True)

        return waveform_mix, waveform_src
    
    def __len__(self):
        return len(self.json_data)

In [None]:
def create_sample_waveforms(n_sources=2):
    librispeech_root = "/content/LibriSpeech"

    json_path = "/content/DNN-based_source_separation/dataset/LibriSpeech/test-clean/test-{}mix.json".format(n_sources)
    with open(json_path) as f:
        json_data = json.load(f)
    
    data_idx = 3
    data = json_data[data_idx]['sources']
    
    waveform_src = []
    
    for src_idx in range(n_sources):
        audio_path = os.path.join(librispeech_root, data['source-{}'.format(src_idx)]['path'])
        start = data['source-{}'.format(src_idx)]['start']
        end = data['source-{}'.format(src_idx)]['end']
        num_frames = end - start
        waveform, _ = torchaudio.load(audio_path, offset=start, num_frames=num_frames)
        waveform_src.append(waveform)
    
    waveform_src = torch.cat(waveform_src, dim=0)
    waveform_mix = torch.sum(waveform_src, dim=0, keepdim=True)

    return waveform_mix, waveform_src

In [None]:
def load_model(sample_rate=SAMPLE_RATE_LIBRISPEECH, n_sources=2):
    model = DANet.build_from_pretrained(task="librispeech", sample_rate=sample_rate, n_sources=n_sources)

    return model

In [None]:
def compute_threshold(amplitude, threshold=40, eps=1e-12):
    log_amplitude = 20 * torch.log10(amplitude + eps)
    max_log_amplitude = torch.max(log_amplitude)
    threshold = 10**((max_log_amplitude - threshold) / 20)
    threshold_weight = torch.where(amplitude > threshold, torch.ones_like(amplitude), torch.zeros_like(amplitude))

    return threshold_weight

In [None]:
def estimate(spectrogram_mix):
    amplitude_mix, phase_mix = torch.abs(spectrogram_mix), torch.angle(spectrogram_mix)

    amplitude_mix = amplitude_mix.unsqueeze(dim=0)
    threshold_weight = compute_threshold(amplitude_mix)

    model.eval()
    with torch.no_grad():
        amplitude_est, latent, attractor = model.extract_latent(amplitude_mix, threshold_weight=threshold_weight, n_sources=n_sources)

    threshold_weight = threshold_weight.squeeze(dim=0)
    attractor = attractor.squeeze(dim=0)
    latent = latent.squeeze(dim=0)
    amplitude_est = amplitude_est.squeeze(dim=0)
    spectrogram_est = amplitude_est * torch.exp(1j * phase_mix)

    return spectrogram_est, latent, attractor, threshold_weight

In [None]:
librispeech_root = "/content/LibriSpeech"
json_path = "/content/DNN-based_source_separation/dataset/LibriSpeech/test-clean/test-{}mix.json".format(n_sources)
dataset = WaveDataset(librispeech_root, json_path=json_path)
# waveform_mix, waveform_src = create_sample_waveforms()
waveform_mix, waveform_src = dataset[0]

for idx in range(n_sources):
    display(ipd.Audio(waveform_src[idx], rate=SAMPLE_RATE_LIBRISPEECH))

display(ipd.Audio(waveform_mix, rate=SAMPLE_RATE_LIBRISPEECH))

In [None]:
model = load_model(sample_rate=SAMPLE_RATE_LIBRISPEECH, n_sources=n_sources)
n_fft, hop_length = model.n_fft, model.hop_length
window_fn = model.window_fn
window = build_window(n_fft, window_fn=window_fn)

In [None]:
spectrogram_mix = torch.stft(waveform_mix, n_fft=n_fft, hop_length=hop_length, window=window, onesided=True, return_complex=True)
spectrogram_src = torch.stft(waveform_src, n_fft=n_fft, hop_length=hop_length, window=window, onesided=True, return_complex=True)

# Compute ideal binary mask for plotting
mask = compute_ideal_binary_mask(spectrogram_src, source_dim=0)

In [None]:
spectrogram_est, latent, attractor, threshold_weight = estimate(spectrogram_mix)

In [None]:
waveform_est = torch.istft(spectrogram_est, n_fft=n_fft, hop_length=hop_length, window=window, onesided=True, length=waveform_mix.size(-1), return_complex=False)
waveform_est = torch.split(waveform_est, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(waveform_est[idx].detach(), rate=SAMPLE_RATE_LIBRISPEECH))

## Plot principal components

In [None]:
def plot_latent_2d(latent, attractor, mask, alpha=0.1, lims=None):
    n_sources = mask.size(0)
    plt.figure(figsize=(12, 8))

    for idx in range(n_sources):
        color = COLORS[idx]
        indices, = torch.nonzero(mask[idx].flatten(), as_tuple=True)
        x, y = torch.unbind(latent[indices], dim=1)[:2]
        plt.scatter(x, y, color=color, alpha=alpha)
    
    for idx in range(n_sources):
        x, y = torch.unbind(attractor[idx], dim=1)[:2]
        plt.scatter(x, y, color="black", marker="^", s=300, linewidths=3, edgecolors="white")

    if lims is not None:
        plt.xlim(lims)
        plt.ylim(lims)

    plt.xlabel("PCA1")
    plt.ylabel("PCA2")

    plt.show()
    plt.close()

def plot_latent_3d(latent, attractor, mask, alpha=0.1, lims=None):
    n_sources = mask.size(0)

    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')

    for idx in range(n_sources):
        color = COLORS[idx]
        indices, = torch.nonzero(mask[idx].flatten(), as_tuple=True)
        x, y, zs = torch.unbind(latent[indices], dim=1)[:3]
        ax.scatter(x, y, zs=zs, color=color, alpha=alpha)
    
    for idx in range(n_sources):
        x, y, zs = torch.unbind(attractor[idx], dim=1)[:3]
        ax.scatter(x, y, zs=zs, color="black", marker="^", s=300, linewidths=3, edgecolors="white")

    if lims is not None:
        ax.set_xlim(lims)
        ax.set_ylim(lims)
        ax.set_zlim(lims)

    ax.set_xlabel("PCA1")
    ax.set_ylabel("PCA2")
    ax.set_zlabel("PCA3")

    plt.show()
    plt.close()

In [None]:
latent = latent.view(-1, latent.size(-1))
salient_indices, = torch.nonzero(threshold_weight.flatten(), as_tuple=True)
latent_salient = latent[salient_indices]

In [None]:
# Applies PCA
pca = PCA()

pca.train()
_ = pca(latent_salient)

pca.eval()
latent_projected = pca(latent)
attractor_projected = pca(attractor)

In [None]:
plot_latent_2d(latent_projected, attractor_projected, mask * threshold_weight, lims=(-5, 5))

In [None]:
plot_latent_3d(latent_projected, attractor_projected, mask * threshold_weight, lims=(-5, 5))

## Extract All Test Attractor

In [None]:
attractors = []

for waveform_mix, waveform_src in dataset:
    spectrogram_mix = torch.stft(waveform_mix, n_fft=n_fft, hop_length=hop_length, window=window, onesided=True, return_complex=True)

    _, _, attractor, _ = estimate(spectrogram_mix)

    attractors.append(attractor)

attractors = torch.stack(dim=0) # (len(dataset), embed_dim)

In [None]:
def plot_attractors_2d(attractors, alpha=0.1, lims=None):
    plt.figure(figsize=(12, 8))

    x, y = torch.unbind(attractors, dim=1)[:2]
    plt.scatter(x, y, color="blue", alpha=alpha)

    if lims is not None:
        plt.xlim(lims)
        plt.ylim(lims)

    plt.xlabel("PCA1")
    plt.ylabel("PCA2")

    plt.show()
    plt.close()

def plot_attractors_3d(attractors, alpha=0.1, lims=None):
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    x, y, zs = torch.unbind(attractors, dim=1)[:3]
    ax.scatter(x, y, zs=zs, color="blue", alpha=alpha)

    if lims is not None:
        ax.set_xlim(lims)
        ax.set_ylim(lims)
        ax.set_zlim(lims)

    ax.set_xlabel("PCA1")
    ax.set_ylabel("PCA2")
    ax.set_zlabel("PCA3")

    plt.show()
    plt.close()

In [None]:
# Applies PCA
pca = PCA()

pca.train()
attractors_projected = pca(attractors)

In [None]:
plot_attractors_2d(attractors_projected)

In [None]:
plot_attractors_3d(attractors_projected)