In [2]:
!pip install speechbrain
!pip install museval
!pip install pesq
!pip install mir_eval

Collecting speechbrain
  Downloading speechbrain-1.0.2-py3-none-any.whl.metadata (23 kB)
Collecting hyperpyyaml (from speechbrain)
  Downloading HyperPyYAML-1.2.2-py3-none-any.whl.metadata (7.6 kB)
Collecting ruamel.yaml>=0.17.28 (from hyperpyyaml->speechbrain)
  Downloading ruamel.yaml-0.18.10-py3-none-any.whl.metadata (23 kB)
Collecting ruamel.yaml.clib>=0.2.7 (from ruamel.yaml>=0.17.28->hyperpyyaml->speechbrain)
  Downloading ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.7 kB)
Downloading speechbrain-1.0.2-py3-none-any.whl (824 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m824.8/824.8 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading HyperPyYAML-1.2.2-py3-none-any.whl (16 kB)
Downloading ruamel.yaml-0.18.10-py3-none-any.whl (117 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.7/117.7 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ruamel.yaml.clib-0.2.12

In [30]:
import os
import torch
import torchaudio
import librosa
import numpy as np
import pandas as pd
from speechbrain.inference.separation import SepformerSeparation as separator
from pesq import pesq
from tqdm import tqdm
import torch
from transformers import WavLMModel, WavLMForXVector
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn as nn
from peft import LoraConfig, get_peft_model

In [4]:
device_env = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device_env}")

Using device: cuda


In [5]:
from speechbrain.inference.separation import SepformerSeparation as SepFormerModel

sepformer_model = SepFormerModel.from_hparams(
    source="speechbrain/sepformer-whamr", 
    savedir='pretrained_models/sepformer-whamr',
    run_opts={"device":"cuda"}
).to(device_env)

hyperparams.yaml:   0%|          | 0.00/1.51k [00:00<?, ?B/s]

masknet.ckpt:   0%|          | 0.00/113M [00:00<?, ?B/s]

encoder.ckpt:   0%|          | 0.00/17.3k [00:00<?, ?B/s]

decoder.ckpt:   0%|          | 0.00/17.3k [00:00<?, ?B/s]

  state_dict = torch.load(path, map_location=device)


In [6]:
import os
import torchaudio
import librosa
import numpy as np
import mir_eval
from tqdm import tqdm

source_audio_dir = '/kaggle/input/vow-2-test/aac'
train_mix_dir = '/kaggle/working/mixed_train'
test_mix_dir = '/kaggle/working/mixed_test'
sep_output_dir = '/kaggle/working/sepformer_output'

device_env = "cuda" if torch.cuda.is_available() else "cpu"
sep_model = SepFormerModel.from_hparams(
    source="speechbrain/sepformer-whamr", 
    savedir='pretrained_models/sepformer-whamr',
    run_opts={"device": "cuda"}
)

In [9]:
def calculate_metrics(ref_signal, est_signal):
    ref_signal = np.atleast_2d(ref_signal)
    est_signal = np.atleast_2d(est_signal)
    if ref_signal.shape[1] != est_signal.shape[1]:
        raise ValueError(f"Shape mismatch: ref_signal {ref_signal.shape}, est_signal {est_signal.shape}")
    sdr_val, sir_val, sar_val, _ = mir_eval.separation.bss_eval_sources(ref_signal, est_signal)
    return sdr_val[0], sir_val[0], sar_val[0]

def get_m4a_files(root_folder):
    file_list = []
    for current_root, _, files in os.walk(root_folder):
        for item in files:
            if item.endswith('.m4a'):
                file_list.append(os.path.join(current_root, item))
    return file_list

def create_mixed_audio(path_id1, path_id2, mix_output_dir):
    files_first = get_m4a_files(path_id1)
    files_second = get_m4a_files(path_id2)
    
    if not files_first or not files_second:
        print(f"Skipping pair {path_id1}, {path_id2} (missing valid audio files)")
        return None, None, None
    
    selected_file1 = np.random.choice(files_first)
    selected_file2 = np.random.choice(files_second)

    audio_signal1, sample_rate1 = torchaudio.load(selected_file1)
    audio_signal2, sample_rate2 = torchaudio.load(selected_file2)

    resample_transform = torchaudio.transforms.Resample(orig_freq=sample_rate1, new_freq=8000)
    audio_signal1 = resample_transform(audio_signal1) if sample_rate1 != 8000 else audio_signal1
    audio_signal2 = resample_transform(audio_signal2) if sample_rate2 != 8000 else audio_signal2

    audio_signal1 = audio_signal1.mean(dim=0) if audio_signal1.shape[0] > 1 else audio_signal1.squeeze(0)
    audio_signal2 = audio_signal2.mean(dim=0) if audio_signal2.shape[0] > 1 else audio_signal2.squeeze(0)

    common_length = min(audio_signal1.shape[0], audio_signal2.shape[0])
    audio_signal1, audio_signal2 = audio_signal1[:common_length], audio_signal2[:common_length]

    combined_signal = audio_signal1 + audio_signal2
    combined_signal = combined_signal / torch.max(torch.abs(combined_signal))

    os.makedirs(mix_output_dir, exist_ok=True)
    mix_filename = os.path.join(mix_output_dir, f'{os.path.basename(selected_file1).split(".")[0]}_{os.path.basename(selected_file2).split(".")[0]}.wav')
    combined_signal = combined_signal.unsqueeze(0)
    torchaudio.save(mix_filename, combined_signal, 8000)
    return mix_filename, selected_file1, selected_file2


In [15]:
!rm -rf /kaggle/working/mixed_train

test_mix_dir = "/kaggle/working/mixtures"
mix_dictionary = {}  

audio_ids = sorted(os.listdir(source_audio_dir))[50:]
for idx in range(0, len(audio_ids) - 1, 2):
    folder1 = os.path.join(source_audio_dir, audio_ids[idx])
    folder2 = os.path.join(source_audio_dir, audio_ids[idx + 1])
    mix_path, orig_file1, orig_file2 = create_mixed_audio(folder1, folder2, test_mix_dir)
    print(mix_path)
    if mix_path:
        mix_dictionary[mix_path] = (orig_file1, orig_file2)
norm_mix_dictionary = {os.path.abspath(k): v for k, v in mix_dictionary.items()}

/kaggle/working/mixtures/00030_00339.wav
/kaggle/working/mixtures/00285_00052.wav
/kaggle/working/mixtures/00021_00059.wav
/kaggle/working/mixtures/00090_00019.wav
/kaggle/working/mixtures/00240_00056.wav
/kaggle/working/mixtures/00007_00058.wav
/kaggle/working/mixtures/00420_00147.wav
/kaggle/working/mixtures/00184_00249.wav
/kaggle/working/mixtures/00091_00375.wav
/kaggle/working/mixtures/00454_00021.wav
/kaggle/working/mixtures/00022_00402.wav
/kaggle/working/mixtures/00245_00189.wav
/kaggle/working/mixtures/00216_00153.wav
/kaggle/working/mixtures/00150_00342.wav
/kaggle/working/mixtures/00019_00134.wav
/kaggle/working/mixtures/00331_00200.wav
/kaggle/working/mixtures/00012_00002.wav
/kaggle/working/mixtures/00152_00037.wav
/kaggle/working/mixtures/00148_00196.wav
/kaggle/working/mixtures/00112_00115.wav
/kaggle/working/mixtures/00199_00481.wav
/kaggle/working/mixtures/00050_00146.wav
/kaggle/working/mixtures/00092_00096.wav
/kaggle/working/mixtures/00299_00015.wav
/kaggle/working/

In [16]:
import warnings
warnings.filterwarnings("ignore")

os.makedirs(sep_output_dir, exist_ok=True)

accumulated_sdr, accumulated_sir, accumulated_sar, accumulated_pesq = 0, 0, 0, 0
file_count = 0

def process_and_assess(mixed_filepath):
    global accumulated_sdr, accumulated_sir, accumulated_sar, accumulated_pesq, file_count

    abs_mixed_filepath = os.path.abspath(mixed_filepath)
    
    if abs_mixed_filepath not in norm_mix_dictionary:
        print(f"Warning: Missing key for file: {abs_mixed_filepath}")
        print("Available keys (first 5):", list(norm_mix_dictionary.keys())[:5])
        return None 
    mix_wave, fs = torchaudio.load(mixed_filepath)
    resample_fs = torchaudio.transforms.Resample(orig_freq=fs, new_freq=8000)
    mix_wave = resample_fs(mix_wave) if fs != 8000 else mix_wave

    separated_output = sep_model.separate_batch(mix_wave.to(device_env))
    separated_np = separated_output.cpu().squeeze(0).numpy()

    base_filename = os.path.basename(mixed_filepath).split(".")[0]
    
    orig_path1, orig_path2 = norm_mix_dictionary[abs_mixed_filepath]
    
    if orig_path1 and orig_path2:
        ref_wave1, _ = librosa.load(orig_path1, sr=8000)
        ref_wave2, _ = librosa.load(orig_path2, sr=8000)
        common_samples = min(len(ref_wave1), len(ref_wave2))
        ref_wave1, ref_wave2 = ref_wave1[:common_samples], ref_wave2[:common_samples]
        est_wave1 = separated_np[:, 0]
        est_wave2 = separated_np[:, 1]
        reference_stack = np.stack((ref_wave1, ref_wave2), axis=0)

        sdr_val, sir_val, sar_val = calculate_metrics(reference_stack, separated_np.T)
        pesq_val1 = pesq(8000, ref_wave1, est_wave1, mode='nb')
        pesq_val2 = pesq(8000, ref_wave2, est_wave2, mode='nb')

        accumulated_sdr += sdr_val
        accumulated_sir += sir_val
        accumulated_sar += sar_val
        accumulated_pesq += (pesq_val1 + pesq_val2)
        file_count += 1

    out_path1 = os.path.join(sep_output_dir, base_filename + "_spk1.wav")
    out_path2 = os.path.join(sep_output_dir, base_filename + "_spk2.wav")
    torchaudio.save(out_path1, torch.tensor(separated_np[0]).unsqueeze(0), 8000)
    torchaudio.save(out_path2, torch.tensor(separated_np[1]).unsqueeze(0), 8000)

    return out_path1, out_path2



In [51]:
accumulated_sdr, accumulated_sir, accumulated_sar, accumulated_pesq = 0, 0, 0, 0
file_count = 0

print("processed")
mixed_files_list = [os.path.join(test_mix_dir, f) for f in os.listdir(test_mix_dir)]
for mix_file in tqdm(mixed_files_list):
    process_and_assess(mix_file)

processed


  1%|          | 1/170 [00:00<02:08,  1.32it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


  4%|▎         | 6/170 [00:02<01:22,  1.98it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


  5%|▍         | 8/170 [00:04<01:23,  1.94it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


  6%|▋         | 11/170 [00:04<01:03,  2.51it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available ke

 11%|█         | 18/170 [00:05<00:34,  4.36it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 14%|█▎        | 23/170 [00:06<00:29,  5.02it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 15%|█▌        | 26/170 [00:07<00:29,  4.82it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 16%|█▋        | 28/170 [00:07<00:35,  4.01it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available ke

 26%|██▌       | 44/170 [00:09<00:18,  6.92it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available ke

 34%|███▎      | 57/170 [00:11<00:21,  5.23it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 36%|███▋      | 62/170 [00:12<00:22,  4.82it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 39%|███▉      | 67/170 [00:13<00:18,  5.44it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 41%|████      | 70/170 [00:14<00:20,  4.89it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available ke

 54%|█████▍    | 92/170 [00:14<00:06, 11.99it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 56%|█████▋    | 96/170 [00:15<00:07,  9.76it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 58%|█████▊    | 98/170 [00:16<00:10,  7.06it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 59%|█████▉    | 100/170 [00:17<00:12,  5.75it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 61%|██████    | 103/170 [00:18<00:14,  4.78it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available ke

 66%|██████▋   | 113/170 [00:19<00:08,  7.02it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 70%|███████   | 119/170 [00:21<00:10,  5.02it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available ke

 77%|███████▋  | 131/170 [00:21<00:04,  7.96it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 79%|███████▉  | 135/170 [00:22<00:04,  7.50it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available ke

 84%|████████▍ | 143/170 [00:23<00:03,  7.15it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 86%|████████▌ | 146/170 [00:25<00:05,  4.26it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']


 88%|████████▊ | 149/170 [00:26<00:05,  3.77it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available ke

 92%|█████████▏| 156/170 [00:27<00:02,  5.00it/s]

Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available keys (first 5): ['/kaggle/working/mixtures/00030_00339.wav', '/kaggle/working/mixtures/00285_00052.wav', '/kaggle/working/mixtures/00021_00059.wav', '/kaggle/working/mixtures/00090_00019.wav', '/kaggle/working/mixtures/00240_00056.wav']
Available ke

100%|██████████| 170/170 [00:28<00:00,  5.92it/s]


In [19]:
print("\nAverage Evaluation Metrics:")
print(f"Average SDR: {accumulated_sdr / file_count:.2f}")
print(f"Average SIR: {accumulated_sir / file_count:.2f}")
print(f"Average SAR: {accumulated_sar / file_count:.2f}")
print(f"Average PESQ: {accumulated_pesq / (2 * file_count):.2f}")


Average Evaluation Metrics:
Average SDR: 3.25
Average SIR: 15.98
Average SAR: 5.63
Average PESQ: 1.62


In [20]:
from sklearn.metrics import accuracy_score

def obtain_speaker_embedding(model_obj, audio_path):
    waveform_data, _ = torchaudio.load(audio_path)
    waveform_data = waveform_data.to(device_env)
    with torch.no_grad():
        embedding_out = model_obj(waveform_data).last_hidden_state.mean(dim=1)
    return embedding_out.cpu().numpy()

actual_labels, predicted_labels_pre, predicted_labels_finetuned = [], [], []
for folder in tqdm(os.listdir(sep_output_dir), desc="Evaluating Speaker Identification"):
    folder_path = os.path.join(sep_output_dir, folder)
    if not os.path.isdir(folder_path):
        continue

    spk_label1, spk_label2 = folder.split("_")[:2]
    actual_labels += [spk_label1, spk_label2]

    audio_spk1 = os.path.join(folder_path, f"speaker1.wav")
    audio_spk2 = os.path.join(folder_path, f"speaker2.wav")
    print(audio_spk1)
    emb_pre1 = obtain_speaker_embedding(wavlm_pretrained, audio_spk1)
    emb_pre2 = obtain_speaker_embedding(wavlm_pretrained, audio_spk2)
    emb_fine1 = obtain_speaker_embedding(speaker_classifier, audio_spk1)
    emb_fine2 = obtain_speaker_embedding(speaker_classifier, audio_spk2)

    predicted_labels_pre += [f"id{np.argmax(emb_pre1)}", f"id{np.argmax(emb_pre2)}"]
    predicted_labels_finetuned += [f"id{np.argmax(emb_fine1)}", f"id{np.argmax(emb_fine2)}"]




Evaluating Speaker Identification: 100%|██████████| 68/68 [00:00<00:00, 57735.36it/s]


In [36]:
fine_tuned_wavlm = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus").to(device_env)
fine_tuned_model_path = "/kaggle/input/fine-tuned-wavlm-model/fine_tuned_wavlm.pth"
fine_tuned_wavlm.load_state_dict(torch.load(fine_tuned_model_path, map_location=device_env))


In [38]:
wavlm_model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus").to(device_env)
wavlm_model.load_state_dict(torch.load(model_path, map_location=device_env))




In [34]:
sep_output_dir = "/kaggle/working/sepformer_output"

def extract_speaker_embedding(audio_path):
    waveform, sample_rate = torchaudio.load(audio_path)
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
    waveform = waveform.mean(dim=0).unsqueeze(0)

    with torch.no_grad():
        embedding = wavlm_model(waveform.to(device_env)).embeddings.cpu().numpy()
    return embedding

separated_files = sorted([f for f in os.listdir(sep_output_dir) if f.endswith(".wav")])
num_correct = 0
total_speakers = len(separated_files) // 2

reference_embeddings = {}


for mix_file, (orig_file1, orig_file2) in norm_mix_dictionary.items():
    ref_embedding1 = extract_speaker_embedding(orig_file1)
    ref_embedding2 = extract_speaker_embedding(orig_file2)
    reference_embeddings[orig_file1] = ref_embedding1
    reference_embeddings[orig_file2] = ref_embedding2

for sep_file in separated_files:
    sep_path = os.path.join(sep_output_dir, sep_file)
    sep_embedding = extract_speaker_embedding(sep_path)

    best_match = None
    best_score = -1

    for ref_path, ref_embedding in reference_embeddings.items():
        score = cosine_similarity(sep_embedding, ref_embedding)[0, 0]
        if score > best_score:
            best_score = score
            best_match = ref_path

    if best_match in norm_mix_dictionary[os.path.abspath(sep_path)]:
        num_correct += 1


In [53]:
rank1_pretrained = compute_rank1_accuracy(pretrained_wavlm)
rank1_finetuned = compute_rank1_accuracy(fine_tuned_wavlm)

print("\nRank-1 Speaker Identification Accuracy:")
print(f"Pre-trained WavLM: {rank1_pretrained * 100:.2f}%")
print(f"Fine-tuned WavLM: {rank1_finetuned * 100:.2f}%")


Rank-1 Speaker Identification Accuracy:
Pre-trained WavLM: 16.17%
Fine-tuned WavLM: 26.47%
