1. Offline augmentation (przed treningiem)
   - zbalansowanie klas
   - augmentacja waveformów
   - augmentacja spektrogramów
   - wygenerowanie przykładów dla klasy silence
   - wygenerowanie spektrogramów

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.transforms as T

class WaveformAugmentations:
    """Collection of audio augmentation techniques"""
    
    @staticmethod
    def time_shift(waveform, shift_limit=0.4):
        """Shifts the audio along the time axis"""
        shift = int(np.random.uniform(-shift_limit, shift_limit) * waveform.shape[1])
        if shift > 0:
            waveform = torch.cat([waveform[:, shift:], torch.zeros(1, shift)], dim=1)
        elif shift < 0:
            shift = -shift
            waveform = torch.cat([torch.zeros(1, shift), waveform[:, :-shift]], dim=1)
        return waveform
    
    @staticmethod
    def add_noise(waveform, noise_level=0.005):
        """Adds random noise to the waveform"""
        noise = torch.randn_like(waveform) * noise_level
        return waveform + noise
    
    @staticmethod
    def pitch_shift(waveform, sample_rate, pitch_shift_limit=5):
        """Shifts the pitch of the audio"""
        pitch_shift = np.random.randint(-pitch_shift_limit, pitch_shift_limit + 1)
        pitch_shift_transform = T.PitchShift(sample_rate, n_steps=pitch_shift)
        return pitch_shift_transform(waveform)

class SpectrogramAugmentations:
    """Collection of spectrogram augmentation techniques"""
    
    @staticmethod
    def time_masking(spectrogram, time_mask_param=10):
        """Apply time masking to the spectrogram"""
        return T.TimeMasking(time_mask_param)(spectrogram)
    
    @staticmethod
    def freq_masking(spectrogram, freq_mask_param=10):
        """Apply frequency masking to the spectrogram"""
        return T.FrequencyMasking(freq_mask_param)(spectrogram)
    
    @staticmethod
    def time_stretch(spectrogram, stretch_factor=0.8):
        """Time stretching for spectrograms"""
        time_stretch_transform = T.TimeStretch(n_freq=spectrogram.shape[1], fixed_rate=stretch_factor)
        return time_stretch_transform(spectrogram.unsqueeze(0)).squeeze(0)


In [None]:
import os
import torch
import torchaudio.transforms as T
import soundfile as sf
from tqdm import tqdm

def preprocess_and_save_spectrograms(
    audio_dir, 
    output_dir, 
    sample_rate=16000,
    n_mels=64, 
    n_fft=400, 
    hop_length=200,
    waveform_transform=None,
    spectrogram_transform=None,
):
    os.makedirs(output_dir, exist_ok=True)
    
    mel_spectrogram = T.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
    )
    db_transform = T.AmplitudeToDB()
    
    print(f"Loading audio files from {audio_dir}...")
    audio_files = []
    for root, _, files in os.walk(audio_dir):
        for file in files:
            if file.endswith('.wav'):
                audio_files.append(os.path.relpath(os.path.join(root, file), audio_dir))

    print(f"Processing {len(audio_files)} audio files...")
    for audio_file in tqdm(audio_files, desc="Processing audio files"):
        # Load audio
        audio_path = os.path.join(audio_dir, audio_file)
        data, orig_sr = sf.read(audio_path)
        waveform = torch.tensor(data, dtype=torch.float32).unsqueeze(0)

        if waveform_transform:
            waveform = waveform_transform(waveform)

        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Resample if necessary
        if orig_sr != sample_rate:
            resampler = T.Resample(orig_freq=orig_sr, new_freq=sample_rate)
            waveform = resampler(waveform)
        
        spec = mel_spectrogram(waveform)
        spec = db_transform(spec)

        if spectrogram_transform:
            spec = spectrogram_transform(spec)

        output_path = os.path.join(output_dir, os.path.splitext(audio_file)[0] + '.pt')
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        torch.save(spec, output_path)
    print(f"Spectrograms saved to {output_dir}")

preprocess_and_save_spectrograms("data/train/audio", "data/train/spectrograms")


In [None]:
def generate_silence_class():
    pass

def balance_dataset():
    pass


In [None]:
import os
import torch
import torchaudio
import matplotlib.pyplot as plt
import pandas as pd

from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import soundfile as sf
import torchaudio.transforms as T


# Global transforms for spectrogram generation
mel_transform = T.MelSpectrogram(sample_rate=16000, n_mels=64)
db_transform = T.AmplitudeToDB()

class AudioSpectrogramsDataset(Dataset):
    pass