In [1]:
import os
import torchaudio
from torch.utils.data import Dataset
import random
import os
import random
import hashlib
import torchaudio
import re
import numpy as np
import torch
from torch.utils.data import Dataset

random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

class SpeechCommandsDataset(Dataset):
    def __init__(self, dataset_path, unknown_label='unknown', selected_words=None):
        self.selected_words = selected_words if selected_words is not None else ['yes', 'no']
        self.audio_labels = []
        self.audio_paths = []
        # Initialize label_to_index with 'unknown' mapped to 0
        self.label_to_index = {unknown_label: 0, 'yes': 1, 'no': 2}
        self.total_size_bytes = 0
        all_labels = sorted(os.listdir(dataset_path))

        for label in all_labels:
            if label.startswith('_'):  # Skip the '_background_noise_' directory
                continue
            label_path = os.path.join(dataset_path, label)
            if os.path.isdir(label_path):
                # If the label is not in selected words, it's 'unknown'
                label_index = self.label_to_index.get(label, self.label_to_index[unknown_label])
                
                for audio_file in os.listdir(label_path):
                    if audio_file.endswith('.wav'):
                        file_path = os.path.join(label_path, audio_file)
                        self.audio_paths.append(file_path)
                        self.audio_labels.append(label_index)
                        self.total_size_bytes += os.path.getsize(file_path)

    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio_info = self.audio_paths[idx]
        label = self.audio_labels[idx]
        waveform, sample_rate = torchaudio.load(audio_info)
        waveform = waveform.squeeze()
        if waveform.numel() == 0: 
            print(f"Empty waveform at idx: {idx}, path: {audio_info}")
        return waveform, label

In [2]:
MAX_NUM_WAVS_PER_CLASS = 2**27 - 1

def which_set(filename, validation_percentage, testing_percentage):
    """Determine the dataset split based on filename hashing."""
    base_name = os.path.basename(filename)
    hash_name = re.sub(r'_nohash_.*$', '', base_name)
    hash_name_hashed = hashlib.sha1(hash_name.encode('utf-8')).hexdigest()
    percentage_hash = (int(hash_name_hashed, 16) % (MAX_NUM_WAVS_PER_CLASS + 1)) * (100.0 / MAX_NUM_WAVS_PER_CLASS)
    
    if percentage_hash < validation_percentage:
        return 'validation'
    elif percentage_hash < (testing_percentage + validation_percentage):
        return 'testing'
    else:
        return 'training'

def distribute_files_by_label(dataset, validation_percentage, testing_percentage):
    """Distribute files into training, validation, and testing sets based on labels."""
    label_indices = {
        'yes': dataset.label_to_index['yes'],
        'no': dataset.label_to_index['no'],
        'unknown': dataset.label_to_index["unknown"]  # Default -1 if 'unknown' is not predefined
    }
    sets = {label: {'training': [], 'validation': [], 'testing': []} for label in label_indices}

    for file_path, label in zip(dataset.audio_paths, dataset.audio_labels):
        set_type = which_set(file_path, validation_percentage, testing_percentage)
        for label_name, index in label_indices.items():
            if label == index:
                sets[label_name][set_type].append(file_path)
                break

    return sets

def distribute_files(dataset, validation_percentage, testing_percentage):
    train_files = []
    validation_files = []
    test_files = []

    for file_path, label_index in zip(dataset.audio_paths, dataset.audio_labels):
        set_type = which_set(file_path, validation_percentage, testing_percentage)
        if set_type == 'training':
            train_files.append((file_path, label_index))
        elif set_type == 'validation':
            validation_files.append((file_path, label_index))
        else:  # 'testing'
            test_files.append((file_path, label_index))

    return train_files, validation_files, test_files


def balance_unknown_files(train_files, validation_files, test_files, unknown_label_index):
    # Count 'yes' files in each dataset
    num_yes_train = sum(1 for path, label in train_files if label == dataset.label_to_index['yes'])
    num_yes_validation = sum(1 for path, label in validation_files if label == dataset.label_to_index['yes'])
    num_yes_test = sum(1 for path, label in test_files if label == dataset.label_to_index['yes'])

    # Shuffle 'unknown' files
    random.shuffle(train_files)
    random.shuffle(validation_files)
    random.shuffle(test_files)

    # Filter out 'unknown' files from each dataset
    train_unknown = [(path, label) for path, label in train_files if label == unknown_label_index][:num_yes_train]
    validation_unknown = [(path, label) for path, label in validation_files if label == unknown_label_index][:num_yes_validation]
    test_unknown = [(path, label) for path, label in test_files if label == unknown_label_index][:num_yes_test]

    # Combine 'yes' and 'no' files with balanced 'unknown' files
    train_balanced = [item for item in train_files if item[1] != unknown_label_index] + train_unknown
    validation_balanced = [item for item in validation_files if item[1] != unknown_label_index] + validation_unknown
    test_balanced = [item for item in test_files if item[1] != unknown_label_index] + test_unknown

    # Shuffle the balanced datasets
    random.shuffle(train_balanced)
    random.shuffle(validation_balanced)
    random.shuffle(test_balanced)

    return train_balanced, validation_balanced, test_balanced


def process_background_noise(background_noise_path, segment_length, target_samples, num_train, num_val, num_test):
    """Process and distribute background noise into dataset segments."""
    noise_files = [f for f in os.listdir(background_noise_path) if f.endswith('.wav')]
    random.shuffle(noise_files)
    
    audio_segments = []
    while len(audio_segments) < target_samples and noise_files:
        noise_file = random.choice(noise_files)
        noise_path = os.path.join(background_noise_path, noise_file)
        waveform, sample_rate = torchaudio.load(noise_path)
        total_samples = waveform.size(1)
        samples_per_segment = int(sample_rate * segment_length)
        max_start = total_samples - samples_per_segment
        
        if max_start > 0:
            start = random.randint(0, max_start)
            end = start + samples_per_segment
            audio_segments.append((noise_path, start, end))
            if len(audio_segments) >= target_samples:
                break

    if len(audio_segments) < (num_train + num_val + num_test):
        raise ValueError("Not enough samples collected to meet the required distribution.")

    random.shuffle(audio_segments)
    train = audio_segments[:num_train]
    val = audio_segments[num_train:num_train + num_val]
    test = audio_segments[num_train + num_val:num_train + num_val + num_test]

    return train, val, test


def save_segments(segments, save_path, prefix):
    """Save segments of audio to disk."""
    saved_paths = []
    os.makedirs(save_path, exist_ok=True)
    for i, (path, start, end) in enumerate(segments):
        waveform, sample_rate = torchaudio.load(path, frame_offset=start, num_frames=end-start)
        segment_path = os.path.join(save_path, f"{prefix}_{i}.wav")
        torchaudio.save(segment_path, waveform, sample_rate)
        saved_paths.append(segment_path)
    return saved_paths

In [3]:
DATASET_PATH = './train/train/audio'
BACKGROUND_NOISE_PATH = './train/train/audio/_background_noise_'
dataset = SpeechCommandsDataset(DATASET_PATH)

# Set distribution percentages
validation_percentage = 10.0
testing_percentage = 10.0

train_files, validation_files, test_files = distribute_files(dataset, validation_percentage, testing_percentage)
train_balanced, validation_balanced, test_balanced = balance_unknown_files(train_files, validation_files, test_files, dataset.label_to_index['unknown'])

# Display the result counts
print("Training files count:", len(train_balanced))
print("Validation files count:", len(validation_balanced))
print("Testing files count:", len(test_balanced))


train_files = [train_file[0] for train_file in train_balanced]
validation_files = [train_file[0] for train_file in validation_balanced]
test_files = [train_file[0] for train_file in test_balanced]

Training files count: 5573
Validation files count: 792
Testing files count: 764


In [4]:
yes_count_train = sum(1 for file_path in train_balanced if 'yes' in file_path[0].split(os.sep)[-2])
yes_count_val = sum(1 for file_path in validation_balanced if 'yes' in file_path[0].split(os.sep)[-2])
yes_count_test = sum(1 for file_path in test_balanced if 'yes' in file_path[0].split(os.sep)[-2])

In [5]:
target_background_samples = len(train_files)
num_train_samples = yes_count_train
num_val_samples = yes_count_val
num_test_samples = yes_count_test
segment_length = 1

train_unknown, val_unknown, test_unknown = process_background_noise(
    BACKGROUND_NOISE_PATH, segment_length, target_background_samples, 
    num_train_samples, num_val_samples, num_test_samples
)

In [6]:
train_unknown_save_path = './train/train/audio/silence/train'
val_unknown_save_path = './train/train/audio/silence/val'
test_unknown_save_path = './train/train/audio/silence/test'

train_silence_files = save_segments(train_unknown, train_unknown_save_path, 'train_silence')
val_silence_files = save_segments(val_unknown, val_unknown_save_path, 'val_silence')
test_silence_files = save_segments(test_unknown, test_unknown_save_path, 'test_silence')

In [7]:
train_files_set = set(train_files)
validation_files_set = set(validation_files)
test_files_set = set(test_files)

train_files_set.update(train_silence_files)
validation_files_set.update(val_silence_files)
test_files_set.update(test_silence_files)

train_files = list(train_files_set)
validation_files = list(validation_files_set)
test_files = list(test_files_set)

random.shuffle(train_files)
random.shuffle(validation_files)
random.shuffle(test_files)

In [10]:
val_silence_files

['./train/train/audio/silence/val\\val_silence_0.wav',
 './train/train/audio/silence/val\\val_silence_1.wav',
 './train/train/audio/silence/val\\val_silence_2.wav',
 './train/train/audio/silence/val\\val_silence_3.wav',
 './train/train/audio/silence/val\\val_silence_4.wav',
 './train/train/audio/silence/val\\val_silence_5.wav',
 './train/train/audio/silence/val\\val_silence_6.wav',
 './train/train/audio/silence/val\\val_silence_7.wav',
 './train/train/audio/silence/val\\val_silence_8.wav',
 './train/train/audio/silence/val\\val_silence_9.wav',
 './train/train/audio/silence/val\\val_silence_10.wav',
 './train/train/audio/silence/val\\val_silence_11.wav',
 './train/train/audio/silence/val\\val_silence_12.wav',
 './train/train/audio/silence/val\\val_silence_13.wav',
 './train/train/audio/silence/val\\val_silence_14.wav',
 './train/train/audio/silence/val\\val_silence_15.wav',
 './train/train/audio/silence/val\\val_silence_16.wav',
 './train/train/audio/silence/val\\val_silence_17.wav',
 '

In [8]:
print("Final training files count:", len(train_files))
print("Final validation files count:", len(validation_files))
print("Final testing files count:", len(test_files))

Final training files count: 7433
Final validation files count: 1053
Final testing files count: 1020


In [9]:
def save_file_list(file_list, file_name):
    """Save the list of file paths to a text file."""
    with open(file_name, 'w') as file:
        for file_path in file_list:
            file.write(file_path + '\n')

# Save the file lists
save_file_list(train_files, 'train_files.txt')
save_file_list(validation_files, 'validation_files.txt')
save_file_list(test_files, 'test_files.txt')

print("File lists saved to text files.")

File lists saved to text files.
