# Test on samples

In [None]:
import torch
from torch.utils.data import DataLoader
from audiosep.data import VoiceNoiseDatamodule
from audiosep.models import SpectroUNet2D
import numpy as np
import os
import librosa
from audiosep.data import SR, N_FFT, HOP_LENGTH
import IPython.display as ipd

## Inference on batch

In [11]:
# Data
dm = VoiceNoiseDatamodule(train_data_dir="../data/train", test_data_dir="../data/test", batch_size=1, num_workers=0)
dm.setup(stage='test')

# Model
checkpoint_path = "../wandb/run-20251208_210907-qqtu5dys/files/checkpoints/spectro_unet2d-epoch=49.ckpt"
model = SpectroUNet2D.load_from_checkpoint(checkpoint_path)

# set device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

# move model to device
model = model.to(device)
model.eval()

# get one batch from test dataloader
batch = next(iter(dm.test_dataloader()))
x_batch, y_batch = batch  # x: (B,1,F,T), y: dict with "voice","noise"

# pick first example in batch
x = x_batch[0:1].to(device)           
y = {k: v[0] for k, v in y_batch.items()}  

# inference
with torch.no_grad():
    est_voice, est_noise, masks = model(x)   # (1,1,F,T) each

  masks = F.softmax(logits) # softmax ?


## Reconstruction

In [None]:
# Use masks returned by the model
masks_np = masks.squeeze().cpu().numpy()  # (2,F,T)
mask_voice = masks_np[0]
mask_noise = masks_np[1]

print(mask_voice.shape) 
print(mask_noise.shape)

# find original mix wav to recover phase for ISTFT
example_dir = dm.test_dataset.example_dirs[0]
folder_path = os.path.join(dm.test_data_dir, example_dir)
mix_file = [f for f in os.listdir(folder_path) if f.startswith("mix")][0]
mix_path = os.path.join(folder_path, mix_file)
voice_path = os.path.join(folder_path, "voice.wav")
noise_path = os.path.join(folder_path, "noise.wav")

(513, 107)
(513, 107)


In [24]:
# load raw mix to get complex STFT (phase)
y_mix, _ = librosa.load(mix_path, sr=SR, mono=True)
S_mix = librosa.stft(y_mix, n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann")  # complex (F,T)

# Pad S_mix if needed to match mask dimensions
if S_mix.shape[1] < mask_voice.shape[1]:
    pad_width = mask_voice.shape[1] - S_mix.shape[1]
    S_mix = np.pad(S_mix, ((0, 0), (0, pad_width)), mode='constant')
elif S_mix.shape[1] > mask_voice.shape[1]:
    S_mix = S_mix[:, :mask_voice.shape[1]]

print(S_mix.shape)

(513, 107)


In [41]:
# apply masks to complex STFT
S_voice = S_mix * mask_voice
S_noise = S_mix * mask_noise
S_rec_mix = S_voice + S_noise

# inverse STFT
y_rec_voice = librosa.istft(S_voice, hop_length=HOP_LENGTH, window="hann", length=y_mix.shape[0])
y_rec_noise = librosa.istft(S_noise, hop_length=HOP_LENGTH, window="hann", length=y_mix.shape[0])
y_rec_mix = librosa.istft(S_rec_mix, hop_length=HOP_LENGTH, window="hann", length=y_mix.shape[0])

# load references
y_voice_ref, _ = librosa.load(voice_path, sr=SR, mono=True)
y_noise_ref, _ = librosa.load(noise_path, sr=SR, mono=True)
y_mix_ref = y_mix  # already loaded

print("\nOriginal voice:")
display(ipd.Audio(y_voice_ref, rate=SR))

print("\nOriginal noisy mix:")
display(ipd.Audio(y_mix_ref, rate=SR))

print(f"{model.__class__.__name__} separation:")
display(ipd.Audio(y_rec_voice, rate=SR))



Original voice:



Original noisy mix:


SpectroUNet2D separation:


In [43]:
def snr_db(ref, est):
    L = min(len(ref), len(est))
    ref = ref[:L]
    est = est[:L]
    num = np.sum(ref ** 2)
    den = np.sum((ref - est) ** 2)
    return 10.0 * np.log10((num + 1e-12) / (den + 1e-12))

snr_voice = snr_db(y_voice_ref, y_rec_voice)

print(f"SNR of original voice vs nn separated voice: {snr_voice:.3f} dB")

SNR of original voice vs nn separated voice: 1.741 dB
