diff --git a/datasets/kws20.py b/datasets/kws20.py index 7d757ed..be50e70 100644 --- a/datasets/kws20.py +++ b/datasets/kws20.py @@ -70,31 +70,30 @@ class KWS: url_librispeech = 'http://us.openslr.org/resources/12/dev-clean.tar.gz' fs = 16000 - class_dict = {'backward': 0, 'bed': 1, 'bird': 2, 'cat': 3, 'dog': 4, 'down': 5, - 'eight': 6, 'five': 7, 'follow': 8, 'forward': 9, 'four': 10, 'go': 11, - 'happy': 12, 'house': 13, 'learn': 14, 'left': 15, - 'librispeech': 16, 'marvin': 17, 'nine': 18, 'no': 19, 'off': 20, 'on': 21, - 'one': 22, 'right': 23, 'seven': 24, 'sheila': 25, 'SILENCE': 26, 'six': 27, - 'stop': 28, 'three': 29, 'tree': 30, 'two': 31, 'up': 32, 'visual': 33, - 'wow': 34, 'yes': 35, 'zero': 36} + class_dict = {'_silence_': 0, 'backward': 1, 'bed': 2, 'bird': 3, 'cat': 4, 'dog': 5, + 'down': 6, 'eight': 7, 'five': 8, 'follow': 9, 'forward': 10, 'four': 11, + 'go': 12, 'happy': 13, 'house': 14, 'learn': 15, 'left': 16, 'librispeech': 17, + 'marvin': 18, 'nine': 19, 'no': 20, 'off': 21, 'on': 22, 'one': 23, 'right': 24, + 'seven': 25, 'sheila': 26, 'six': 27, 'stop': 28, 'three': 29, 'tree': 30, + 'two': 31, 'up': 32, 'visual': 33, 'wow': 34, 'yes': 35, 'zero': 36} dataset_dict = { - 'KWS': ('up', 'down', 'left', 'right', 'stop', 'go', 'UNKNOWN'), + 'KWS': ('up', 'down', 'left', 'right', 'stop', 'go', '_unknown_'), 'KWS_20': ('up', 'down', 'left', 'right', 'stop', 'go', 'yes', 'no', 'on', 'off', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'zero', - 'UNKNOWN'), + '_unknown_'), 'KWS_35': ('backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', - 'up', 'visual', 'wow', 'yes', 'zero', 'UNKNOWN'), + 'up', 'visual', 'wow', 'yes', 'zero', '_unknown_'), 'KWS_12_benchmark': ('down', 'go', 'left', 'no', 'off', 'on', 'right', 'stop', 'up', 'yes', - 'SILENCE', 'UNKNOWN') + '_silence_', '_unknown_') } benchmark_keywords = ['down', 'go', 'left', 'no', 'off', 'on', 'right', 'stop', - 'up', 'yes', 'SILENCE'] + 'up', 'yes', '_silence_'] def __init__(self, root, classes, d_type, t_type, transform=None, quantization_scheme=None, augmentation=None, download=False, save_unquantized=False, filter_libri=False, @@ -126,7 +125,7 @@ class KWS: print(f'\nProcessing {self.d_type}...') self.__filter_dtype() - if 'SILENCE' not in self.classes: + if '_silence_' not in self.classes: self.__filter_silence() if self.filter_libri: @@ -168,7 +167,7 @@ class KWS: def silence_folder(self): """Folder for the silence data. """ - return os.path.join(self.raw_folder, 'silence') + return os.path.join(self.raw_folder, '_silence_') def __parse_quantization(self, quantization_scheme): if quantization_scheme: @@ -211,35 +210,35 @@ class KWS: self.__makedir_exist_ok(self.raw_folder) self.__makedir_exist_ok(self.processed_folder) - # download Speech Command + # download Google Speech Commands dataset filename = self.url_speechcommand.rpartition('/')[2] self.__download_and_extract_archive(self.url_speechcommand, download_root=self.raw_folder, filename=filename) - # sampling long segments of background noise + # sample long segments of background noise (total 560*6 = 3360 samples) self.__sample_wav(os.path.join(self.raw_folder, '_background_noise_'), - self.silence_folder, 1000) + self.silence_folder, 560) - # download Speech Command test + # download Google Speech Commands official test set for 10 keywords + _silence_ + _unknown_ filename = self.url_test.rpartition('/')[2] self.__download_and_extract_archive(self.url_test, download_root=self.raw_test_folder, filename=filename) - # copying test silence files under Speech Command silence folder + # copy test silence files to raw folder, for ease of processing in gen_datasets shutil.copytree(os.path.join(self.raw_test_folder, '_silence_'), - os.path.join(self.raw_folder, 'silence'), dirs_exist_ok=True) + os.path.join(self.raw_folder, '_silence_'), dirs_exist_ok=True) - print('Test for silence class successfully copied under raw/silence folder.') + print('Test for _silence_ class successfully copied under raw/_silence_ folder.') - # download LibriSpeech + # download LibriSpeech dev-clean dataset filename = self.url_librispeech.rpartition('/')[2] self.__download_and_extract_archive(self.url_librispeech, download_root=self.librispeech_folder, filename=filename) - # convert the LibriSpeech audio files to 1-sec 16KHz .wav, stored under raw/librispeech + # convert the LibriSpeech audio files to 1-sec 16KHz .wav, store under raw/librispeech self.__resample_convert_wav(folder_in=self.librispeech_folder, folder_out=os.path.join(self.raw_folder, 'librispeech')) @@ -485,7 +484,7 @@ class KWS: self.new_class_dict = {} for c in self.classes: if c not in self.class_dict: - if c == 'UNKNOWN': + if c == '_unknown_': continue raise ValueError(f'Class {c} not found in data') num_elems = (self.targets == self.class_dict[c]).cpu().sum() @@ -495,17 +494,17 @@ class KWS: new_class_label += 1 num_elems = (self.targets < initial_new_class_label).cpu().sum() - print(f'Class UNKNOWN: {num_elems} elements') + print(f'Class _unknown_: {num_elems} elements') self.targets[(self.targets < initial_new_class_label)] = new_class_label self.targets -= initial_new_class_label self.new_class_dict = {c: self.new_class_dict[c] - initial_new_class_label for c in self.new_class_dict.keys()} - self.new_class_dict['UNKNOWN'] = len(self.new_class_dict) + self.new_class_dict['_unknown_'] = len(self.new_class_dict) def __filter_librispeech(self): - print('Filtering librispeech elements...') + print('Filtering out librispeech elements...') idx_for_librispeech = [idx for idx, val in enumerate(self.targets) if val != self.class_dict['librispeech']] @@ -519,20 +518,20 @@ class KWS: set_size = sum((self.data_type == 0)[:, -1]) elif self.d_type == 'test': set_size = sum((self.data_type == 1)[:, -1]) - print(f'{self.d_type} set: {set_size} elements') + print(f'Remaining {self.d_type} set: {set_size} elements') if self.d_type == 'train': train_size = sum((self.data_type == 0)[:, -1]) set_size = sum((self.data_type == 2)[:, -1]) # indicate the list of validation indices to be used by distiller's dataloader self.valid_indices = range(train_size, train_size + set_size) - print(f'validation set: {set_size} elements') + print(f'Remaining validation set: {set_size} elements') def __filter_silence(self): - print('Filtering silence elements...') + print('Filtering out _silence_ elements...') idx_for_silence = [idx for idx, val in enumerate(self.targets) - if val != self.class_dict['SILENCE']] + if val != self.class_dict['_silence_']] self.data = torch.index_select(self.data, 0, torch.tensor(idx_for_silence)) self.targets = torch.index_select(self.targets, 0, torch.tensor(idx_for_silence)) @@ -543,14 +542,14 @@ class KWS: set_size = sum((self.data_type == 0)[:, -1]) elif self.d_type == 'test': set_size = sum((self.data_type == 1)[:, -1]) - print(f'{self.d_type} set: {set_size} elements') + print(f'Remaining {self.d_type} set: {set_size} elements') if self.d_type == 'train': train_size = sum((self.data_type == 0)[:, -1]) set_size = sum((self.data_type == 2)[:, -1]) # indicate the list of validation indices to be used by distiller's dataloader self.valid_indices = range(train_size, train_size + set_size) - print(f'validation set: {set_size} elements') + print(f'Remaining validation set: {set_size} elements') def __len__(self): return len(self.data) @@ -725,7 +724,7 @@ class KWS: lst = sorted(os.listdir(self.raw_folder)) labels = [d for d in lst if os.path.isdir(os.path.join(self.raw_folder, d)) - and d[0].isalpha()] + and d[0].isalpha() or d == '_silence_'] # show the size of dataset for each keyword print('------------- Label Size ---------------') @@ -737,15 +736,17 @@ class KWS: # read testing_list.txt & validation_list.txt into sets for fast access with open(os.path.join(self.raw_folder, 'testing_list.txt'), encoding="utf-8") as f: test_set = set(f.read().splitlines()) - test_silence = [os.path.join('silence', rec) for rec in os.listdir( + test_silence = [os.path.join('_silence_', rec) for rec in os.listdir( os.path.join(self.raw_test_folder, '_silence_'))] test_set.update(test_silence) with open(os.path.join(self.raw_folder, 'validation_list.txt'), encoding="utf-8") as f: validation_set = set(f.read().splitlines()) - val_silence = [os.path.join('silence', s) for s in os.listdir(self.silence_folder) - if 'running_tap' in s] - validation_set.update(val_silence) + + # sample 1 out of every 9 generated silence files for the validation set + silence_files = [os.path.join('_silence_', s) for s in os.listdir(self.silence_folder) + if not s[0].isdigit()] # files starting with numbers: used for testing + validation_set.update(silence_files[::9]) train_count = 0 test_count = 0 @@ -787,7 +788,7 @@ class KWS: print(f'\t{r + 1} of {record_len}') if label in self.benchmark_keywords: - if label == 'SILENCE': + if label == '_silence_': raw_test_list = os.listdir(os.path.join( self.raw_test_folder, '_silence_')) else: @@ -796,6 +797,7 @@ class KWS: else: raw_test_list = os.listdir(os.path.join(self.raw_test_folder, '_unknown_')) # example record name: "backward_a6f2fd71_nohash_3.wav.wav" + # note: the double "wav" extension is due to errors in the original dataset raw_test_list = [os.path.join(label, rec[-25:-4]) for rec in raw_test_list if label in rec] @@ -815,7 +817,7 @@ class KWS: record_pth = os.path.join(self.raw_folder, record_name) record, fs = librosa.load(record_pth, offset=0, sr=None) - # training and validation examplesget speed augmentation + # training and validation examples get speed augmentation if d_typ not in (1, 3): no_augmentations = self.augmentation['aug_num'] else: # test examples don't get speed augmentation @@ -1022,7 +1024,7 @@ def KWS_20_msnoise_mixed_get_datasets(data, load_train=True, load_test=True, noise_dataset_train = MSnoise(root=data_dir, classes=noise_type, d_type='train', dataset_len=len(kws_train_dataset), desired_probs=desired_probs, - transform=None, quantize=False, download=False) + transform=None, quantize=False, download=True) train_dataset = SignalMixer(signal_dataset=kws_train_dataset, snr_range=snr_range, @@ -1042,7 +1044,7 @@ def KWS_20_msnoise_mixed_get_datasets(data, load_train=True, load_test=True, def KWS_12_benchmark_get_datasets(data, load_train=True, load_test=True): """ Returns the KWS dataset benchmark for 12 classes. 10 keywords and - SILENCE + UNKNOWN. + _silence_ + _unknown_. """ return KWS_get_datasets(data, load_train, load_test, dataset_name='KWS_12_benchmark', num_classes=11, filter_libri=True, benchmark=True) @@ -1099,21 +1101,21 @@ def MixedKWS_20_get_datasets_10dB(data, load_train=True, load_test=True, datasets = [ { - 'name': 'KWS', # 6 keywords + unknown + 'name': 'KWS', # 6 keywords + _unknown_ 'input': (512, 64), 'output': KWS.dataset_dict['KWS'], 'weight': (1, 1, 1, 1, 1, 1, 0.06), 'loader': KWS_get_datasets, }, { - 'name': 'KWS_20', # 20 keywords + unknown + 'name': 'KWS_20', # 20 keywords + _unknown_ 'input': (128, 128), 'output': KWS.dataset_dict['KWS_20'], 'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.07), 'loader': KWS_20_get_datasets, }, { - 'name': 'KWS_35', # 35 keywords + unknown + 'name': 'KWS_35', # 35 keywords + _unknown_ 'input': (128, 128), 'output': KWS.dataset_dict['KWS_35'], 'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -1121,28 +1123,28 @@ datasets = [ 'loader': KWS_35_get_datasets, }, { - 'name': 'KWS_20_msnoise_mixed', + 'name': 'KWS_20_msnoise_mixed', # 20 keywords + _unknown_ 'input': (128, 128), 'output': KWS.dataset_dict['KWS_20'], 'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.07), 'loader': KWS_20_msnoise_mixed_get_datasets, }, { - 'name': 'KWS_12_benchmark', # 10 keyword + silence + unknown + 'name': 'KWS_12_benchmark', # 10 keywords + _silence_ + _unknown_ 'input': (128, 128), 'output': KWS.dataset_dict['KWS_12_benchmark'], 'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.6, 0.06), 'loader': KWS_12_benchmark_get_datasets, }, { - 'name': 'MixedKWS20_10dB', + 'name': 'MixedKWS20_10dB', # 20 keywords + _unknown_ 'input': (128, 128), 'output': KWS.dataset_dict['KWS_20'], 'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.07), 'loader': MixedKWS_20_get_datasets_10dB, }, { - 'name': 'KWS_35_unquantized', # 35 keywords + unknown + 'name': 'KWS_35_unquantized', # 35 keywords + _unknown_ 'input': (128, 128), 'output': KWS.dataset_dict['KWS_35'], 'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,