In [12]:
from collections import defaultdict
from math import sqrt
from torch.utils.data import DataLoader, random_split
from AudioDataset import AudioDataset, pad_collate_fn
from PerClassNormalize import PerClassNormalize

# Create full dataset with set number of samples in each class

In [13]:
# Setting the params for AudioDataset

root_dir = r"C:\Users\monik\OneDrive\Pulpit\Projekt Warsztaty\data\train" # Replace with Your filepath
cache_dir = r"C:\Users\monik\OneDrive\Pulpit\Projekt Warsztaty\cache_spectrograms" # Replace with your filepath
allowed_classes = ["yes","no","up","down","left","right","on","off","stop","go"] # Add "silence" if needed for experiments
max_samples_per_class = 100 # Set desired number of samples in each class 

In [14]:
# Create full dataset
dataset = AudioDataset(
    root_dir            = root_dir,
    cache_dir           = cache_dir,
    export_dir          = None,
    preprocess          = True,              
    transform           = None,              
    allowed_classes     = allowed_classes,
    save_spectrograms   = False,
    max_samples_per_class= max_samples_per_class 
)

# Compute split sizes
train_frac = 0.8
train_size = int(train_frac * len(dataset))
val_size   = len(dataset) - train_size

# Randomly split into two subsets
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Check
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")


Train samples: 880
Validation samples: 220


# Find mean and std for each class for normalization

In [15]:
sums   = defaultdict(lambda: 0.0)
sqs    = defaultdict(lambda: 0.0)
counts = defaultdict(lambda: 0)

for spec, label in train_dataset:
    # spec shape: (C, H, W) or (1, freq_bins, time_steps)
    n_elems = spec.numel()
    sums[label]   += spec.sum().item()
    sqs[label]    += (spec**2).sum().item()
    counts[label] += n_elems

# now build per-class means & stds
class_means = {}
class_stds  = {}

for lbl in counts:
    mu = sums[lbl] / counts[lbl]
    var = (sqs[lbl] / counts[lbl]) - mu**2
    class_means[lbl] = mu
    class_stds[lbl]  = sqrt(var)

# Example: map from label index -> class name
idx2class = {i:c for i,c in enumerate(dataset.allowed)}

print("Per-class means & stds:")
for idx, cls in idx2class.items():
    print(f"  {cls:5s} → mean={class_means[idx]:.4f}, std={class_stds[idx]:.4f}")


Per-class means & stds:
  go    → mean=-29.8854, std=20.9561
  left  → mean=-34.1892, std=21.7751
  off   → mean=-31.6523, std=20.2902
  on    → mean=-32.1178, std=21.0336
  down  → mean=-31.4741, std=20.8006
  yes   → mean=-29.7448, std=21.3592
  stop  → mean=-34.5246, std=21.9527
  right → mean=-32.6826, std=21.3948
  no    → mean=-35.7149, std=20.5595
  up    → mean=-31.8529, std=21.7311


# Create full dataset again with normalization added

In [16]:
# create the normalization transform
norm_transform = PerClassNormalize(class_means, class_stds)

# create the same dataset but with normalization
dataset = AudioDataset(
    root_dir            = root_dir,
    cache_dir           = cache_dir,
    export_dir          = None,
    preprocess          = True,              
    transform           = norm_transform,      # <-- apply normalization              
    allowed_classes     = allowed_classes,
    save_spectrograms   = False,
    max_samples_per_class= max_samples_per_class 
)

# split again (or reuse your existing train_dataset/val_dataset):
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Wrap in DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                          collate_fn=pad_collate_fn, num_workers=4, pin_memory=True)

val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False,
                          collate_fn=pad_collate_fn, num_workers=4, pin_memory=True)


After running above code (and setting desired parameters) you will get 2 loaders object:

- **`train_loader`**: Supplies shuffled batches from your training split.  
- **`val_loader`**: Supplies (usually un-shuffled) batches from your validation split for monitoring generalization.

Those can be used for further experiments.