In [1]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), '..'))
from sklearn.model_selection import train_test_split
import torch
from src.data import create_manifest, load_manifest, AudioDataset
from src.training import TkLossPlotter
from torch.utils.data import Subset

In [2]:
DEBUG = True
MAX_SAMPLES = 256
os.chdir(os.path.join(os.getcwd(), '..'))

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Manifest
Manifest is a json file that contains the paths to the clean and noisy audio files
This is what the dataset will use to load the data in chunks rather than loading 
the entire dataset into memory

In [4]:
manifest_file_name = 'src/data/manifest_trainset_28spk_wav.json'
if not os.path.exists(manifest_file_name):
    manifest = create_manifest(manifest_file_name, 'D:/denoise_sound_files/', 'trainset_28spk_wav')
else:
    manifest = load_manifest('manifest_trainset_28spk_wav.json')

In [5]:
dataset = AudioDataset(
    manifest=manifest,
    transform=None,
    segment_ms=2000,
    sample_rate=16000,
    mono=True,
    max_samples=MAX_SAMPLES if DEBUG else None
)

# Split test set from train set
indices = torch.arange(len(dataset))
train_val_idx, test_idx = train_test_split(
    indices, test_size=0.15, random_state=42, shuffle=True
)
test_set = Subset(dataset, test_idx)

In [6]:
from src.models import *
from src.training import *

learning_options = {
    'batch_size': 16,
    'learning_rate': .001,
    'epochs': 10,
    'patience': 4,
}

patience = learning_options['patience']

l1_loss = torch.nn.functional.l1_loss
mse_loss = torch.nn.functional.mse_loss
loss_fn = denoise_loss
loss_fns = [torch.nn.functional.l1_loss, torch.nn.functional.mse_loss, spectral_l1]
neg_si_sdr = lambda pred, target: -si_sdr_loss(pred, target)
loss_fns = [neg_si_sdr, spectral_l1, l1_loss, mse_loss]
weights = [4.0, 3.0, 2.0, 1.0]

split_idx = int(len(train_val_idx) * 0.8)
train_index = train_val_idx[:split_idx]
val_index = train_val_idx[split_idx:]

train_set = Subset(dataset, train_index)
val_set = Subset(dataset, val_index)

print('Getting dataloaders...')
models = [AudioUNet_v1(transforms=None, device=device), AudioUNet_v2(transforms=None, device=device)]


model1 = ModelWrapper(models[0], learning_options['learning_rate'], weight_decay=5e-4, step_size=1, gamma=0.25)
model2 = ModelWrapper(models[1], learning_options['learning_rate'], weight_decay=5e-4, step_size=1, gamma=0.25)

models = [model1, model2]
# models = [ModelWrapper(model, torch.optim.AdamW(model.parameters(), lr=learning_options['learning_rate'], weight_decay=5e-4), torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)) for model in models]
# models = {f'model_{i}': model for i, model in enumerate(models)}

epochs = learning_options['epochs']

train_loader = get_dataloader(train_set, learning_options['batch_size'], shuffle=True, device=device)
val_loader = get_dataloader(val_set, learning_options['batch_size'], shuffle=False, device=device)

# plotter = TkLossPlotter(refresh_hz=10)
# plotter.start(title=f'{fold_script} Training & Validation Loss')
# plotter.start(title='Training & Validation Loss')

plotters = {model: TkLossPlotter(refresh_hz=10) for model in models}
for model in models:
    plotters[model].start(title=f'{model} Training & Validation Loss')

for epoch in range(epochs):
    epoch_script = f'Epoch {epoch+1} of {epochs}'
    print(f'\r{epoch_script}', end='')
    
    for model in models:
        if model.stopped_early:
            continue

        train_loss = train_model(model, train_loader, model.optimizer, loss_fn)
        val_loss = evaluate_model(model, val_loader, loss_fn)
        print(f'Model {model} train loss: {train_loss:.6f}, val loss: {val_loss:.6f}')
        model.train_losses.append(train_loss)
        model.val_losses.append(val_loss)
        tloss_avg = sum(model.train_losses) / len(model.train_losses)
        vloss_avg = sum(model.val_losses) / len(model.val_losses)

        plotters[model].update((tloss_avg, vloss_avg))
        model.scheduler.step()

    # Update plot
    # plotter.update((train_loss, val_loss))
    
        # Early stopping
        if val_loss < model.best_val_loss:
            model.best_val_loss = val_loss
            model.best_epoch = epoch
            model.patience_counter = 0
        else:
            model.patience_counter += 1
        
        if model.patience_counter >= patience:
            print(f'Model {model} stopped early at epoch {epoch+1}')
            model.stopped_early = True
            continue


Getting dataloaders...
Epoch 1 of 10

Training: 100%|██████████| 10/10 [00:06<00:00,  1.47it/s, Data load time: 0.25s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  4.82it/s, Data load time: 0.07s]


Model AudioUNet_v1() train loss: 0.499934, val loss: 0.117406


Training: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s, Data load time: 0.20s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  3.89it/s, Data load time: 0.25s]


Model AudioUNet_v2() train loss: 0.189307, val loss: 0.170927
Epoch 2 of 10

Training: 100%|██████████| 10/10 [00:06<00:00,  1.50it/s, Data load time: 0.20s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  4.35it/s, Data load time: 0.09s]


Model AudioUNet_v1() train loss: 0.119339, val loss: 0.126906


Training: 100%|██████████| 10/10 [00:03<00:00,  2.90it/s, Data load time: 0.13s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  6.80it/s, Data load time: 0.07s]


Model AudioUNet_v2() train loss: 0.173535, val loss: 0.176492
Epoch 3 of 10

Training: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s, Data load time: 0.11s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  4.92it/s, Data load time: 0.07s]


Model AudioUNet_v1() train loss: 0.126430, val loss: 0.121797


Training: 100%|██████████| 10/10 [00:03<00:00,  2.93it/s, Data load time: 0.17s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  6.09it/s, Data load time: 0.11s]


Model AudioUNet_v2() train loss: 0.170891, val loss: 0.161656
Epoch 4 of 10

Training: 100%|██████████| 10/10 [00:05<00:00,  1.71it/s, Data load time: 0.13s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  4.50it/s, Data load time: 0.07s]


Model AudioUNet_v1() train loss: 0.119810, val loss: 0.116844


Training: 100%|██████████| 10/10 [00:03<00:00,  2.89it/s, Data load time: 0.13s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  5.95it/s, Data load time: 0.10s]


Model AudioUNet_v2() train loss: 0.167487, val loss: 0.160093
Epoch 5 of 10

Training: 100%|██████████| 10/10 [00:06<00:00,  1.50it/s, Data load time: 0.13s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  3.73it/s, Data load time: 0.08s]


Model AudioUNet_v1() train loss: 0.119112, val loss: 0.119783


Training: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s, Data load time: 0.16s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  6.13it/s, Data load time: 0.08s]


Model AudioUNet_v2() train loss: 0.166631, val loss: 0.169351
Epoch 6 of 10

Training: 100%|██████████| 10/10 [00:07<00:00,  1.32it/s, Data load time: 0.13s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  3.78it/s, Data load time: 0.08s]


Model AudioUNet_v1() train loss: 0.118978, val loss: 0.113640


Training: 100%|██████████| 10/10 [00:04<00:00,  2.36it/s, Data load time: 0.13s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  5.91it/s, Data load time: 0.07s]


Model AudioUNet_v2() train loss: 0.170753, val loss: 0.167860
Epoch 7 of 10

Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s, Data load time: 0.13s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  3.50it/s, Data load time: 0.12s]


Model AudioUNet_v1() train loss: 0.115679, val loss: 0.120579


Training: 100%|██████████| 10/10 [00:04<00:00,  2.30it/s, Data load time: 0.16s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  5.19it/s, Data load time: 0.09s]


Model AudioUNet_v2() train loss: 0.174827, val loss: 0.171050
Epoch 8 of 10

Training: 100%|██████████| 10/10 [00:07<00:00,  1.39it/s, Data load time: 0.13s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  3.91it/s, Data load time: 0.08s]


Model AudioUNet_v1() train loss: 0.116030, val loss: 0.116033


Training: 100%|██████████| 10/10 [00:04<00:00,  2.22it/s, Data load time: 0.15s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  3.06it/s, Data load time: 0.15s]


Model AudioUNet_v2() train loss: 0.168197, val loss: 0.165837
Model AudioUNet_v2() stopped early at epoch 8
Epoch 9 of 10

Training: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s, Data load time: 0.13s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  3.98it/s, Data load time: 0.08s]


Model AudioUNet_v1() train loss: 0.119091, val loss: 0.117866
Epoch 10 of 10

Training: 100%|██████████| 10/10 [00:07<00:00,  1.35it/s, Data load time: 0.14s]
Validating: 100%|██████████| 2/2 [00:00<00:00,  3.82it/s, Data load time: 0.09s]

Model AudioUNet_v1() train loss: 0.118281, val loss: 0.116731
Model AudioUNet_v1() stopped early at epoch 10





In [7]:
import os
src_file_name = 'model_tuning'
snapshot_path = lambda model: f'./loss_plots/{src_file_name}/{model.name}'
for model in models:
    plotters[model].stop(save_dir=snapshot_path(model))

In [8]:
for model in models:
    print(f'Model {model.model} Best Validation Loss: {model.best_val_loss} at epoch {model.best_epoch+1}')

Model AudioUNet_v1() Best Validation Loss: 0.11364008486270905 at epoch 6
Model AudioUNet_v2() Best Validation Loss: 0.1600930020213127 at epoch 4


In [10]:
from tqdm import tqdm
import time
from importlib import reload
import src.training.helpers as th
reload(th)
from src.training.helpers import *
# Convert loss to accuracy for sound data
def compute_accuracy(y_pred, y_true, threshold=0.05) -> float:
        """
        Computes accuracy for sound data by comparing predicted and true waveforms.
        Accuracy is defined as the percentage of samples where the absolute error is below a threshold.
        Args:
            y_pred (torch.Tensor): Predicted waveform, shape (batch, 1, length)
            y_true (torch.Tensor): Ground truth waveform, shape (batch, 1, length)
            threshold (float): Maximum absolute error to count as correct
        Returns:
            float: Accuracy as a percentage (0-100)
        """
        # Ensure shapes are compatible
        if y_pred.shape != y_true.shape:
            raise ValueError(f"Shape mismatch: y_pred {y_pred.shape}, y_true {y_true.shape}")
        # Compute absolute error
        abs_error = torch.abs(y_pred - y_true)
        # Count number of elements below threshold
        correct = (abs_error < threshold).float().sum()
        total = abs_error.numel()
        accuracy = (correct / total) * 100.0
        return accuracy

test_loader = get_dataloader(test_set, learning_options['batch_size'], shuffle=False, device=model.device)
test_plotter = TkLossPlotter(refresh_hz=10)
test_plotter.start(title='Test Accuracy')
for model in models:
    acc = evaluate_model(model, test_loader, compute_accuracy, plot_updater=test_plotter.update)
    print(f'Test Accuracy: {acc:.2f}%')
    test_plotter.stop(save_dir=snapshot_path(model))

Validating: 100%|██████████| 2/2 [00:01<00:00,  1.34it/s, Data load time: 0.54s]


Test Accuracy: 76.61%


Validating: 100%|██████████| 2/2 [00:00<00:00,  5.36it/s, Data load time: 0.11s]

Test Accuracy: 70.88%





In [11]:
import numpy as np
import torch
import IPython.display as ipd

def play_audio_samples(models, dataset, device, idx=0, sample_rate=16000):
    """
    Play an audible test sample: noisy input, clean target, and denoised output.
    Args:
        model: Trained model.
        dataset: Dataset object (should return (noisy, clean) pairs).
        device: Torch device.
        idx: Index of the sample to play.
        sample_rate: Audio sample rate for playback.
    """
    # Get a sample
    noisy, clean = dataset[idx]
    # Ensure shape is (1, length)
    if noisy.ndim == 1:
        noisy = noisy.unsqueeze(0)
    if clean.ndim == 1:
        clean = clean.unsqueeze(0)
    # Add batch dimension
    noisy_batch = noisy.unsqueeze(0).to(device)
    denoised_audios = []
    with torch.no_grad():
        for model in models:
            model.eval()
            denoised = model(noisy_batch)
            denoised_audios.append(denoised)
    # Remove batch and channel dimensions
    noisy_np = noisy.squeeze().cpu().numpy()
    clean_np = clean.squeeze().cpu().numpy()
    denoised_nps = [denoised.squeeze().cpu().numpy() for denoised in denoised_audios]
    # Normalize for playback (avoid clipping)
    def norm_audio(x):
        x = x.astype(np.float32)
        maxv = np.max(np.abs(x))
        return x / maxv if maxv > 0 else x
    print("Noisy input:")
    ipd.display(ipd.Audio(norm_audio(noisy_np), rate=sample_rate))
    print("Clean target:")
    ipd.display(ipd.Audio(norm_audio(clean_np), rate=sample_rate))
    print("Denoised output:")
    for model, denoised_np in zip(models, denoised_nps):
        print(f"Denoised output for {model.name}:")
        ipd.display(ipd.Audio(norm_audio(denoised_np), rate=sample_rate))

# Example usage:
play_audio_samples(models, test_set, device, idx=0, sample_rate=16000)



Noisy input:


Clean target:


Denoised output:
Denoised output for AudioUNet_v1():


Denoised output for AudioUNet_v2():


In [12]:
def si_sdr(x, s, eps=1e-8):
    """SI_SDRI stands for Signal to Distortion Ratio Improvement.
    It is a measure of the improvement in signal-to-distortion ratio when using a denoising model.
    It is calculated as the difference between the signal-to-distortion ratio of the denoised signal and the signal-to-distortion ratio of the noisy signal.
    The higher the SI_SDRI, the better the denoising model.
    """
    x, s = x.squeeze(1), s.squeeze(1)
    s_energy = torch.sum(s**2, dim=-1, keepdim=True) + eps
    alpha = torch.sum(x*s, dim=-1, keepdim=True) / s_energy
    s_target = alpha * s
    e_noise = x - s_target
    return 10 * torch.log10((torch.sum(s_target**2, dim=-1) + eps) /
                            (torch.sum(e_noise**2, dim=-1) + eps))

def test_model(model, test_set, device):
    test_loader = get_dataloader(test_set, 1, shuffle=False, device=device)
    test_batch = next(iter(test_loader))
    noisy_sample, clean_sample = test_batch[0], test_batch[1]
    # print('shape of noisy sample:', noisy_sample.shape)
    # print('shape of clean sample:', clean_sample.shape)

    # Prepare tensors
    noisy_tensor = noisy_sample[0].unsqueeze(0).to(device)
    clean_tensor = clean_sample[0].unsqueeze(0).to(device)

    # Get model outputs
    pred = model(noisy_tensor)

    distance = lambda x, y: (x - y).abs().mean()
    eps = 1e-8
    progress = lambda y_hat: 1 - (distance(y_hat, clean_tensor) / (distance(noisy_tensor, clean_tensor) + eps))
    # Compute average differences

    performance = distance(pred, noisy_tensor) / distance(pred, clean_tensor)

    progress = progress(pred)

    si_sdri = si_sdr(pred, clean_tensor) - si_sdr(noisy_tensor, clean_tensor)
    # print(f"Model 0: avg diff to clean: {avg_diff_pred1_clean:.6f}, avg diff to noisy: {avg_diff_pred1_noisy:.6f}, performance: {performance1:.6f}")
    # print(f"Model 1: avg diff to clean: {avg_diff_pred2_clean:.6f}, avg diff to noisy: {avg_diff_pred2_noisy:.6f}, performance: {performance2:.6f}")
    return si_sdri, performance, progress

for model in models:
    si_sdri, performance, progress = test_model(model, test_set, device)
    print(f"Model {str(model)} SI-SDRi: {si_sdri.item():.2f} dB")
    print(f'Model {str(model)} performance: {performance:.6f}, progress: {progress:.6f}')

Model AudioUNet_v1() SI-SDRi: -5.10 dB
Model AudioUNet_v1() performance: 1.006414, progress: 0.369844
Model AudioUNet_v2() SI-SDRi: -35.31 dB
Model AudioUNet_v2() performance: 0.786563, progress: 0.273745


In [43]:
import src.demonstrative.data_visualization as dv
from src.demonstrative.data_visualization import *
reload(dv)

def visualize_diff_between_models(model1, model2, test_set, device):
    test_loader = get_dataloader(test_set, 1, shuffle=False, device=device)
    test_batch = next(iter(test_loader))
    noisy_sample, clean_sample = test_batch[0], test_batch[1]
    # print('shape of noisy sample:', noisy_sample.shape)
    # print('shape of clean sample:', clean_sample.shape)
    
    device = model1.device
    front_waveform = model1(noisy_sample.to(device)).detach().cpu()
    back_waveform = model2(noisy_sample.to(device)).detach().cpu()
    title = f'{str(model1)} vs {str(model2)}'
    front_label = str(model1)
    back_label = str(model2)
    visually_compare_audio_waveforms(front_waveform, 
                                     back_waveform, 
                                     sample_rate=16000, 
                                     title=title, 
                                     front_label=front_label, 
                                     back_label=back_label)

def visualize_diff_between_model_and_data(model, test_set, device, data_label='Noisy', save_path=None):
    test_loader = get_dataloader(test_set, 1, shuffle=False, device=device)
    test_batch = next(iter(test_loader))
    noisy_sample, clean_sample = test_batch[0], test_batch[1]
    # print('shape of noisy sample:', noisy_sample.shape)
    # print('shape of clean sample:', clean_sample.shape)
    
    device = model.device

    if data_label == 'Noisy':
        front_waveform = model(noisy_sample.to(device)).detach().cpu()
        back_waveform = noisy_sample.to(device).detach().cpu()
        front_label = str(model)
        back_label = 'Noisy'
        title = f'{str(model)} vs "Noisy"'
        
    elif data_label == 'Clean':
        front_waveform = clean_sample.to(device).detach().cpu()
        back_waveform = model(clean_sample.to(device)).detach().cpu()
        front_label = 'Clean'
        back_label = str(model)
        title = f'"Clean" vs {str(model)}'
    else:
        raise ValueError(f"Invalid data_label: {data_label}")

    visually_compare_audio_waveforms(front_waveform,
                                     back_waveform, 
                                     sample_rate=16000, 
                                     title=title, 
                                     front_label=front_label, 
                                     back_label=back_label,
                                     save_path=save_path)


In [44]:
# save_path = 'visualizations/_comparison.png'
# visualize_diff_between_models(models[0], models[1], test_set, device, save)

In [47]:
import os
if not os.path.exists('./visualizations'):
    os.mkdir('./visualizations')

model = models[0]
save_path = f'visualizations/{str(model)}_vs_noisy_comparison.png'
visualize_diff_between_model_and_data(model, test_set, device, data_label='Noisy', save_path=save_path)
save_path = f'visualizations/{str(model)}_vs_clean_comparison.png'
visualize_diff_between_model_and_data(model, test_set, device, data_label='Clean', save_path=save_path)



In [48]:
model = models[1]
save_path = f'visualizations/{str(model)}_vs_noisy_comparison.png'
visualize_diff_between_model_and_data(model, test_set, device, data_label='Noisy', save_path=save_path)
save_path = f'visualizations/{str(model)}_vs_clean_comparison.png'
visualize_diff_between_model_and_data(model, test_set, device, data_label='Clean', save_path=save_path)