diff --git a/datasets/kws20.py b/datasets/kws20.py index e911106..c677627 100644 --- a/datasets/kws20.py +++ b/datasets/kws20.py @@ -95,6 +95,12 @@ class KWS: benchmark_keywords = ['down', 'go', 'left', 'no', 'off', 'on', 'right', 'stop', 'up', 'yes', '_silence_'] + # define constants for data types (train, test, validation, benchmark) + TRAIN = np.uint(0) + TEST = np.uint(1) + VALIDATION = np.uint(2) + BENCHMARK = np.uint(3) + def __init__(self, root, classes, d_type, t_type, transform=None, quantization_scheme=None, augmentation=None, download=False, save_unquantized=False, filter_libri=False, benchmark=False): @@ -429,14 +435,14 @@ class KWS: def __filter_dtype(self): if self.d_type == 'train': - idx_to_select = (self.data_type == 0)[:, -1] + idx_to_select = (self.data_type == self.TRAIN)[:, -1] elif self.d_type == 'test': if self.benchmark: - idx_to_select = (self.data_type == 3)[:, -1] + idx_to_select = (self.data_type == self.BENCHMARK)[:, -1] else: - idx_bm = (self.data_type == 3)[:, -1] - idx_test = (self.data_type == 1)[:, -1] - idx_to_select = idx_bm | idx_test + idx_bm = (self.data_type == self.BENCHMARK)[:, -1] + idx_test = (self.data_type == self.TEST)[:, -1] + idx_to_select = torch.logical_or(idx_bm, idx_test) else: print(f'Unknown data type: {self.d_type}') return @@ -452,14 +458,14 @@ class KWS: self.data = self.data[idx_to_select, :] self.targets = self.targets[idx_to_select, :] if self.d_type == 'test': - self.data_type[idx_to_select, :] = np.uint8(1) + self.data_type[idx_to_select, :] = self.TEST self.data_type = self.data_type[idx_to_select, :] self.shift_limits = self.shift_limits[idx_to_select, :] # append validation set to the training set if validation examples are explicitly included if self.d_type == 'train': - idx_to_select = (self.data_type_original == 2)[:, -1] + idx_to_select = (self.data_type_original == self.VALIDATION)[:, -1] if idx_to_select.sum() > 0: # if validation examples exist self.data = torch.cat((self.data, self.data_original[idx_to_select, :]), dim=0) self.targets = \ @@ -515,14 +521,14 @@ class KWS: torch.tensor(idx_for_librispeech)) if self.d_type == 'train': - set_size = sum((self.data_type == 0)[:, -1]) + set_size = sum((self.data_type == self.TRAIN)[:, -1]) elif self.d_type == 'test': - set_size = sum((self.data_type == 1)[:, -1]) + set_size = sum((self.data_type == self.TEST)[:, -1]) print(f'Remaining {self.d_type} set: {set_size} elements') if self.d_type == 'train': - train_size = sum((self.data_type == 0)[:, -1]) - set_size = sum((self.data_type == 2)[:, -1]) + train_size = sum((self.data_type == self.TRAIN)[:, -1]) + set_size = sum((self.data_type == self.VALIDATION)[:, -1]) # indicate the list of validation indices to be used by distiller's dataloader self.valid_indices = range(train_size, train_size + set_size) print(f'Remaining validation set: {set_size} elements') @@ -539,14 +545,14 @@ class KWS: self.shift_limits = torch.index_select(self.shift_limits, 0, torch.tensor(idx_for_silence)) if self.d_type == 'train': - set_size = sum((self.data_type == 0)[:, -1]) + set_size = sum((self.data_type == self.TRAIN)[:, -1]) elif self.d_type == 'test': - set_size = sum((self.data_type == 1)[:, -1]) + set_size = sum((self.data_type == self.TEST)[:, -1]) print(f'Remaining {self.d_type} set: {set_size} elements') if self.d_type == 'train': - train_size = sum((self.data_type == 0)[:, -1]) - set_size = sum((self.data_type == 2)[:, -1]) + train_size = sum((self.data_type == self.TRAIN)[:, -1]) + set_size = sum((self.data_type == self.VALIDATION)[:, -1]) # indicate the list of validation indices to be used by distiller's dataloader self.valid_indices = range(train_size, train_size + set_size) print(f'Remaining validation set: {set_size} elements') @@ -580,7 +586,7 @@ class KWS: data_type, shift_limits = self.data_type[index], self.shift_limits[index] # apply dynamic shift and noise augmentation to training examples - if data_type == 0: + if data_type == self.TRAIN: inp = self.shift_and_noise_augment(inp, shift_limits) # reshape to 2D @@ -745,7 +751,7 @@ class KWS: # sample 1 out of every 9 generated silence files for the validation set silence_files = [os.path.join('_silence_', s) for s in os.listdir(self.silence_folder) - if not s[0].isdigit()] # files starting with numbers: for testing + if not s[0].isdigit()] # files starting w/ numbers: used for testing validation_set.update(silence_files[::9]) train_count = 0 @@ -802,23 +808,23 @@ class KWS: if label in rec] if record_name in raw_test_list: - d_typ = np.uint(3) # benchmark test + d_typ = self.BENCHMARK # benchmark test test_count += 1 elif record_name in test_set: - d_typ = np.uint8(1) # test + d_typ = self.TEST test_count += 1 elif record_name in validation_set: - d_typ = np.uint8(2) # val + d_typ = self.VALIDATION valid_count += 1 else: - d_typ = np.uint8(0) # train + d_typ = self.TRAIN train_count += 1 record_pth = os.path.join(self.raw_folder, record_name) record, fs = librosa.load(record_pth, offset=0, sr=None) # training and validation examples get speed augmentation - if d_typ not in (1, 3): + if d_typ not in (self.TEST, self.BENCHMARK): no_augmentations = self.augmentation['aug_num'] else: # test examples don't get speed augmentation no_augmentations = 0 @@ -865,7 +871,7 @@ class KWS: # apply static shift & noise augmentation for validation examples for sample_index in range(data_in_all.shape[0]): - if data_type_all[sample_index] == 2: + if data_type_all[sample_index] == self.VALIDATION: data_in_all[sample_index] = \ self.shift_and_noise_augment(data_in_all[sample_index], data_shift_limits_all[sample_index]) @@ -1133,7 +1139,7 @@ datasets = [ 'name': 'KWS_12_benchmark', # 10 keywords + _silence_ + _unknown_ 'input': (128, 128), 'output': KWS.dataset_dict['KWS_12_benchmark'], - 'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.6, 0.06), + 'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.056), 'loader': KWS_12_benchmark_get_datasets, }, {