In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))
os.chdir(os.path.join(os.getcwd(), '..'))
from sklearn.model_selection import KFold, train_test_split
from src.training.loss_plots import TkLossPlotter
from src.data.dataset import AudioDataset
import torch
from torch.utils.data import Subset

Data directory: D:/clean_noisy_sound_dataset/


In [2]:
DEBUG = True
MAX_SAMPLES = 512
DATA_PREPROCESSING = False

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Manifest
Manifest is a json file that contains the paths to the clean and noisy audio files
This is what the dataset will use to load the data in chunks rather than loading 
the entire dataset into memory

In [4]:
import os
from importlib import reload
import src.data.manifest
reload(src.data.manifest)
from src.data.manifest import create_manifest, load_manifest
from src.data.dataset import AudioDataset

manifest_file_name = 'src/data/manifest_trainset_28spk_wav.json'
data_dir = 'D:/clean_noisy_sound_dataset/'
sub_dir_context = 'trainset_28spk_wav' # part of the path that noisy/clean subdirs have in common
def get_manifest(manifest_file_name):
    if not os.path.exists(manifest_file_name):
        manifest = create_manifest(manifest_file_name, data_dir, sub_dir_context)
    else:
        try: 
            manifest = load_manifest(manifest_file_name)
            if not os.path.exists(manifest[0]['noisy_path']) or not os.path.exists(manifest[0]['clean_path']):
                raise Exception('Manifest file exists but paths do not exist. Attempting to create manifest...')
        except Exception as e:
            print(f"Error loading manifest: {e}")
            manifest = create_manifest(manifest_file_name, data_dir, sub_dir_context)
            print('len(manifest):', len(manifest))
    return manifest

if DATA_PREPROCESSING:
    dataset = torch.load('src/data/processed_dataset.pt')
else:
    manifest = get_manifest(manifest_file_name)
    # print(len(manifest))
    dataset = AudioDataset(
        manifest=manifest,
        transform=None,
        segment_ms=2000,
        sample_rate=16000,
        mono=True,
        max_samples=MAX_SAMPLES if DEBUG else None
    )

# Split test set from train set
indices = torch.arange(len(dataset))
train_val_idx, test_idx = train_test_split(
    indices, test_size=0.15, random_state=42, shuffle=True
)
test_set = Subset(dataset, test_idx)
# print('shape of test_set', test_set[0][0].shape)

Data directory: D:/clean_noisy_sound_dataset/


In [5]:
print(dataset[0][0].shape, dataset[0][1].shape)

torch.Size([1, 32000]) torch.Size([1, 32000])




In [6]:
from importlib import reload
import traceback
import src.models.model_classes
reload(src.models.model_classes)
from src.models.model_classes import AudioUNet_v1
# from torch.optim.lr_scheduler import CosineAnnealingLR
from src.training.loss_functions import si_sdr_loss, gain_loss_rms_db
from src.training.helpers import denoise_loss
reload(src.training.helpers)
from src.training.helpers import *

RUN_KFOLD = False
src_file_name = 'model_tuning'
snapshot_path = lambda model: f'./loss_plots/{src_file_name}/{str(model)}'

learning_options = {
    'batch_size': 16,
    'learning_rate': 1e-4,
    'epochs': 5,
    'patience': 4,
}

patience = learning_options['patience']

if RUN_KFOLD:
    n_splits = 5
    print('-'*50)
    print(f'Running K-Fold with {n_splits} splits')
    print('-'*50)
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_results = {
        'train_losses': [],
        'val_losses': []
    }
    iterator = enumerate(kf.split(train_val_idx))
else:
    # these are replacing kfold
    split_idx = int(len(train_val_idx) * 0.8)
    iterator = enumerate([[train_val_idx[:split_idx], train_val_idx[split_idx:]]])
    
l1_loss = torch.nn.functional.l1_loss
mse_loss = torch.nn.functional.mse_loss
# loss_fn = lambda pred, target: 0.5 * l1_loss(pred, target)  + 0.5 * mse_loss(pred, target)
loss_fn = denoise_loss
PLOT = DEBUG
for fold, (train_index, val_index) in iterator:
    if RUN_KFOLD:
        fold_script = f'Fold {fold+1} of {kf.n_splits}'
        print(f'{fold_script}')
        
    train_set = Subset(dataset, train_index)
    val_set = Subset(dataset, val_index)
    
    print('Getting dataloaders...')
    loss_fns = [si_sdr_loss, gain_loss_rms_db]
    weights = [1.0, 10.0]
    model = AudioUNet_v1(transforms=None, device=device, loss_fns=loss_fns, weights=weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_options['learning_rate'], weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    epochs = learning_options['epochs']
    # scheduler = CosineAnnealingLR(optimizer, T_max=epochs-3, eta_min=1e-6)

    train_loader = get_dataloader(train_set, learning_options['batch_size'], shuffle=True, device=device)
    val_loader = get_dataloader(val_set, learning_options['batch_size'], shuffle=False, device=device)

    if PLOT:
        plotter = TkLossPlotter(refresh_hz=10)
        plotter_title = f'{fold_script} Training & Validation Loss' if RUN_KFOLD else 'Training & Validation Loss'
        plotter.start(title=plotter_title)

    best_val_loss = float('inf')
    best_epoch, patience_counter = 0, 0
    
    try:
        for epoch in range(epochs):
            epoch_script = f'Epoch {epoch+1} of {epochs}'
            print(f'\r{epoch_script}', end='')
            train_loss = train_model(model, train_loader, optimizer, loss_fn)
            val_loss = evaluate_model(model, val_loader, loss_fn)
            
            # Update learning rate
            scheduler.step()

            print(f'Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:4.6f}, Val Loss: {val_loss:4.6f}')
            # Save results
            if RUN_KFOLD:
                fold_results['train_losses'].append(train_loss)
                fold_results['val_losses'].append(val_loss)

            if PLOT:
                # Update plot
                plotter.update((train_loss, val_loss))
            
            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_epoch = epoch
                patience_counter = 0
            else:
                patience_counter += 1
            
            if patience_counter >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break
        print(f'Best Validation Loss: {best_val_loss} at epoch {best_epoch+1}')
    except Exception as e:
        traceback.print_exc()
        print(e)
    finally:
        if PLOT:
            plotter.stop(save_dir=snapshot_path(model))


Getting dataloaders...
Epoch 1 of 5

Training: 100%|██████████| 21/21 [00:12<00:00,  1.63it/s, Data load time: 0.21s]
Validating: 100%|██████████| 5/5 [00:01<00:00,  4.08it/s, Data load time: 0.07s]


Epoch 1/5 | Train Loss: 58.709553, Val Loss: 0.210978
Epoch 2 of 5

Training: 100%|██████████| 21/21 [00:11<00:00,  1.78it/s, Data load time: 0.10s]
Validating: 100%|██████████| 5/5 [00:01<00:00,  4.69it/s, Data load time: 0.08s]


Epoch 2/5 | Train Loss: 13.627794, Val Loss: 0.160731
Epoch 3 of 5

Training: 100%|██████████| 21/21 [00:12<00:00,  1.67it/s, Data load time: 0.11s]
Validating: 100%|██████████| 5/5 [00:01<00:00,  4.82it/s, Data load time: 0.06s]


Epoch 3/5 | Train Loss: 8.764475, Val Loss: 0.147916
Epoch 4 of 5

Training: 100%|██████████| 21/21 [00:12<00:00,  1.71it/s, Data load time: 0.12s]
Validating: 100%|██████████| 5/5 [00:01<00:00,  4.33it/s, Data load time: 0.07s]


Epoch 4/5 | Train Loss: 8.074298, Val Loss: 0.186122
Epoch 5 of 5

Training: 100%|██████████| 21/21 [00:13<00:00,  1.54it/s, Data load time: 0.12s]
Validating: 100%|██████████| 5/5 [00:01<00:00,  3.70it/s, Data load time: 0.08s]


Epoch 5/5 | Train Loss: 8.693830, Val Loss: 0.159091
Best Validation Loss: 0.14791629910469056 at epoch 3


In [7]:
if RUN_KFOLD:
    print("\nK-Fold Results Summary:")
    for i, (train_loss, val_loss) in enumerate(zip(fold_results['train_losses'], fold_results['val_losses'])):
        print(f"Epoch {i+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")


In [8]:
from tqdm import tqdm
reload(src.training.helpers)
from src.training.helpers import *
import time
# Convert loss to accuracy for sound data
def compute_accuracy(y_pred, y_true, threshold=0.05):
        """
        Computes accuracy for sound data by comparing predicted and true waveforms.
        Accuracy is defined as the percentage of samples where the absolute error is below a threshold.
        Args:
            y_pred (torch.Tensor): Predicted waveform, shape (batch, 1, length)
            y_true (torch.Tensor): Ground truth waveform, shape (batch, 1, length)
            threshold (float): Maximum absolute error to count as correct
        Returns:
            float: Accuracy as a percentage (0-100)
        """
        # Ensure shapes are compatible
        if y_pred.shape != y_true.shape:
            raise ValueError(f"Shape mismatch: y_pred {y_pred.shape}, y_true {y_true.shape}")
        # Compute absolute error
        abs_error = torch.abs(y_pred - y_true)
        
        # Count number of elements below threshold
        correct = (abs_error < threshold).float().sum()
        total = abs_error.numel()
        accuracy = (correct / total) * 100.0
        return accuracy

test_loader = get_dataloader(test_set, learning_options['batch_size'], shuffle=False, device=model.device)
# test_plotter = TkLossPlotter(refresh_hz=10)
# test_plotter.start(title='Test Accuracy')
acc = evaluate_model(model, test_loader, loss_fn=compute_accuracy)#, plot_updater=test_plotter.update)
print(f'Test Accuracy: {acc:.2f}%')
# test_plotter.stop(save_dir=snapshot_path(model))


Validating: 100%|██████████| 4/4 [00:01<00:00,  3.33it/s, Data load time: 0.11s]

Test Accuracy: 63.14%





In [9]:
from cv2 import norm
import numpy as np
import torch
import IPython.display as ipd
from src.data.preprocessing_helpers import undo_button
def norm_audio(x):
    x = x.astype(np.float32)
    maxv = np.max(np.abs(x))
    return x / maxv if maxv > 0 else x

def play_audio_samples(model, dataset, device, idx=0, sample_rate=16000):
    """
    Play an audible test sample: noisy input, clean target, and denoised output.
    Args:
        model: Trained model.
        dataset: Dataset object (should return (noisy, clean) pairs).
        device: Torch device.
        idx: Index of the sample to play.
        sample_rate: Audio sample rate for playback.
    """
    # Get a sample
    noisy, clean = dataset[idx]
    # Ensure shape is (1, length)
    if noisy.ndim == 1:
        noisy = noisy.unsqueeze(0)
    if clean.ndim == 1:
        clean = clean.unsqueeze(0)
    # Add batch dimension
    noisy = noisy.unsqueeze(0).to(device)
    clean = clean.unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        denoised = model(noisy.reshape(1, 1, -1))
    # Remove batch and channel dimensions
    noisy_np = noisy.squeeze().cpu().numpy()
    clean_np = clean.squeeze().cpu().numpy()
    denoised_np = denoised.squeeze().cpu().numpy()
    # Normalize for playback (avoid clipping)


    
    print("Noisy input:")
    display(ipd.Audio(norm_audio(noisy_np), rate=sample_rate))
    print("Clean target:")
    display(ipd.Audio(norm_audio(clean_np), rate=sample_rate))
    print("Denoised output:")
    display(ipd.Audio(norm_audio(denoised_np), rate=sample_rate))

# Example usage:
play_audio_samples(model, test_set, device, idx=0, sample_rate=16000)


Noisy input:


Clean target:


Denoised output:


In [None]:
# # Compare to Facebook's Demucs

# # Download and run Demucs on the same sample for comparison
# import subprocess
# import torchaudio
# import tempfile
# import os

# def run_demucs_on_sample(noisy_np, sample_rate=16000):
#     """
#     Run Facebook's Demucs on a numpy audio array and return the denoised output.
#     Requires Demucs to be installed (`pip install demucs`).
#     """
#     # Save noisy audio to a temporary wav file
#     with tempfile.TemporaryDirectory() as tmpdir:
#         noisy_path = os.path.join(tmpdir, "noisy.wav")
#         outdir = os.path.join(tmpdir, "demucs_out")
#         torchaudio.save(noisy_path, torch.tensor(noisy_np).unsqueeze(0), sample_rate)
#         # Run Demucs CLI
#         # Use --two-stems=vocals to get vocals only, or --mp3 for mp3 output, but here we want wav
#         # Use --shifts=0 for deterministic output
#         cmd = [
#             "demucs",
#             "--two-stems=vocals",
#             "--shifts", "0",
#             "--out", outdir,
#             noisy_path
#         ]
#         try:
#             subprocess.run(cmd, check=True, capture_output=True)
#         except Exception as e:
#             print("Error running Demucs:", e)
#             return None
#         # Demucs output is in outdir/demucs_noisy/vocals.wav
#         # Find the output file
#         demucs_dir = os.path.join(outdir, "demucs_noisy")
#         vocals_path = os.path.join(demucs_dir, "vocals.wav")
#         if not os.path.exists(vocals_path):
#             # fallback: try "noisy/vocals.wav"
#             vocals_path = os.path.join(outdir, "noisy", "vocals.wav")
#         if not os.path.exists(vocals_path):
#             print("Demucs output not found.")
#             return None
#         # Load the denoised audio
#         waveform, sr = torchaudio.load(vocals_path)
#         return waveform.squeeze().cpu().numpy(), sr

# # Run Demucs on the same test sample
# print("Running Demucs on the same noisy sample for comparison...")
# noisy_np = test_set[0][0].squeeze().cpu().numpy()
# demucs_result = run_demucs_on_sample(noisy_np, sample_rate=16000)
# if demucs_result is not None:
#     demucs_np, demucs_sr = demucs_result
#     print("Demucs output:")
#     display(ipd.Audio(norm_audio(demucs_np), rate=demucs_sr))
# else:
#     print("Demucs output unavailable. Make sure Demucs is installed and in your PATH.")




Running Demucs on the same noisy sample for comparison...
Demucs output not found.
Demucs output unavailable. Make sure Demucs is installed and in your PATH.
