In [1]:
from pathlib import Path
from typing import List, Tuple, Optional

import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio

import numpy as np
import matplotlib.pyplot as plt

In [2]:
class SpeechCommandsDataset(Dataset):
    def __init__(self, root_dir: str, transform=None, mode: str = "original",
        commands: Optional[List[str]] = None,  # list of command labels
    ):
        assert mode in {"original", "modified"}, "mode must be 'original' or 'modified'"
        
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.mode = mode
        
        # Set the known commands if provided, otherwise use default
        if commands is None:
            self.commands = [
                "yes", "no", "up", "down", "left", "right", "on", "off", "stop", "go"
            ]
        else:
            self.commands = commands

        # Build samples
        self.samples = []
        all_labels = sorted({p.name for p in self.root_dir.iterdir() if p.is_dir()})
        
        if self.mode == "original":
            self.labels = all_labels
        else:  # modified
            self.labels = sorted(self.commands + ["unknown"])
        
        self.label_to_index = {label: idx for idx, label in enumerate(self.labels)}

        for label in all_labels:
            label_dir = self.root_dir / label
            for wav_file in label_dir.glob("**/*.wav"):
                if self.mode == "original":
                    target_label = label
                else:
                    target_label = label if label in self.commands else "unknown"
                self.samples.append((wav_file, self.label_to_index[target_label]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx) -> Tuple[torch.Tensor, int]:
        audio_path, label = self.samples[idx]
        waveform, sample_rate = torchaudio.load(audio_path)

        if self.transform:
            waveform = self.transform(waveform)

        return waveform, label

In [3]:
def collate_fn(batch):
    waveforms, labels = zip(*batch)
    return list(waveforms), torch.tensor(labels)

In [7]:
train_dataset = SpeechCommandsDataset("./../../data/train", mode="modified")

train_loader = DataLoader(
    train_dataset, batch_size=40, shuffle=True, collate_fn=collate_fn
)

for batch in train_loader:
    waveforms, labels = batch
    for i, waveform in enumerate(waveforms):
        print(f"Waveform {i} shape: {waveform.shape}")
    print(labels)
    break

Waveform 0 shape: torch.Size([1, 16000])
Waveform 1 shape: torch.Size([1, 16000])
Waveform 2 shape: torch.Size([1, 16000])
Waveform 3 shape: torch.Size([1, 16000])
Waveform 4 shape: torch.Size([1, 9558])
Waveform 5 shape: torch.Size([1, 16000])
Waveform 6 shape: torch.Size([1, 16000])
Waveform 7 shape: torch.Size([1, 16000])
Waveform 8 shape: torch.Size([1, 16000])
Waveform 9 shape: torch.Size([1, 16000])
Waveform 10 shape: torch.Size([1, 16000])
Waveform 11 shape: torch.Size([1, 16000])
Waveform 12 shape: torch.Size([1, 16000])
Waveform 13 shape: torch.Size([1, 12971])
Waveform 14 shape: torch.Size([1, 16000])
Waveform 15 shape: torch.Size([1, 16000])
Waveform 16 shape: torch.Size([1, 16000])
Waveform 17 shape: torch.Size([1, 16000])
Waveform 18 shape: torch.Size([1, 16000])
Waveform 19 shape: torch.Size([1, 16000])
Waveform 20 shape: torch.Size([1, 16000])
Waveform 21 shape: torch.Size([1, 16000])
Waveform 22 shape: torch.Size([1, 16000])
Waveform 23 shape: torch.Size([1, 9558])
Wave

In [8]:
from collections import Counter

label_counts = Counter()

for _, label in train_dataset.samples:
    label_counts[label] += 1

index_to_label = {idx: label for label, idx in train_dataset.label_to_index.items()}

for label_idx, count in label_counts.items():
    print(f"{index_to_label[label_idx]}: {count} samples")


unknown: 34123 samples
down: 1842 samples
go: 1861 samples
left: 1839 samples
no: 1853 samples
off: 1839 samples
on: 1864 samples
right: 1852 samples
stop: 1885 samples
up: 1843 samples
yes: 1860 samples


In [1]:
from src.Dataset import SpeechCommandsDataset
from collections import Counter
train_dataset = SpeechCommandsDataset("./../../data/train", mode="modified")
label_counts = Counter()

for _, label in train_dataset.samples:
    label_counts[label] += 1

index_to_label = {idx: label for label, idx in train_dataset.label_to_index.items()}

for label_idx, count in label_counts.items():
    print(f"{index_to_label[label_idx]}: {count} samples")

unknown: 32550 samples
down: 1842 samples
go: 1861 samples
left: 1839 samples
no: 1853 samples
off: 1839 samples
on: 1864 samples
right: 1852 samples
silence: 1573 samples
stop: 1885 samples
up: 1843 samples
yes: 1860 samples
