In [1]:
import os
import sys
import torch
import torch.nn.functional as F

import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

from IPython.display import Audio

# Define model and ckpt

HPARAM_FILE = 'hparams/WSJ0Mix/dpmamba_M.yaml'
CKPT_PATH = 'results/WSJ0Mix/dpmamba_M/1234/save/CKPT+2024-03-03+18-23-45+00'
device = 'cuda'

In [2]:
# Load hparams

argv = [HPARAM_FILE]
# argv += [--data_folder, '/yourpath/wsj0-mix/2speakers']
hparam_file, run_opts, overrides = sb.parse_arguments(argv)
with open(hparam_file) as f:
    hparams = load_hyperpyyaml(f, overrides)

# Load model weights

for name, mod in hparams['modules'].items():
    mod.load_state_dict(
        torch.load(
            os.path.join(CKPT_PATH, name+'.ckpt')
        )
    )
    mod = mod.to(device)
    mod.eval()

In [3]:
# Forward and data functions

@torch.no_grad()
def separate(mix, hparams):
    # Encode
    mix_w = hparams['Encoder'](mix)
    
    # Separate
    est_mask = hparams['MaskNet'](mix_w)
    mix_w = torch.stack([mix_w] * hparams['num_spks'])
    sep_h = mix_w * est_mask

    # Decode
    est_source = torch.cat(
        [
            hparams['Decoder'](sep_h[i]).unsqueeze(-1)
            for i in range(hparams['num_spks'])
        ],
        dim=-1
    )

    T_origin = mix.size(1)
    T_est = est_source.size(1)
    if T_origin > T_est:
        est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
    else:
        est_source = est_source[:, :T_origin, :]

    return est_source


def dataio_prep(hparams):
    """Creates data processing pipeline"""

    test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
        csv_path=hparams["test_data"],
        replacements={"data_root": hparams["data_folder"]},
    )

    datasets = [test_data]
    @sb.utils.data_pipeline.takes("mix_wav")
    @sb.utils.data_pipeline.provides("mix_sig")
    def audio_pipeline_mix(mix_wav):
        mix_sig = sb.dataio.dataio.read_audio(mix_wav)
        return mix_sig

    @sb.utils.data_pipeline.takes("s1_wav")
    @sb.utils.data_pipeline.provides("s1_sig")
    def audio_pipeline_s1(s1_wav):
        s1_sig = sb.dataio.dataio.read_audio(s1_wav)
        return s1_sig

    @sb.utils.data_pipeline.takes("s2_wav")
    @sb.utils.data_pipeline.provides("s2_sig")
    def audio_pipeline_s2(s2_wav):
        s2_sig = sb.dataio.dataio.read_audio(s2_wav)
        return s2_sig

    if hparams["num_spks"] == 3:

        @sb.utils.data_pipeline.takes("s3_wav")
        @sb.utils.data_pipeline.provides("s3_sig")
        def audio_pipeline_s3(s3_wav):
            s3_sig = sb.dataio.dataio.read_audio(s3_wav)
            return s3_sig

    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_mix)
    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s1)
    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s2)
    if hparams["num_spks"] == 3:
        sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s3)
        sb.dataio.dataset.set_output_keys(
            datasets, ["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig"]
        )
    else:
        sb.dataio.dataset.set_output_keys(
            datasets, ["id", "mix_sig", "s1_sig", "s2_sig"]
        )

    return test_data

In [4]:
# Main

test_set = dataio_prep(hparams)

for data in test_set:
    mix = data['mix_sig'].to(device)
    tars = torch.stack(
        [data['s1_sig'], data['s2_sig']],
        dim=-1
    ).to(device)
    
    ests = separate(mix.unsqueeze(0), hparams).squeeze(0)
    
    break

In [5]:
Audio(mix.cpu(), rate=8000)

In [6]:
Audio(tars[:, 0].cpu(), rate=8000)

In [7]:
Audio(tars[:, 1].cpu(), rate=8000)

In [8]:
Audio(ests[:, 0].cpu(), rate=8000)

In [9]:
Audio(ests[:, 1].cpu(), rate=8000)