In [11]:
import os
import torchaudio
from torch.utils.data import Dataset
import random
import os
import random
import hashlib
import torchaudio
import re
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

class SpeechCommandsDataset(Dataset):
    def __init__(self, dataset_path, unknown_label='unknown', target_background_samples=2375, selected_words=None):
        self.selected_words = selected_words if selected_words is not None else []
        self.audio_labels = []
        self.audio_paths = []
        self.label_to_index = {unknown_label: 0}
        self.total_size_bytes = 0
        self.target_background_samples = target_background_samples
        self.background_samples_added = 0
        all_labels = sorted(os.listdir(dataset_path))

        for label in all_labels:
            if label.startswith('_') or (self.selected_words and label not in self.selected_words):
                continue  
            label_path = os.path.join(dataset_path, label)
            if os.path.isdir(label_path):
                label_index = len(self.label_to_index)
                self.label_to_index[label] = label_index

                for audio_file in os.listdir(label_path):
                    if audio_file.endswith('.wav'):
                        file_path = os.path.join(label_path, audio_file)
                        self.audio_paths.append(file_path)
                        self.audio_labels.append(label_index)
                        self.total_size_bytes += os.path.getsize(file_path)

    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio_info = self.audio_paths[idx]
        label = self.audio_labels[idx]
        if isinstance(audio_info, tuple):
            waveform, sample_rate = torchaudio.load(audio_info[0], frame_offset=audio_info[1], num_frames=audio_info[2]-audio_info[1])
        else:
            waveform, sample_rate = torchaudio.load(audio_info)
        waveform = waveform.squeeze()
        if waveform.numel() == 0:  # Check if the waveform is empty
            print(f"Empty waveform at idx: {idx}, path: {audio_info}")
        return waveform, label

In [2]:
MAX_NUM_WAVS_PER_CLASS = 2**27 - 1

def which_set(filename, validation_percentage, testing_percentage):
    """Determine the dataset split based on filename hashing."""
    base_name = os.path.basename(filename)
    hash_name = re.sub(r'_nohash_.*$', '', base_name)
    hash_name_hashed = hashlib.sha1(hash_name.encode('utf-8')).hexdigest()
    percentage_hash = (int(hash_name_hashed, 16) % (MAX_NUM_WAVS_PER_CLASS + 1)) * (100.0 / MAX_NUM_WAVS_PER_CLASS)
    
    if percentage_hash < validation_percentage:
        return 'validation'
    elif percentage_hash < (testing_percentage + validation_percentage):
        return 'testing'
    else:
        return 'training'

def distribute_files_by_label(dataset, validation_percentage, testing_percentage):
    """Distribute files into training, validation, and testing sets based on labels."""
    label_indices = {
        'yes': dataset.label_to_index['yes'],
        'no': dataset.label_to_index['no'],
        'unknown': dataset.label_to_index.get('unknown', -1)  # Default -1 if 'unknown' is not predefined
    }
    sets = {label: {'training': [], 'validation': [], 'testing': []} for label in label_indices}

    for file_path, label in zip(dataset.audio_paths, dataset.audio_labels):
        set_type = which_set(file_path, validation_percentage, testing_percentage)
        for label_name, index in label_indices.items():
            if label == index:
                sets[label_name][set_type].append(file_path)
                break

    return sets

def balance_and_shuffle_files(sets):
    """Balance the number of unknown samples to match the number of yes samples, then shuffle each set."""
    num_yes_train = len(sets['yes']['training'])
    num_yes_validation = len(sets['yes']['validation'])
    num_yes_testing = len(sets['yes']['testing'])

    random.shuffle(sets['unknown']['training'])
    random.shuffle(sets['unknown']['validation'])
    random.shuffle(sets['unknown']['testing'])

    sets['unknown']['training'] = sets['unknown']['training'][:num_yes_train]
    sets['unknown']['validation'] = sets['unknown']['validation'][:num_yes_validation]
    sets['unknown']['testing'] = sets['unknown']['testing'][:num_yes_testing]

    train_files = sets['yes']['training'] + sets['no']['training'] + sets['unknown']['training']
    validation_files = sets['yes']['validation'] + sets['no']['validation'] + sets['unknown']['validation']
    test_files = sets['yes']['testing'] + sets['no']['testing'] + sets['unknown']['testing']

    random.shuffle(train_files)
    random.shuffle(validation_files)
    random.shuffle(test_files)

    return train_files, validation_files, test_files

def process_background_noise(background_noise_path, segment_length, target_samples, num_train, num_val, num_test):
    """Process and distribute background noise into dataset segments."""
    noise_files = [f for f in os.listdir(background_noise_path) if f.endswith('.wav')]
    random.shuffle(noise_files)
    
    audio_segments = []
    while len(audio_segments) < target_samples and noise_files:
        noise_file = random.choice(noise_files)
        noise_path = os.path.join(background_noise_path, noise_file)
        waveform, sample_rate = torchaudio.load(noise_path)
        total_samples = waveform.size(1)
        samples_per_segment = int(sample_rate * segment_length)
        max_start = total_samples - samples_per_segment
        
        if max_start > 0:
            start = random.randint(0, max_start)
            end = start + samples_per_segment
            audio_segments.append((noise_path, start, end))
            if len(audio_segments) >= target_samples:
                break

    if len(audio_segments) < (num_train + num_val + num_test):
        raise ValueError("Not enough samples collected to meet the required distribution.")

    random.shuffle(audio_segments)
    train = audio_segments[:num_train]
    val = audio_segments[num_train:num_train + num_val]
    test = audio_segments[num_train + num_val:num_train + num_val + num_test]

    return train, val, test


def save_segments(segments, save_path, prefix):
    """Save segments of audio to disk."""
    saved_paths = []
    os.makedirs(save_path, exist_ok=True)
    for i, (path, start, end) in enumerate(segments):
        waveform, sample_rate = torchaudio.load(path, frame_offset=start, num_frames=end-start)
        segment_path = os.path.join(save_path, f"{prefix}_{i}.wav")
        torchaudio.save(segment_path, waveform, sample_rate)
        saved_paths.append(segment_path)
    return saved_paths

In [3]:
DATASET_PATH = './train/train/audio'
BACKGROUND_NOISE_PATH = './train/train/audio/_background_noise_'
dataset = SpeechCommandsDataset(DATASET_PATH, BACKGROUND_NOISE_PATH)

# Set distribution percentages
validation_percentage = 10.0
testing_percentage = 10.0

# Distribute files
distributed_sets = distribute_files_by_label(dataset, validation_percentage, testing_percentage)

# Balance and shuffle unknown samples to match the number of yes samples
train_files, validation_files, test_files = balance_and_shuffle_files(distributed_sets)

# Display the result counts
print("Training files count:", len(train_files))
print("Validation files count:", len(validation_files))
print("Testing files count:", len(test_files))

Training files count: 3713
Validation files count: 531
Testing files count: 508


In [4]:
yes_count_train = sum(1 for file_path in train_files if 'yes' in file_path.split(os.sep)[-2])
yes_count_val = sum(1 for file_path in validation_files if 'yes' in file_path.split(os.sep)[-2])
yes_count_test = sum(1 for file_path in test_files if 'yes' in file_path.split(os.sep)[-2])

In [5]:
target_background_samples = len(train_files)
num_train_samples = yes_count_train
num_val_samples = yes_count_val
num_test_samples = yes_count_test
segment_length=1

train_unknown, val_unknown, test_unknown = process_background_noise(
    BACKGROUND_NOISE_PATH, segment_length, target_background_samples, 
    num_train_samples, num_val_samples, num_test_samples
)


In [6]:
train_unknown_save_path = './train/train/audio/silence/train'
val_unknown_save_path = './train/train/audio/silence/val'
test_unknown_save_path = './train/train/audio/silence/test'

train_silence_files = save_segments(train_unknown, train_unknown_save_path, 'train_silence')
val_silence_files = save_segments(val_unknown, val_unknown_save_path, 'val_silence')
test_silence_files = save_segments(test_unknown, test_unknown_save_path, 'test_silence')


In [7]:
train_files_set = set(train_files)
validation_files_set = set(validation_files)
test_files_set = set(test_files)

train_files_set.update(train_silence_files)
validation_files_set.update(val_silence_files)
test_files_set.update(test_silence_files)

train_files = list(train_files_set)
validation_files = list(validation_files_set)
test_files = list(test_files_set)

random.shuffle(train_files)
random.shuffle(validation_files)
random.shuffle(test_files)

In [8]:
print("Final training files count:", len(train_files))
print("Final validation files count:", len(validation_files))
print("Final testing files count:", len(test_files))

Final training files count: 5573
Final validation files count: 792
Final testing files count: 764


In [12]:
class AudioDataset(Dataset):
    def __init__(self, file_paths):
        self.file_paths = file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        waveform, sample_rate = torchaudio.load(self.file_paths[idx])
        return waveform, sample_rate

train_dataset = AudioDataset(train_files)
validation_dataset = AudioDataset(validation_files)
test_dataset = AudioDataset(test_files)

In [13]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
