1. Offline augmentation (przed treningiem)
   - zbalansowanie klas
   - augmentacja waveformów
   - augmentacja spektrogramów
   - wygenerowanie przykładów dla klasy silence
   - wygenerowanie spektrogramów

In [4]:
import random
import torchaudio
from WaveformAugmentations import WaveformAugmentations
from IPython.display import Audio
import matplotlib.pyplot as plt
import os
import torchaudio.transforms as T
import torch
from tqdm import tqdm

SAMPLE_RATE = 16000

## Waveform Augmentation

In [5]:
def load_background_waveforms(background_noise_dir='data/train/audio/_background_noise_'):
    background_waveforms = []
    for file in os.listdir(background_noise_dir):
        if file.endswith('.wav'):
            try:
                path = os.path.join(background_noise_dir, file)
                waveform, sr = torchaudio.load(path)
                waveform = waveform.squeeze(0)  # remove channel dimension
                if sr != SAMPLE_RATE:
                    resampler = T.Resample(orig_freq=sr, new_freq=SAMPLE_RATE)
                    waveform = resampler(waveform)
                background_waveforms.append(waveform)
                print(f"Loaded background: {file}")
            except Exception as e:
                print(f"Error loading {file}: {e}")
    return background_waveforms

def create_augmentation_pipeline(sample_rate, background_waveforms=None, rir_waveforms=None, p=0.5):
    """Create a function that applies random augmentations with probability p"""
    def augment_waveform(waveform):
        augmented = waveform.clone()
        
        if random.random() < p:
            augmented = WaveformAugmentations.time_shift(augmented)
        
        if random.random() < p:
            augmented = WaveformAugmentations.add_noise(augmented, noise_level=random.uniform(0.001, 0.005))
        
        if random.random() < p:
            augmented = WaveformAugmentations.pitch_shift(augmented, sample_rate)
        
        if random.random() < p:
            augmented = WaveformAugmentations.volume_control(augmented)
        
        if random.random() < p*0.6:
            augmented = WaveformAugmentations.speed_change(augmented, sample_rate)
            
        if random.random() < p*0.6:
            augmented = WaveformAugmentations.reverb(augmented, sample_rate)
        
        if background_waveforms and random.random() < p*0.8:
            bg_waveform = random.choice(background_waveforms)
            augmented = WaveformAugmentations.mix_background(augmented, bg_waveform)
        
        if rir_waveforms and random.random() < p*0.6:
            rir_waveform = random.choice(rir_waveforms)
            augmented = WaveformAugmentations.convolution_reverb(augmented, rir_waveform)
            
        return augmented
    return augment_waveform

Test augmentation pipeline

In [6]:
background_waveforms = load_background_waveforms()
augmentation_pipeline = create_augmentation_pipeline(sample_rate=SAMPLE_RATE, p=0.2)

wav_path = r"data\train\audio\yes\0c2ca723_nohash_0.wav"
waveform, sample_rate = torchaudio.load(wav_path)
waveform = waveform.squeeze(0)  # remove channel dimension
Audio(waveform.numpy(), rate=sample_rate)

Loaded background: doing_the_dishes.wav
Loaded background: dude_miaowing.wav
Loaded background: exercise_bike.wav
Loaded background: pink_noise.wav
Loaded background: running_tap.wav
Loaded background: white_noise.wav


In [7]:
augmented_waveform = augmentation_pipeline(waveform)
Audio(augmented_waveform.numpy(), rate=sample_rate)

In [8]:
def generate_silence_class():
    pass

def balance_dataset():
    pass

In [9]:
def preprocess_with_augmentation(
    audio_dir, 
    output_dir,
    sample_rate=16000,
    n_mels=64,
    n_fft=400,
    hop_length=200,
    augment_probability=0.5,
    augmented_copies=1,  # Number of augmented copies to generate per original sample
):
    os.makedirs(output_dir, exist_ok=True)
    
    # Load background and RIR samples for augmentation
    background_waveforms = load_background_waveforms()
    
    # Create augmentation pipeline
    augment_fn = create_augmentation_pipeline(sample_rate, background_waveforms, p=augment_probability)
    
    # Create transforms
    mel_spectrogram = T.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
    )
    db_transform = T.AmplitudeToDB()
    
    # Get all audio files
    audio_files = []
    for root, _, files in os.walk(audio_dir):
        for file in files:
            if file.endswith('.wav'):
                audio_files.append(os.path.relpath(os.path.join(root, file), audio_dir))
    
    print(f"Processing {len(audio_files)} audio files...")
    
    # Track class distribution
    class_counts = {}
    for audio_file in tqdm(audio_files, desc="Processing audio files"):
        # Extract class name from the path
        class_name = os.path.dirname(audio_file)
        class_counts[class_name] = class_counts.get(class_name, 0) + 1
        
        # Load audio
        audio_path = os.path.join(audio_dir, audio_file)
        waveform, orig_sr = torchaudio.load(audio_path)
        waveform = waveform.squeeze(0)  # remove channel dimension
        
        # Resample if necessary
        if orig_sr != sample_rate:
            resampler = T.Resample(orig_freq=orig_sr, new_freq=sample_rate)
            waveform = resampler(waveform)
        
        # Save original spectrogram
        spec = mel_spectrogram(waveform)
        spec = db_transform(spec)
        output_path = os.path.join(output_dir, os.path.splitext(audio_file)[0] + '.pt')
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        torch.save(spec, output_path)
        
        # Generate augmented versions
        for i in range(augmented_copies):
            augmented = augment_fn(waveform)
            
            # Generate spectrogram for augmented waveform
            aug_spec = mel_spectrogram(augmented)
            aug_spec = db_transform(aug_spec)
            
            # Save augmented spectrogram with suffix
            aug_output_path = os.path.join(
                output_dir, 
                os.path.splitext(audio_file)[0] + f'_aug{i+1}.pt'
            )
            os.makedirs(os.path.dirname(aug_output_path), exist_ok=True)
            torch.save(aug_spec, aug_output_path)
    
    print(f"All spectrograms saved to {output_dir}")
    print(f"Class distribution: {class_counts}")
    
    # Plot class distribution
    plt.figure(figsize=(12, 6))
    plt.bar(class_counts.keys(), class_counts.values())
    plt.xticks(rotation=45, ha='right')
    plt.title('Class Distribution')
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.tight_layout()
    plt.show()
    
    return class_counts

In [None]:
# Run the augmented preprocessing with 2 augmented copies per original sample
class_counts = preprocess_with_augmentation(
    audio_dir="data/train/audio", 
    output_dir="data/train/spectrograms",
    augment_probability=0.1,  
    augmented_copies=1,
    sample_rate=SAMPLE_RATE,
    n_mels=64,
    n_fft=400,
    hop_length=200,
)

Loaded background: doing_the_dishes.wav
Loaded background: dude_miaowing.wav
Loaded background: exercise_bike.wav
Loaded background: pink_noise.wav
Loaded background: running_tap.wav
Loaded background: white_noise.wav
Processing 64727 audio files...


Processing audio files:   1%|          | 433/64727 [01:31<3:45:27,  4.75it/s] 


KeyboardInterrupt: 