From 19c58bdb8bc0d44d9dd7317520d10da04b85c556 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 10 Aug 2023 08:44:25 +0000 Subject: [PATCH 1/8] add sec control in dataset --- egs_s2/LibriLight/conf/small_medium_iter.yaml | 22 +++++++++----- soundstorm/s2/data/semantic_dataset.py | 29 +++++++++---------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/egs_s2/LibriLight/conf/small_medium_iter.yaml b/egs_s2/LibriLight/conf/small_medium_iter.yaml index dbd466e..64cb2c5 100644 --- a/egs_s2/LibriLight/conf/small_medium_iter.yaml +++ b/egs_s2/LibriLight/conf/small_medium_iter.yaml @@ -27,6 +27,10 @@ model: timestep_type: 'adalayernorm' # adainsnorm or adalayernorm and abs mlp_hidden_times: 4 semantic_token_nums: 300 + prompt_semantic_emb_len: 10 # should > max_prompt_sec in dataset + target_semantic_emb_len: 30 # should > max_target_sec in dataset + prompt_acoustic_emb_len: 10 # can be same with prompt_semantic + target_acoustic_emb_len: 30 # can be same with target_semantic content_emb_config: target: soundstorm.s2.models.dalle_wav.mask_embedding.DalleMaskImageEmbedding params: @@ -37,11 +41,11 @@ model: pos_emb_type: embedding solver: - base_lr: 0.3e-05 # 3.0e-6 x 8 cause max_token_one_batch is x 8 (the old is 10k) - adjust_lr: none # not adjust lr according to total batch_size + base_lr: 0.3e-05 # 3.0e-6 x 8 cause max_token_one_batch is x 8 (the old is 10k) + adjust_lr: none # not adjust lr according to total batch_size max_iters: 550000 # 55w iter for 8 GPU for small + medium, ~70 epochs, ~7.8k iter/epoch, training 2.5h/epoch - save_iters: 1000 # 1k, ~ cost 0.5h to save a ckpt - dev_iters: 1000 + save_iters: 1500 # 1.5k, ~ cost 0.3h to save a ckpt + dev_iters: 1500 # num of iter for each gpu, for 8 gpu here, should x2 when use 4 gpus to make model see same number of samples ema: decay: 0.99 update_interval: 25 @@ -69,7 +73,7 @@ solver: threshold: 1.0e-1 threshold_mode: rel warmup_lr: 0.45e-3 # the lr to be touched after warmup - warmup: 800 # 这个是 warm up iter 数 + warmup: 800 # num of iter to warmup dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) @@ -79,7 +83,9 @@ dataloader: params: codec_name: hificodec num_quant: 4 # not work when != 4 for hificodec, and can be 3 for soundstream and encodec - semantic_token_nums: 300 # 1000 for mhubert 500 for en_hubert + semantic_token_nums: 300 # same with num of kmeans bins + max_prompt_sec: 3 # be same with LibriTTS + max_target_sec: 20 # LibriTTS is 10, use 20 here for longer TTS semantic_path: dump/train/semantic_token.tsv acoustic_path: dump/train/acoustic_token/hificodec.pth @@ -88,6 +94,8 @@ dataloader: params: codec_name: hificodec num_quant: 4 # not work when != 4 for hificodec, and can be 3 for soundstream and encodec - semantic_token_nums: 300 # 1000 for mhubert 500 for en_hubert + semantic_token_nums: 300 # same with num of kmeans bins + max_prompt_sec: 3 # be same with LibriTTS + max_target_sec: 20 # LibriTTS is 10, use 20 here for longer TTS semantic_path: dump/dev/semantic_token.tsv acoustic_path: dump/dev/acoustic_token/hificodec.pth diff --git a/soundstorm/s2/data/semantic_dataset.py b/soundstorm/s2/data/semantic_dataset.py index c3b883d..0f179f7 100644 --- a/soundstorm/s2/data/semantic_dataset.py +++ b/soundstorm/s2/data/semantic_dataset.py @@ -35,23 +35,24 @@ def __init__( num_quant, semantic_path, acoustic_path, - codec_name='hificodec', - max_length=(250, 250), - max_token_one_batch=10000, - # 1000 for mhubert 500 for en_hubert - semantic_token_nums=1000): + codec_name: str='hificodec', + max_token_one_batch: int=10000, + semantic_token_nums: int=1000, + max_prompt_sec: int=3, + max_target_sec: int=10): super().__init__() self.semantic_data = pd.read_csv(semantic_path, delimiter='\t') # get dict self.acoustic_data = torch.load(acoustic_path) - self.max_length = max_length self.num_quant = 4 if codec_name == 'hificodec' else num_quant # 16000 / 320 = 50 self.hz = 50 # 分辨率 - # 默认使用 3s 一个segments - self.segment_size = 3 + + self.max_prompt_sec = max_prompt_sec + self.max_target_sec = max_target_sec + self.max_sec = self.max_prompt_sec + self.max_target_sec # NOTE by yuantian: same as SemanticTokenizer.dim_codebook self.semantic_token_nums = semantic_token_nums @@ -80,10 +81,8 @@ def __init__( def init_batch(self): # this function aims to prepare batch - # 一个 batch 的总 token 数量设为 5600 ❓ # 先根据 semantic_data 的 长度进行排序 # target 最长设为 10s, prompt 3s, 1s 对应 50 个 token, - # 因此,若使用 10s,则一条数据有 500*3 + 500 + 150+ 150 = 2300 个 token, 则只能放 2 条数据 *3 是什么意思❓ max_token_one_batch = self.max_token_one_batch sementic_ls = [] len_ls = [] @@ -104,7 +103,7 @@ def init_batch(self): range(len(len_ls)), key=lambda k: len_ls[k], reverse=True) start_batch_id = 0 # 最大长度为 13s - max_len = 13 * self.hz + max_len = self.max_sec * self.hz tmp_prompt_semantics = [] tmp_target_semantics = [] tmp_prompt_acoustics = [] @@ -129,8 +128,8 @@ def init_batch(self): over_semantic_len = over_semantic.shape[1] if over_semantic_len > max_len: # 若音频长度大于 13s,则考虑切成 3 + 10, prompt 3s, target 10s - prompt_len = 3 * self.hz - targen_len = 10 * self.hz + prompt_len = self.max_prompt_sec * self.hz + targen_len = self.max_target_sec * self.hz # 先随机选一个 prompt 起始点,max 为最后 13s # 总长度剪去 13s max_prompt_index = over_semantic_len - max_len @@ -144,9 +143,9 @@ def init_batch(self): target_acoustic = over_acoustic[:, prompt_end:prompt_end + targen_len] # 如果长度大于 6s 小于 13s - elif over_semantic_len > 6 * self.hz and over_semantic_len < max_len: + elif over_semantic_len > 2 * self.max_prompt_sec * self.hz and over_semantic_len < max_len: # 3s 的 prompt, 其后全为 target - prompt_len = 3 * self.hz + prompt_len = self.max_prompt_sec * self.hz prompt_semantic = over_semantic[:, :prompt_len] prompt_acoustic = over_acoustic[:, :prompt_len] # 前 3s 以后,全做为 target From bc1040ac62d996d9cbfdff18d6407b4702dc6a48 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 10 Aug 2023 11:57:21 +0000 Subject: [PATCH 2/8] raw dataset_6k ok --- egs_s2/LibriLight/conf/small_medium_iter.yaml | 12 +- egs_s2/LibriLight/local/train.sh | 14 +- soundstorm/s1/AR/exps/get_txt_librilight.py | 2 +- soundstorm/s2/data/build.py | 8 +- soundstorm/s2/data/build_librilight_6k.py | 89 ++++++ .../s2/data/semantic_dataset_librilight_6k.py | 292 ++++++++++++++++++ soundstorm/s2/exps/train_librilight_6k.py | 290 +++++++++++++++++ 7 files changed, 690 insertions(+), 17 deletions(-) create mode 100644 soundstorm/s2/data/build_librilight_6k.py create mode 100644 soundstorm/s2/data/semantic_dataset_librilight_6k.py create mode 100644 soundstorm/s2/exps/train_librilight_6k.py diff --git a/egs_s2/LibriLight/conf/small_medium_iter.yaml b/egs_s2/LibriLight/conf/small_medium_iter.yaml index 64cb2c5..1b46aec 100644 --- a/egs_s2/LibriLight/conf/small_medium_iter.yaml +++ b/egs_s2/LibriLight/conf/small_medium_iter.yaml @@ -79,23 +79,23 @@ dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) num_workers: 2 train_datasets: # a list of configures, so we can combine several schedulers - - target: soundstorm.s2.data.semantic_dataset.SemanticDataset + - target: soundstorm.s2.data.semantic_dataset_librilight_6k.SemanticDataset params: codec_name: hificodec num_quant: 4 # not work when != 4 for hificodec, and can be 3 for soundstream and encodec semantic_token_nums: 300 # same with num of kmeans bins max_prompt_sec: 3 # be same with LibriTTS max_target_sec: 20 # LibriTTS is 10, use 20 here for longer TTS - semantic_path: dump/train/semantic_token.tsv - acoustic_path: dump/train/acoustic_token/hificodec.pth + semantic_dirs: ['dump/small/train/'] + acoustic_dirs: ['dump/small/train/acoustic_token/'] dev_datasets: - - target: soundstorm.s2.data.semantic_dataset.SemanticDataset + - target: soundstorm.s2.data.semantic_dataset_librilight_6k.SemanticDataset params: codec_name: hificodec num_quant: 4 # not work when != 4 for hificodec, and can be 3 for soundstream and encodec semantic_token_nums: 300 # same with num of kmeans bins max_prompt_sec: 3 # be same with LibriTTS max_target_sec: 20 # LibriTTS is 10, use 20 here for longer TTS - semantic_path: dump/dev/semantic_token.tsv - acoustic_path: dump/dev/acoustic_token/hificodec.pth + semantic_dirs: ['dump/small/dev/'] + acoustic_dirs: ['dump/small/dev/acoustic_token/'] diff --git a/egs_s2/LibriLight/local/train.sh b/egs_s2/LibriLight/local/train.sh index 86cd717..1c4a2ca 100755 --- a/egs_s2/LibriLight/local/train.sh +++ b/egs_s2/LibriLight/local/train.sh @@ -7,14 +7,16 @@ log_frequency=$4 dist_url=$5 dump_dir=$6 -python3 ${BIN_DIR}/train_large.py \ +# 注意 *_dirs 参数后面不可以有 ''=' +python3 ${BIN_DIR}/train_librilight_6k.py \ --config_file=${config_path} \ - --train_semantic_path=${root_dir}/${dump_dir}/train/semantic_token.tsv \ - --train_acoustic_path=${root_dir}/${dump_dir}/train/acoustic_token/hificodec.pth \ - --dev_semantic_path=${root_dir}/${dump_dir}/dev/semantic_token.tsv \ - --dev_acoustic_path=${root_dir}/${dump_dir}/dev/acoustic_token/hificodec.pth \ + --train_semantic_dirs ''${root_dir}'/'${dump_dir}'/small/train/' ''${root_dir}'/'${dump_dir}'/large/train/' \ + --train_acoustic_dirs ''${root_dir}'/'${dump_dir}'/small/train/acoustic_token/' ''${root_dir}'/'${dump_dir}'/medium/train/acoustic_token/' \ + --dev_semantic_dirs ''${root_dir}'/'${dump_dir}'/small/dev/' ''${root_dir}'/'${dump_dir}'/medium/dev/' \ + --dev_acoustic_dirs ''${root_dir}'/'${dump_dir}'/small/dev/acoustic_token/' ''${root_dir}'/'${dump_dir}'/medium/dev/acoustic_token/' \ --output=${root_dir}/${train_output_path} \ --log_frequency=${log_frequency} \ --dist_url=${dist_url} \ --hificodec_model_path=pretrained_model/hificodec/HiFi-Codec-16k-320d-large-universal \ - --hificodec_config_path=pretrained_model/hificodec/config_16k_320d.json \ No newline at end of file + --hificodec_config_path=pretrained_model/hificodec/config_16k_320d.json \ + --train_with_iter=True \ No newline at end of file diff --git a/soundstorm/s1/AR/exps/get_txt_librilight.py b/soundstorm/s1/AR/exps/get_txt_librilight.py index 99c22a4..49da7d7 100644 --- a/soundstorm/s1/AR/exps/get_txt_librilight.py +++ b/soundstorm/s1/AR/exps/get_txt_librilight.py @@ -157,7 +157,7 @@ def process_sentences(args, if asr_result is not False: txt_dict[subset][utt_id] = asr_result else: - print('asr result of {utt_id} is False') + print(f'asr result of {utt_id} is False') except Exception: print(f"{utt_id} occur Exception") traceback.print_exc() diff --git a/soundstorm/s2/data/build.py b/soundstorm/s2/data/build.py index c522b37..898566c 100644 --- a/soundstorm/s2/data/build.py +++ b/soundstorm/s2/data/build.py @@ -5,16 +5,15 @@ from torch.utils.data import ConcatDataset - def build_dataloader(config, args=None, return_dataset=False): dataset_cfg = config['dataloader'] batch_size = 1 train_dataset = [] for ds_cfg in dataset_cfg['train_datasets']: - # ds_cfg['params']['data_root'] = dataset_cfg.get('data_root', '') ds_cfg['params']['semantic_path'] = args.train_semantic_path ds_cfg['params']['acoustic_path'] = args.train_acoustic_path - ds_cfg['params']['max_token_one_batch'] = dataset_cfg['max_token_one_batch'] + ds_cfg['params']['max_token_one_batch'] = dataset_cfg[ + 'max_token_one_batch'] ds = instantiate_from_config(ds_cfg) train_dataset.append(ds) if len(train_dataset) > 1: @@ -25,7 +24,8 @@ def build_dataloader(config, args=None, return_dataset=False): for ds_cfg in dataset_cfg['dev_datasets']: ds_cfg['params']['semantic_path'] = args.dev_semantic_path ds_cfg['params']['acoustic_path'] = args.dev_acoustic_path - ds_cfg['params']['max_token_one_batch'] = dataset_cfg['max_token_one_batch'] + ds_cfg['params']['max_token_one_batch'] = dataset_cfg[ + 'max_token_one_batch'] ds = instantiate_from_config(ds_cfg) dev_dataset.append(ds) if len(dev_dataset) > 1: diff --git a/soundstorm/s2/data/build_librilight_6k.py b/soundstorm/s2/data/build_librilight_6k.py new file mode 100644 index 0000000..e8dfa22 --- /dev/null +++ b/soundstorm/s2/data/build_librilight_6k.py @@ -0,0 +1,89 @@ +# Fast loader +# it can help to fast read data, so that it can improve the training time. +import torch +from soundstorm.s2.utils.misc import instantiate_from_config +from torch.utils.data import ConcatDataset + + +def build_dataloader(config, args=None, return_dataset=False): + dataset_cfg = config['dataloader'] + batch_size = 1 + train_dataset = [] + for ds_cfg in dataset_cfg['train_datasets']: + ds_cfg['params']['semantic_dirs'] = args.train_semantic_dirs + ds_cfg['params']['acoustic_dirs'] = args.train_acoustic_dirs + ds_cfg['params']['max_token_one_batch'] = dataset_cfg[ + 'max_token_one_batch'] + ds = instantiate_from_config(ds_cfg) + train_dataset.append(ds) + if len(train_dataset) > 1: + train_dataset = ConcatDataset(train_dataset) + else: + train_dataset = train_dataset[0] + dev_dataset = [] + for ds_cfg in dataset_cfg['dev_datasets']: + ds_cfg['params']['semantic_dirs'] = args.dev_semantic_dirs + ds_cfg['params']['acoustic_dirs'] = args.dev_acoustic_dirs + ds_cfg['params']['max_token_one_batch'] = dataset_cfg[ + 'max_token_one_batch'] + ds = instantiate_from_config(ds_cfg) + dev_dataset.append(ds) + if len(dev_dataset) > 1: + dev_dataset = ConcatDataset(dev_dataset) + else: + dev_dataset = dev_dataset[0] + + if args is not None and args.distributed: + # I add "num_replicas=world_size, rank=rank" + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, shuffle=True) + dev_sampler = torch.utils.data.distributed.DistributedSampler( + dev_dataset, shuffle=False) + train_iters = len(train_sampler) // batch_size + dev_iters = len(dev_sampler) // batch_size + else: + train_sampler = None + dev_sampler = None + # 每个 epoch 进行一次 + train_iters = len(train_dataset) // batch_size + dev_iters = len(dev_dataset) // batch_size + num_workers = dataset_cfg['num_workers'] + persistent_workers = True if num_workers > 0 else False + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=(train_sampler is None), + num_workers=num_workers, + pin_memory=True, + sampler=train_sampler, + drop_last=True, + collate_fn=train_dataset.collater, + persistent_workers=persistent_workers, + # 解决 num_workers>0 时的 bad value(s) in fds_to_keep 报错 + multiprocessing_context='fork') + + dev_loader = torch.utils.data.DataLoader( + dev_dataset, + batch_size=batch_size, + #(dev_sampler is None), + shuffle=False, + num_workers=num_workers, + sampler=dev_sampler, + drop_last=True, + pin_memory=True, + collate_fn=train_dataset.collater, + persistent_workers=persistent_workers, + multiprocessing_context='fork') + + dataload_info = { + 'train_loader': train_loader, + 'dev_loader': dev_loader, + 'train_iterations': train_iters, + 'dev_iterations': dev_iters + } + + if return_dataset: + dataload_info['train_dataset'] = train_dataset + dataload_info['dev_dataset'] = dev_dataset + + return dataload_info diff --git a/soundstorm/s2/data/semantic_dataset_librilight_6k.py b/soundstorm/s2/data/semantic_dataset_librilight_6k.py new file mode 100644 index 0000000..72d69af --- /dev/null +++ b/soundstorm/s2/data/semantic_dataset_librilight_6k.py @@ -0,0 +1,292 @@ +import os +import random + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F + +# BaseDataset code from NATSpeech +''' +数据集构建策略: +(1) prompt 不超过 3s, target 不超过 3s +(2) 若总长度小于 6s, 则 1/2 分给 prompt, 1/2 分给 target. +(3) 分成 se_pro, se-tar, ac_pro, ac_targ 4 个部分返回,每个部分分别 padding 到其 max sample +''' + + +def get_files_by_suffix(path, suffix): + files = [] + for file_name in os.listdir(path): + if file_name.endswith(suffix): + files.append(os.path.join(path, file_name)) + return files + + +def pad_2D(inputs, PAD, print_len=False): + # when each sample in inputs is 2D, this function can be used + def pad(x, max_len): + return F.pad(x, (0, max_len - x.shape[-1]), mode="constant", value=PAD) + + max_len = max(np.shape(x)[-1] for x in inputs) + input_len = len(inputs) + if print_len: + min_len = min(np.shape(x)[-1] for x in inputs) + print("input_len, max_len, min_len, max_len-min_len:", input_len, + max_len, min_len, max_len - min_len) + output = np.stack([pad(x, max_len) for x in inputs]) + return output + + +class SemanticDataset(torch.utils.data.Dataset): + def __init__(self, + num_quant, + semantic_dirs, + acoustic_dirs, + codec_name: str='hificodec', + max_token_one_batch: int=10000, + semantic_token_nums: int=1000, + max_prompt_sec: int=3, + max_target_sec: int=10): + super().__init__() + + self.semantic_data_dict = dict() + self.acoustic_data_dict = dict() + # self.semantic_data = pd.read_csv(semantic_path, delimiter='\t') + # # get dict + # self.acoustic_data = torch.load(acoustic_path) + semantic_files = [] + acoustic_files = [] + for semantic_dir in semantic_dirs: + semantic_files += get_files_by_suffix(semantic_dir, 'tsv') + for acoustic_dir in acoustic_dirs: + acoustic_files += get_files_by_suffix(acoustic_dir, 'pth') + + for semantic_file in semantic_files: + name_list = semantic_file.split("/") + rank_name = '_'.join(name_list[-1].split('.')[0].split('_')[-2:]) + key_name = f'{name_list[-3]}_{rank_name}' + self.semantic_data_dict[key_name] = pd.read_csv( + semantic_file, delimiter='\t') + for acoustic_file in acoustic_files: + name_list = acoustic_file.split("/") + rank_name = '_'.join(name_list[-1].split('.')[0].split('_')[-2:]) + key_name = f'{name_list[-4]}_{rank_name}' + self.acoustic_data_dict[key_name] = torch.load(acoustic_file) + + self.num_quant = 4 if codec_name == 'hificodec' else num_quant + # 16000 / 320 = 50 + self.hz = 50 # 分辨率 + + self.max_prompt_sec = max_prompt_sec + self.max_target_sec = max_target_sec + self.max_sec = self.max_prompt_sec + self.max_target_sec + + # NOTE by yuantian: same as SemanticTokenizer.dim_codebook + self.semantic_token_nums = semantic_token_nums + # self.prompt_semantic_start_id = self.semantic_token_nums + self.prompt_semantic_end_id = self.semantic_token_nums + 1 + # self.target_semantic_start_id = self.semantic_token_nums + 2 + self.target_semantic_end_id = self.semantic_token_nums + 3 + # NOTE by yuantian: N in codec + self.acoustic_token_nums = 1024 + self.prompt_acoustic_eos = self.acoustic_token_nums + self.target_acoustic_eos = self.acoustic_token_nums + 1 + + self.batch_prompt_semantics = {} + self.batch_target_semantics = {} + self.batch_prompt_acoustics = {} + self.batch_target_acoustics = {} + + # 一个 batch 最多多少个 token + self.max_token_one_batch = max_token_one_batch + self.inited = False + self.start_batch_id = 0 + self.total_semantic_data_len = 0 + self.total_acoustic_data_len = 0 + + if not self.inited: + # 调用初始化函数 + for key_name in self.semantic_data_dict.keys(): + self.init_batch(key_name) + self.inited = True + print("self.total_semantic_data_len:", self.total_semantic_data_len) + print("self.total_acoustic_data_len:", self.total_acoustic_data_len) + + def init_batch(self, key_name): + # this function aims to prepare batch + # 先根据 semantic_data 的 长度进行排序 + # target 最长设为 10s, prompt 3s, 1s 对应 50 个 token, + if key_name not in self.acoustic_data_dict.keys(): + print(f'{key_name} not in self.acoustic_data_dict') + return None + + semantic_data = self.semantic_data_dict[key_name] + acoustic_data = self.acoustic_data_dict[key_name] + max_token_one_batch = self.max_token_one_batch + sementic_ls = [] + len_ls = [] + semantic_data_len = len(semantic_data) + acoustic_data_len = len(acoustic_data.keys()) + + self.total_semantic_data_len += semantic_data_len + self.total_acoustic_data_len += acoustic_data_len + + for i in range(semantic_data_len): + # 先依次遍历 + # get str + semantic_str = semantic_data['semantic_audio'][i] + # get token list + tmp = [int(idx) for idx in semantic_str.split(' ')] + sementic_ls.append(tmp) + len_ls.append(len(tmp)) + # 按列表中元素的值进行排序,并返回元素对应索引序列 + sorted_id = sorted( + range(len(len_ls)), key=lambda k: len_ls[k], reverse=True) + + # 最大长度为 13s + max_len = self.max_sec * self.hz + tmp_prompt_semantics = [] + tmp_target_semantics = [] + tmp_prompt_acoustics = [] + tmp_target_acoustics = [] + tmp_tot_tokens = 0 + for i in range(len(sorted_id)): + # get the index + index = sorted_id[i] + # get the semantic + # (1, T) + over_semantic = torch.tensor(sementic_ls[index]).unsqueeze(0) + # 需要处理 item_name 不在 acoustic_data 中的情况 + item_name = semantic_data['item_name'][index] + try: + acoustic_str = acoustic_data[item_name] + except Exception: + print(item_name, "not in self.acoustic_data!") + continue + # only keep the first num_quant codebooks + # 这里表明 acoustic_token 的存储方式是 (C, T) + over_acoustic = acoustic_str[:self.num_quant, ...] + over_semantic_len = over_semantic.shape[1] + if over_semantic_len > max_len: + # 若音频长度大于 13s,则考虑切成 3 + 10, prompt 3s, target 10s + prompt_len = self.max_prompt_sec * self.hz + targen_len = self.max_target_sec * self.hz + # 先随机选一个 prompt 起始点,max 为最后 13s + # 总长度剪去 13s + max_prompt_index = over_semantic_len - max_len + left_start = random.randint(0, max_prompt_index) + prompt_end = left_start + prompt_len + prompt_semantic = over_semantic[:, left_start:prompt_end] + prompt_acoustic = over_acoustic[:, left_start:prompt_end] + # 往后数 10s + target_semantic = over_semantic[:, prompt_end:prompt_end + + targen_len] + target_acoustic = over_acoustic[:, prompt_end:prompt_end + + targen_len] + # 如果长度大于 6s 小于 13s + elif over_semantic_len > 2 * self.max_prompt_sec * self.hz and over_semantic_len < max_len: + # 3s 的 prompt, 其后全为 target + prompt_len = self.max_prompt_sec * self.hz + prompt_semantic = over_semantic[:, :prompt_len] + prompt_acoustic = over_acoustic[:, :prompt_len] + # 前 3s 以后,全做为 target + target_semantic = over_semantic[:, prompt_len:] + target_acoustic = over_acoustic[:, prompt_len:] + else: + # 小于 6s,直接平均分 + mid_id = int(over_semantic_len / 2) + # choose 3s + prompt_semantic = over_semantic[:, :mid_id] + prompt_acoustic = over_acoustic[:, :mid_id] + # 前 3s 以后,全做为 target + target_semantic = over_semantic[:, mid_id:] + target_acoustic = over_acoustic[:, mid_id:] + # 计算当前数据的 token 数量 + cal_num = prompt_semantic.shape[1] + target_semantic.shape[ + 1] + prompt_acoustic.shape[1] + target_acoustic.shape[1] + if tmp_tot_tokens + cal_num < max_token_one_batch: + # 若还没满一个 batch ,继续添加 + # shape: (1, 150) + tmp_prompt_semantics.append(prompt_semantic) + tmp_target_semantics.append(target_semantic) + tmp_prompt_acoustics.append(prompt_acoustic) + tmp_target_acoustics.append(target_acoustic) + # 添加当前 batch 的 token 数量 + tmp_tot_tokens += cal_num + else: + # 若已满一个 batch + # save batch + self.batch_prompt_semantics[str( + self.start_batch_id)] = tmp_prompt_semantics + self.batch_target_semantics[str( + self.start_batch_id)] = tmp_target_semantics + self.batch_prompt_acoustics[str( + self.start_batch_id)] = tmp_prompt_acoustics + self.batch_target_acoustics[str( + self.start_batch_id)] = tmp_target_acoustics + # clear previous step + tmp_prompt_semantics = [] + tmp_target_semantics = [] + tmp_prompt_acoustics = [] + tmp_target_acoustics = [] + # 重置为 0 + tmp_tot_tokens = 0 + # add new batch + tmp_prompt_semantics.append(prompt_semantic) + tmp_target_semantics.append(target_semantic) + tmp_prompt_acoustics.append(prompt_acoustic) + tmp_target_acoustics.append(target_acoustic) + tmp_tot_tokens += cal_num + self.start_batch_id += 1 + # add the last batch + self.batch_prompt_semantics[str( + self.start_batch_id)] = tmp_prompt_semantics + self.batch_target_semantics[str( + self.start_batch_id)] = tmp_target_semantics + self.batch_prompt_acoustics[str( + self.start_batch_id)] = tmp_prompt_acoustics + self.batch_target_acoustics[str( + self.start_batch_id)] = tmp_target_acoustics + + def __len__(self): + return len(self.batch_prompt_semantics) + + def __getitem__(self, index): + prompt_semantic = self.batch_prompt_semantics[str(index)] + target_semantic = self.batch_target_semantics[str(index)] + prompt_acoustic = self.batch_prompt_acoustics[str(index)] + target_acoustic = self.batch_target_acoustics[str(index)] + sample = {} + sample['prompt_semantic'] = prompt_semantic + sample['target_semantic'] = target_semantic + sample['prompt_acoustic'] = prompt_acoustic + sample['target_acoustic'] = target_acoustic + return sample + + def collater(self, samples): + # 为什么只取 第 0 个? => 因为 samples 是 list 长度一直是 1, batch_size must be 1 here + # prompt_semantics 里面是 n 个 tensor, n 的大小不固定 + # len(prompt_semantics) = 100 ,表示 batch_size = 100, batch_size 是不固定的 + prompt_semantics = samples[0]['prompt_semantic'] + target_semantics = samples[0]['target_semantic'] + prompt_acoustics = samples[0]['prompt_acoustic'] + target_acoustics = samples[0]['target_acoustic'] + # in this version, we do not use pading token any more, instead, we use eos token + # 一个 batch 里面按照最长的补 0 + prompt_semantics = pad_2D(prompt_semantics, self.prompt_semantic_end_id) + target_semantics = pad_2D(target_semantics, self.target_semantic_end_id) + prompt_acoustics = pad_2D(prompt_acoustics, self.prompt_acoustic_eos) + # 用 1025 补零 + target_acoustics = pad_2D(target_acoustics, self.target_acoustic_eos) + # mask 住 target_acoustics 的补 0 部分 + x_mask = (target_acoustics == self.target_acoustic_eos) + new_samples = {} + # (B, 1, T), B, T 动态 + new_samples['prompt_semantics'] = torch.from_numpy(prompt_semantics) + new_samples['target_semantics'] = torch.from_numpy(target_semantics) + new_samples['prompt_acoustics'] = torch.from_numpy(prompt_acoustics) + # (B, 4, T), B, T 动态 + new_samples['target_acoustics'] = torch.from_numpy(target_acoustics) + new_samples['x_mask'] = torch.from_numpy(x_mask[:, 0, :]) + return new_samples diff --git a/soundstorm/s2/exps/train_librilight_6k.py b/soundstorm/s2/exps/train_librilight_6k.py new file mode 100644 index 0000000..6d91e90 --- /dev/null +++ b/soundstorm/s2/exps/train_librilight_6k.py @@ -0,0 +1,290 @@ +# train and eval control by iter not epoch +# ------------------------------------------ +# Diffsound, By Dongchao Yang +# based on https://github.com/cientgu/VQ-Diffusion +# ------------------------------------------ +import argparse +import os +import warnings + +import torch +from academicodec.models.hificodec.vqvae import VQVAE +from soundstorm.s2.data.build_librilight_6k import build_dataloader +from soundstorm.s2.distributed.launch import launch +from soundstorm.s2.engine.logger import Logger +from soundstorm.s2.models.dalle_wav.build import build_model +from soundstorm.s2.utils.misc import merge_opts_to_config +from soundstorm.s2.utils.misc import modify_config_for_debug +from soundstorm.s2.utils.misc import seed_everything +from soundstorm.utils import str2bool +from soundstorm.utils.io import load_yaml_config + +NODE_RANK = os.environ['INDEX'] if 'INDEX' in os.environ else 0 +NODE_RANK = int(NODE_RANK) +MASTER_ADDR, MASTER_PORT = (os.environ['CHIEF_IP'], + 22275) if 'CHIEF_IP' in os.environ else ( + "127.0.0.1", 29500) +MASTER_PORT = int(MASTER_PORT) +DIST_URL = 'tcp://%s:%s' % (MASTER_ADDR, MASTER_PORT) +NUM_NODE = os.environ['HOST_NUM'] if 'HOST_NUM' in os.environ else 1 + + +def get_args(): + parser = argparse.ArgumentParser(description='PyTorch Training script') + parser.add_argument( + '--config_file', + type=str, + default='conf/default.yaml', + help='path of config file') + parser.add_argument( + '--output', + type=str, + default='exp/default', + help='directory to save the results') + parser.add_argument( + '--log_frequency', + type=int, + default=100, + help='print frequency (default: 100 iter)') + parser.add_argument( + '--load_path', + type=str, + default=None, + help='path to model that need to be loaded, used for loading pretrained model' + ) + parser.add_argument( + "--auto_resume", + type=str2bool, + default=True, + help="automatically resume the training") + # args for dataset + parser.add_argument( + '--train_semantic_dirs', + type=list, + nargs='+', + default=["dump/large/dev/"], + help='dirs of train semantic') + parser.add_argument( + '--train_acoustic_dirs', + type=list, + nargs='+', + default=["dump/small/train/acoustic/"], + help='dirs of train acoustic') + parser.add_argument( + '--dev_semantic_dirs', + type=list, + nargs='+', + default=["dump/small/dev/"], + help='dirs of dev semantic') + parser.add_argument( + '--dev_acoustic_dirs', + type=list, + nargs='+', + default=["dump/small/dev/acoustic/"], + help='dirs of dev acoustic') + + # args for ddp + parser.add_argument( + '--num_node', + type=int, + default=NUM_NODE, + help='number of nodes for distributed training') + parser.add_argument( + '--ngpus_per_node', + type=int, + default=8, + help='number of gpu on one node') + parser.add_argument( + '--node_rank', + type=int, + default=NODE_RANK, + help='node rank for distributed training') + parser.add_argument( + '--dist_url', + type=str, + default=DIST_URL, + help='url used to set up distributed training') + parser.add_argument( + '--gpu', + type=int, + default=None, + help='GPU id to use. If given, only the specific gpu will be' + ' used, and ddp will be disabled') + parser.add_argument( + '--local_rank', + default=-1, + type=int, + help='node rank for distributed training') + parser.add_argument( + "--sync_bn", type=str2bool, default=False, help="use sync BN layer") + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="use tensorboard for logging") + parser.add_argument("--timestamp", type=str2bool, default=True) + # args for random + parser.add_argument( + '--seed', + type=int, + default=None, + help='seed for initializing training. ') + parser.add_argument( + "--cudnn_deterministic", + type=str2bool, + default=False, + help="set cudnn.deterministic True") + parser.add_argument( + "--amp", + type=str2bool, + default=False, + help="automatic mixture of precesion") + parser.add_argument( + "--debug", type=str2bool, default=False, help="set as debug mode") + # for HiFi-Codec + parser.add_argument( + "--hificodec_model_path", + type=str, + default='pretrained_model/hificodec//HiFi-Codec-16k-320d') + parser.add_argument( + "--hificodec_config_path", + type=str, + default='pretrained_model/hificodec/config_16k_320d.json') + + # args for modify config + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, ) + + parser.add_argument( + "--train_with_iter", + type=str2bool, + default=False, + help="control training with epoch or iter") + args = parser.parse_args() + args.cwd = os.path.abspath(os.path.dirname(__file__)) + + new_train_semantic_dirs = [] + new_train_acoustic_dirs = [] + new_dev_semantic_dirs = [] + new_dev_acoustic_dirs = [] + # format dataset dirs + for item in args.train_semantic_dirs: + new_train_semantic_dirs.append(''.join(item)) + args.train_semantic_dirs = new_train_semantic_dirs + + for item in args.train_acoustic_dirs: + new_train_acoustic_dirs.append(''.join(item)) + args.train_acoustic_dirs = new_train_acoustic_dirs + + for item in args.dev_semantic_dirs: + new_dev_semantic_dirs.append(''.join(item)) + args.dev_semantic_dirs = new_dev_semantic_dirs + + for item in args.dev_acoustic_dirs: + new_dev_acoustic_dirs.append(''.join(item)) + args.dev_acoustic_dirs = new_dev_acoustic_dirs + + # modify args for debugging + if args.debug: + if args.gpu is None: + args.gpu = 0 + return args + + +def main(): + args = get_args() + if args.seed is not None or args.cudnn_deterministic: + seed_everything(args.seed, args.cudnn_deterministic) + if args.gpu is not None: + warnings.warn( + 'You have chosen a specific GPU. This will completely disable ddp.') + torch.cuda.set_device(args.gpu) + args.ngpus_per_node = 1 + args.world_size = 1 + else: + print('args.num_node ', args.num_node) + if args.num_node == 1: + args.dist_url == "auto" + else: + assert args.num_node > 1 + args.ngpus_per_node = torch.cuda.device_count() + args.world_size = args.ngpus_per_node * args.num_node # + launch( + main_worker, + args.ngpus_per_node, + args.num_node, + args.node_rank, + args.dist_url, + args=(args, )) + + +def main_worker(local_rank, args): + args.local_rank = local_rank + args.global_rank = args.local_rank + args.node_rank * args.ngpus_per_node + args.distributed = args.world_size > 1 + # load config + config = load_yaml_config(args.config_file) + # 合并命令行输入到 config 文件中 + config = merge_opts_to_config(config, args.opts) + if args.debug: + config = modify_config_for_debug(config) + # get logger + logger = Logger(args) + logger.save_config(config) + + # get model + model = build_model(config) + if args.sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + # for sample() + # NOTE by yuantian: all threads use some of memory of GPU 0 which need to be fixed + if local_rank == 0: + hificodec = VQVAE( + config_path=args.hificodec_config_path, + ckpt_path=args.hificodec_model_path, + with_encoder=True) + hificodec.generator.remove_weight_norm() + hificodec.encoder.remove_weight_norm() + hificodec.eval() + else: + hificodec = None + + # get dataloader + print("start build dataloader...") + dataloader_info = build_dataloader(config, args) + # get solver + if args.train_with_iter is True: + from soundstorm.s2.engine.solver_iter import Solver + print("import Solver from soundstorm.s2.engine.solver_iter...") + else: + from soundstorm.s2.engine.solver import Solver + print("import Solver from soundstorm.s2.engine.solver...") + + solver = Solver( + config=config, + args=args, + model=model, + dataloader=dataloader_info, + logger=logger, + hificodec=hificodec) + + # resume + # only load the model paramters + if args.load_path is not None: + solver.resume( + path=args.load_path, + # load_model=True, + load_optimizer_and_scheduler=False, + load_others=False) + elif args.auto_resume: + print("in auto_resume") + solver.resume() + solver.train() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + main() From e7b64ed4db2e0cbb2958a7ae4bc676423bc48c84 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 10 Aug 2023 12:32:11 +0000 Subject: [PATCH 3/8] use default --- egs_s2/LibriLight/conf/default.yaml | 37 ++++--- egs_s2/LibriLight/conf/small_medium_iter.yaml | 101 ------------------ egs_s2/LibriLight/local/train.sh | 2 +- .../{run_small_medium.sh => run.sh} | 5 +- soundstorm/s2/exps/train_librilight_6k.py | 7 +- 5 files changed, 32 insertions(+), 120 deletions(-) delete mode 100644 egs_s2/LibriLight/conf/small_medium_iter.yaml rename egs_s2/LibriLight/{run_small_medium.sh => run.sh} (95%) diff --git a/egs_s2/LibriLight/conf/default.yaml b/egs_s2/LibriLight/conf/default.yaml index f2bcf31..5d28d40 100644 --- a/egs_s2/LibriLight/conf/default.yaml +++ b/egs_s2/LibriLight/conf/default.yaml @@ -1,3 +1,4 @@ +# config for LibriLight 6k (small + medium) # 30k_basex1_hubert_L7km300 model: target: soundstorm.s2.models.dalle_wav.dalle_wav.DALLE @@ -27,6 +28,10 @@ model: timestep_type: 'adalayernorm' # adainsnorm or adalayernorm and abs mlp_hidden_times: 4 semantic_token_nums: 300 + prompt_semantic_emb_len: 10 # should > max_prompt_sec in dataset + target_semantic_emb_len: 30 # should > max_target_sec in dataset + prompt_acoustic_emb_len: 10 # can be same with prompt_semantic + target_acoustic_emb_len: 30 # can be same with target_semantic content_emb_config: target: soundstorm.s2.models.dalle_wav.mask_embedding.DalleMaskImageEmbedding params: @@ -37,11 +42,11 @@ model: pos_emb_type: embedding solver: - base_lr: 0.3e-05 # 3.0e-6 x 8 cause max_token_one_batch is x 8 (the old is 10k) - adjust_lr: none # not adjust lr according to total batch_size - max_epochs: 500 # 400 for LibriTTS (train-clean-100 + train-clean-360) 9.2k epoch for LJspeech - save_epochs: 1 - dev_epochs: 1 + base_lr: 0.3e-05 # 3.0e-6 x 8 cause max_token_one_batch is x 8 (the old is 10k) + adjust_lr: none # not adjust lr according to total batch_size + max_iters: 550000 # 55w iter for 8 GPU for small + medium, ~70 epochs, ~7.8k iter/epoch, training 2.5h/epoch + save_iters: 1500 # 1.5k, ~ cost 0.3h to save a ckpt + dev_iters: 1500 # num of iter for each gpu, for 8 gpu here, should x2 when use 4 gpus to make model see same number of samples ema: decay: 0.99 update_interval: 25 @@ -69,25 +74,29 @@ solver: threshold: 1.0e-1 threshold_mode: rel warmup_lr: 0.45e-3 # the lr to be touched after warmup - warmup: 800 # ~ 2 epoch + warmup: 800 # num of iter to warmup dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) num_workers: 2 train_datasets: # a list of configures, so we can combine several schedulers - - target: soundstorm.s2.data.semantic_dataset.SemanticDataset + - target: soundstorm.s2.data.semantic_dataset_librilight_6k.SemanticDataset params: codec_name: hificodec num_quant: 4 # not work when != 4 for hificodec, and can be 3 for soundstream and encodec - semantic_token_nums: 300 # 1000 for mhubert 500 for en_hubert - semantic_path: dump/train/semantic_token.tsv - acoustic_path: dump/train/acoustic_token/hificodec.pth + semantic_token_nums: 300 # same with num of kmeans bins + max_prompt_sec: 3 # be same with LibriTTS + max_target_sec: 20 # LibriTTS is 10, use 20 here for longer TTS + semantic_dirs: ['dump/small/train/'] + acoustic_dirs: ['dump/small/train/acoustic_token/'] dev_datasets: - - target: soundstorm.s2.data.semantic_dataset.SemanticDataset + - target: soundstorm.s2.data.semantic_dataset_librilight_6k.SemanticDataset params: codec_name: hificodec num_quant: 4 # not work when != 4 for hificodec, and can be 3 for soundstream and encodec - semantic_token_nums: 300 # 1000 for mhubert 500 for en_hubert - semantic_path: dump/dev/semantic_token.tsv - acoustic_path: dump/dev/acoustic_token/hificodec.pth + semantic_token_nums: 300 # same with num of kmeans bins + max_prompt_sec: 3 # be same with LibriTTS + max_target_sec: 20 # LibriTTS is 10, use 20 here for longer TTS + semantic_dirs: ['dump/small/dev/'] + acoustic_dirs: ['dump/small/dev/acoustic_token/'] diff --git a/egs_s2/LibriLight/conf/small_medium_iter.yaml b/egs_s2/LibriLight/conf/small_medium_iter.yaml deleted file mode 100644 index 1b46aec..0000000 --- a/egs_s2/LibriLight/conf/small_medium_iter.yaml +++ /dev/null @@ -1,101 +0,0 @@ -# 30k_basex1_hubert_L7km300 -model: - target: soundstorm.s2.models.dalle_wav.dalle_wav.DALLE - params: - content_info: {key: audio} - condition_info: {key: text} - n_q: 4 # the encodec codebook's number, same as `num_quant` of dataloader - diffusion_config: - target: soundstorm.s2.models.dalle_wav.diffusion_transformer.DiffusionTransformer - params: - diffusion_step: 100 - alpha_init_type: 'alpha1' # init_type = fix or cos or linear - auxiliary_loss_weight: 5.0e-4 - adaptive_auxiliary_loss: True - mask_weight: [1, 1] # the loss weight on mask region and non-mask region - transformer_config: - target: soundstorm.s2.models.dalle_wav.transformer_utils.Text2ImageTransformer - params: - attn_type: 'selfcross' # using self attention - n_layer: 16 # we may use large model - n_embd: 512 # the dim of embedding dims - condition_dim: 512 - n_head: 8 - attn_pdrop: 0.0 - resid_pdrop: 0.0 - block_activate: GELU2 - timestep_type: 'adalayernorm' # adainsnorm or adalayernorm and abs - mlp_hidden_times: 4 - semantic_token_nums: 300 - prompt_semantic_emb_len: 10 # should > max_prompt_sec in dataset - target_semantic_emb_len: 30 # should > max_target_sec in dataset - prompt_acoustic_emb_len: 10 # can be same with prompt_semantic - target_acoustic_emb_len: 30 # can be same with target_semantic - content_emb_config: - target: soundstorm.s2.models.dalle_wav.mask_embedding.DalleMaskImageEmbedding - params: - num_embed: 1026 # should be quantize_number - max_size: 16000 - embed_dim: 512 # the dim of postion embedding - trainable: True - pos_emb_type: embedding - -solver: - base_lr: 0.3e-05 # 3.0e-6 x 8 cause max_token_one_batch is x 8 (the old is 10k) - adjust_lr: none # not adjust lr according to total batch_size - max_iters: 550000 # 55w iter for 8 GPU for small + medium, ~70 epochs, ~7.8k iter/epoch, training 2.5h/epoch - save_iters: 1500 # 1.5k, ~ cost 0.3h to save a ckpt - dev_iters: 1500 # num of iter for each gpu, for 8 gpu here, should x2 when use 4 gpus to make model see same number of samples - ema: - decay: 0.99 - update_interval: 25 - device: cpu - clip_grad_norm: - target: soundstorm.s2.engine.clip_grad_norm.ClipGradNorm - params: - start_iteration: 0 - end_iteration: 5000 - max_norm: 0.5 - optimizers_and_schedulers: # a list of configures, so we can config several optimizers and schedulers - - name: none # default is None - optimizer: - target: torch.optim.AdamW - params: - betas: !!python/tuple [0.9, 0.96] - weight_decay: 1.0e-2 - scheduler: - step_iteration: 1 - target: soundstorm.s2.engine.lr_scheduler.ReduceLROnPlateauWithWarmup - params: - factor: 0.5 - patience: 25000 - min_lr: 1.0e-06 - threshold: 1.0e-1 - threshold_mode: rel - warmup_lr: 0.45e-3 # the lr to be touched after warmup - warmup: 800 # num of iter to warmup - -dataloader: - max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) - num_workers: 2 - train_datasets: # a list of configures, so we can combine several schedulers - - target: soundstorm.s2.data.semantic_dataset_librilight_6k.SemanticDataset - params: - codec_name: hificodec - num_quant: 4 # not work when != 4 for hificodec, and can be 3 for soundstream and encodec - semantic_token_nums: 300 # same with num of kmeans bins - max_prompt_sec: 3 # be same with LibriTTS - max_target_sec: 20 # LibriTTS is 10, use 20 here for longer TTS - semantic_dirs: ['dump/small/train/'] - acoustic_dirs: ['dump/small/train/acoustic_token/'] - - dev_datasets: - - target: soundstorm.s2.data.semantic_dataset_librilight_6k.SemanticDataset - params: - codec_name: hificodec - num_quant: 4 # not work when != 4 for hificodec, and can be 3 for soundstream and encodec - semantic_token_nums: 300 # same with num of kmeans bins - max_prompt_sec: 3 # be same with LibriTTS - max_target_sec: 20 # LibriTTS is 10, use 20 here for longer TTS - semantic_dirs: ['dump/small/dev/'] - acoustic_dirs: ['dump/small/dev/acoustic_token/'] diff --git a/egs_s2/LibriLight/local/train.sh b/egs_s2/LibriLight/local/train.sh index 1c4a2ca..02b20cb 100755 --- a/egs_s2/LibriLight/local/train.sh +++ b/egs_s2/LibriLight/local/train.sh @@ -10,7 +10,7 @@ dump_dir=$6 # 注意 *_dirs 参数后面不可以有 ''=' python3 ${BIN_DIR}/train_librilight_6k.py \ --config_file=${config_path} \ - --train_semantic_dirs ''${root_dir}'/'${dump_dir}'/small/train/' ''${root_dir}'/'${dump_dir}'/large/train/' \ + --train_semantic_dirs ''${root_dir}'/'${dump_dir}'/small/train/' ''${root_dir}'/'${dump_dir}'/medium/train/' \ --train_acoustic_dirs ''${root_dir}'/'${dump_dir}'/small/train/acoustic_token/' ''${root_dir}'/'${dump_dir}'/medium/train/acoustic_token/' \ --dev_semantic_dirs ''${root_dir}'/'${dump_dir}'/small/dev/' ''${root_dir}'/'${dump_dir}'/medium/dev/' \ --dev_acoustic_dirs ''${root_dir}'/'${dump_dir}'/small/dev/acoustic_token/' ''${root_dir}'/'${dump_dir}'/medium/dev/acoustic_token/' \ diff --git a/egs_s2/LibriLight/run_small_medium.sh b/egs_s2/LibriLight/run.sh similarity index 95% rename from egs_s2/LibriLight/run_small_medium.sh rename to egs_s2/LibriLight/run.sh index 0af1cdd..f7acc43 100755 --- a/egs_s2/LibriLight/run_small_medium.sh +++ b/egs_s2/LibriLight/run.sh @@ -1,5 +1,6 @@ #!/bin/bash # run_base_L7_km300 +# train LibriLight 6k (small + medium) by default set -e source path.sh @@ -7,12 +8,12 @@ source path.sh gpus=0,1,2,3 stage=0 stop_stage=100 -train_output_path='exp_librilight/small_medium' +train_output_path='exp_librilight/default' # dir to set part/all of dump dataset and experiment result root_dir='/nfs-speech-cpfs/dev/yuantian04/Vivid_TTS/SoundStorm/SoundStorm/SoundStorm' # there should be *.wav 、*/*.wav or */*/*.wav in data_dir data_dir='~/datasets/LibriLight' -config_path='conf/small_medium_iter.yaml' +config_path='conf/default.yaml' log_frequency=1 # 'tcp://%s:%s' % (MASTER_ADDR, MASTER_PORT) dist_url='tcp://127.0.0.1:29505' diff --git a/soundstorm/s2/exps/train_librilight_6k.py b/soundstorm/s2/exps/train_librilight_6k.py index 6d91e90..f3006fd 100644 --- a/soundstorm/s2/exps/train_librilight_6k.py +++ b/soundstorm/s2/exps/train_librilight_6k.py @@ -5,6 +5,7 @@ # ------------------------------------------ import argparse import os +import time import warnings import torch @@ -234,7 +235,7 @@ def main_worker(local_rank, args): # get logger logger = Logger(args) logger.save_config(config) - + # get model model = build_model(config) if args.sync_bn: @@ -251,10 +252,12 @@ def main_worker(local_rank, args): hificodec.eval() else: hificodec = None - + # get dataloader print("start build dataloader...") + start_build_time = time.time() dataloader_info = build_dataloader(config, args) + print(f"time of build dataloader: {time.time() - start_build_time}") # get solver if args.train_with_iter is True: from soundstorm.s2.engine.solver_iter import Solver From 6d1d6cb183d7c0edc9248f5e563124333ace3c26 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 11 Aug 2023 03:08:31 +0000 Subject: [PATCH 4/8] add max_sec control when synthesize --- egs_s2/LibriLight/local/test.sh | 18 +++- egs_s2/LibriLight/run.sh | 2 +- egs_s2/LibriTTS/run.sh | 2 +- egs_s2/LibriTTS/run_base_L7_km300.sh | 5 +- egs_s2/LibriTTS/run_base_L7_km300_iter.sh | 5 +- soundstorm/s2/exps/synthesize.py | 39 ++++++-- soundstorm/s2/exps/test.py | 109 ++++++++++++++++------ 7 files changed, 131 insertions(+), 49 deletions(-) mode change 120000 => 100755 egs_s2/LibriLight/local/test.sh diff --git a/egs_s2/LibriLight/local/test.sh b/egs_s2/LibriLight/local/test.sh deleted file mode 120000 index 48bd482..0000000 --- a/egs_s2/LibriLight/local/test.sh +++ /dev/null @@ -1 +0,0 @@ -../../LibriTTS/local/test.sh \ No newline at end of file diff --git a/egs_s2/LibriLight/local/test.sh b/egs_s2/LibriLight/local/test.sh new file mode 100755 index 0000000..f06871e --- /dev/null +++ b/egs_s2/LibriLight/local/test.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# test with test set + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +root_dir=$4 +dump_dir=$5 + +python3 ${BIN_DIR}/test.py \ + --config_file=${config_path} \ + --ckpt_path=${root_dir}/${train_output_path}/checkpoint/${ckpt_name} \ + --test_semantic_path=${root_dir}/${dump_dir}/small/test/semantic_token_0_3.tsv \ + --test_acoustic_path=${root_dir}/${dump_dir}/small/test/acoustic_token/hificodec_0_3.pth \ + --output_dir=${root_dir}/${train_output_path}/test_output \ + --hificodec_model_path=pretrained_model/hificodec/HiFi-Codec-16k-320d-large-universal \ + --hificodec_config_path=pretrained_model/hificodec/config_16k_320d.json \ No newline at end of file diff --git a/egs_s2/LibriLight/run.sh b/egs_s2/LibriLight/run.sh index f7acc43..d5aebc6 100755 --- a/egs_s2/LibriLight/run.sh +++ b/egs_s2/LibriLight/run.sh @@ -18,7 +18,7 @@ log_frequency=1 # 'tcp://%s:%s' % (MASTER_ADDR, MASTER_PORT) dist_url='tcp://127.0.0.1:29505' # use which checkpoint file to test -ckpt_name='000301e_471119iter.pth' +ckpt_name='33000iter.pth' # should be same with ${layer} in hubert_kms.sh layer=7 # should be same with ${hubert_path} in hubert_kms.sh diff --git a/egs_s2/LibriTTS/run.sh b/egs_s2/LibriTTS/run.sh index b76a580..a17c244 100755 --- a/egs_s2/LibriTTS/run.sh +++ b/egs_s2/LibriTTS/run.sh @@ -61,5 +61,5 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh \ ${config_path} ${train_output_path} ${ckpt_name} ${root_dir} \ ${hubert_path} ${quantizer_path} ${prompt_wav_path} \ - ${S1_config_file} ${S1_ckpt_path} ${sil_token}|| exit -1 + ${S1_config_file} ${S1_ckpt_path} ${sil_token} || exit -1 fi diff --git a/egs_s2/LibriTTS/run_base_L7_km300.sh b/egs_s2/LibriTTS/run_base_L7_km300.sh index 3c604a1..f78df03 100755 --- a/egs_s2/LibriTTS/run_base_L7_km300.sh +++ b/egs_s2/LibriTTS/run_base_L7_km300.sh @@ -24,8 +24,7 @@ hubert_path=pretrained_model/hubert/hubert_base_ls960.pt quantizer_path=pretrained_model/hubert/train-clean-360_hubert_base_ls960_L7_km300.bin dump_dir=dump_libritts_universal_hificodec # for synthesize_e2e.sh -prompt_wav_path='/nfs-speech-cpfs/dev/yuantian04/Vivid_TTS/SoundStorm/SoundStorm/SoundStorm/dump_l -ibritts_base_L9_km500/test/synthesize_input/1006_135212_000060_000004.wav' +prompt_wav_path='/nfs-speech-cpfs/dev/yuantian04/Vivid_TTS/SoundStorm/SoundStorm/SoundStorm/dump_libritts_base_L9_km500/test/synthesize_input/1006_135212_000060_000004.wav' S1_config_file='../../egs_s1/AR/LibriTTS/conf/base_L7bin300.yaml' S1_ckpt_path='/nfs-speech-cpfs/dev/yuantian04/Vivid_TTS/SoundStorm/SoundStorm/ar_s1/SoundStorm/exp/base_L7_km300/ckpt/epoch=99-step=49000.ckpt' # 4 for 300 bin, you should modify this due to your own dum data @@ -61,5 +60,5 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh \ ${config_path} ${train_output_path} ${ckpt_name} ${root_dir} \ ${hubert_path} ${quantizer_path} ${prompt_wav_path} \ - ${S1_config_file} ${S1_ckpt_path} ${sil_token}|| exit -1 + ${S1_config_file} ${S1_ckpt_path} ${sil_token} || exit -1 fi diff --git a/egs_s2/LibriTTS/run_base_L7_km300_iter.sh b/egs_s2/LibriTTS/run_base_L7_km300_iter.sh index ddfbbac..7199dd0 100755 --- a/egs_s2/LibriTTS/run_base_L7_km300_iter.sh +++ b/egs_s2/LibriTTS/run_base_L7_km300_iter.sh @@ -24,8 +24,7 @@ hubert_path=pretrained_model/hubert/hubert_base_ls960.pt quantizer_path=pretrained_model/hubert/train-clean-360_hubert_base_ls960_L7_km300.bin dump_dir=dump_libritts_universal_hificodec # for synthesize_e2e.sh -prompt_wav_path='/nfs-speech-cpfs/dev/yuantian04/Vivid_TTS/SoundStorm/SoundStorm/SoundStorm/dump_l -ibritts_base_L9_km500/test/synthesize_input/1006_135212_000060_000004.wav' +prompt_wav_path='/nfs-speech-cpfs/dev/yuantian04/Vivid_TTS/SoundStorm/SoundStorm/SoundStorm/dump_libritts_base_L9_km500/test/synthesize_input/1006_135212_000060_000004.wav' S1_config_file='../../egs_s1/AR/LibriTTS/conf/base_L7bin300.yaml' S1_ckpt_path='/nfs-speech-cpfs/dev/yuantian04/Vivid_TTS/SoundStorm/SoundStorm/ar_s1/SoundStorm/exp/base_L7_km300/ckpt/epoch=99-step=49000.ckpt' # 4 for 300 bin, you should modify this due to your own dum data @@ -61,5 +60,5 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh \ ${config_path} ${train_output_path} ${ckpt_name} ${root_dir} \ ${hubert_path} ${quantizer_path} ${prompt_wav_path} \ - ${S1_config_file} ${S1_ckpt_path} ${sil_token}|| exit -1 + ${S1_config_file} ${S1_ckpt_path} ${sil_token} || exit -1 fi diff --git a/soundstorm/s2/exps/synthesize.py b/soundstorm/s2/exps/synthesize.py index 26be398..081edd3 100644 --- a/soundstorm/s2/exps/synthesize.py +++ b/soundstorm/s2/exps/synthesize.py @@ -51,22 +51,23 @@ def hificodec_decode(hificodec, acoustic_token, rescale=True): def get_batch(prompt_semantic_tokens, prompt_acoustic_tokens, target_semantic_tokens, - num_quant=4, - hz=50): + num_quant: int=4, + hz: int=50, + max_prompt_sec: int=3, + max_target_sec: int=10): # transformer_utils.py 里面最大是 20, pad 了一个 stop token, 所以这里最大是 19 # 但是训练时最多是 10s, 所以超过 10s 的无法合成出来 - max_sec = 10 # prompt 最多为 3s - if prompt_acoustic_tokens.shape[1] > 6 * hz: - prompt_len = 3 * hz + if prompt_acoustic_tokens.shape[1] > 2 * max_prompt_sec * hz: + prompt_len = max_prompt_sec * hz else: prompt_len = prompt_acoustic_tokens.shape[1] // 2 prompt_semantic_tokens = prompt_semantic_tokens[:, :prompt_len] prompt_acoustic_tokens = prompt_acoustic_tokens[:, :prompt_len] # target 最多为 10s - target_semantic_tokens = target_semantic_tokens[:, :max_sec * hz] + target_semantic_tokens = target_semantic_tokens[:, :max_target_sec * hz] # acoustic_token 和 semantic_token 长度是对齐的 target_T = target_semantic_tokens.shape[-1] # 伪造的 target_acoustics_tokens @@ -86,8 +87,14 @@ def get_batch(prompt_semantic_tokens, return samples -def evaluate(args, hificodec, soundstorm, semantic_tokenizer=None): - num_quant = 4 +def evaluate(args, + hificodec, + soundstorm, + semantic_tokenizer=None, + num_quant: int=4, + max_prompt_sec: int=3, + max_target_sec: int=10): + sample_rate = 16000 hz = 50 output_dir = Path(args.output_dir) @@ -141,7 +148,9 @@ def evaluate(args, hificodec, soundstorm, semantic_tokenizer=None): prompt_acoustic_tokens=prompt_acoustic_tokens, target_semantic_tokens=target_semantic_tokens, num_quant=num_quant, - hz=hz) + hz=hz, + max_prompt_sec=max_prompt_sec, + max_target_sec=max_target_sec) batch = move_tensors_to_cuda(batch) @@ -250,7 +259,17 @@ def main(): duplicate=True) # cost 14s for a 10s target - evaluate(args, hificodec, soundstorm, semantic_tokenizer) + num_quant = 4 + max_prompt_sec = 3 + max_target_sec = 10 + evaluate( + args, + hificodec, + soundstorm, + semantic_tokenizer, + num_quant=num_quant, + max_prompt_sec=max_prompt_sec, + max_target_sec=max_target_sec) if __name__ == "__main__": diff --git a/soundstorm/s2/exps/test.py b/soundstorm/s2/exps/test.py index 734027f..9839046 100644 --- a/soundstorm/s2/exps/test.py +++ b/soundstorm/s2/exps/test.py @@ -29,7 +29,7 @@ def move_tensors_to_cuda(d): return d -def hificodec_decode(hificodec, acoustic_token, rescale=True): +def hificodec_decode(hificodec, acoustic_token, rescale: bool=True): """ acoustic_token: shape [B, Nq, T] """ @@ -51,7 +51,13 @@ def hificodec_decode(hificodec, acoustic_token, rescale=True): return wav -def get_one_sample(acoustic_data, semantic_data, index, num_quant, hz): +def get_one_sample(acoustic_data, + semantic_data, + index: int, + num_quant: int=4, + hz: int=50, + max_prompt_sec: int=3, + max_target_sec: int=10): ''' 一条数据构成一个 batch (1) 若总长度大于 6s, 前 3s 为 prompt, 剩余为 target @@ -70,20 +76,21 @@ def get_one_sample(acoustic_data, semantic_data, index, num_quant, hz): # shape (4, T) # acoustic_tokens 的 T 与 semantic_tokens 的 T 可能有误差 acoustic_tokens = acoustic_str[:num_quant, ...] - if acoustic_tokens.shape[1] > 6 * hz: - prompt_len = 3 * hz + if acoustic_tokens.shape[1] > 2 * max_prompt_sec * hz: + prompt_len = max_prompt_sec * hz else: prompt_len = acoustic_tokens.shape[1] // 2 prompt_acoustic_tokens = acoustic_tokens[:, :prompt_len] prompt_semantic_tokens = semantic_tokens[:, :prompt_len] - target_semantic_tokens = semantic_tokens[:, prompt_len:prompt_len + 10 * hz] + target_semantic_tokens = semantic_tokens[:, prompt_len:prompt_len + + max_target_sec * hz] prompt_semantic_tokens = prompt_semantic_tokens target_semantic_tokens = target_semantic_tokens prompt_acoustic_tokens = prompt_acoustic_tokens - target_acoustics_tokens = acoustic_tokens[:, prompt_len:prompt_len + 10 * - hz] + target_acoustics_tokens = acoustic_tokens[:, prompt_len:prompt_len + + max_target_sec * hz] target_acoustics_tokens = target_acoustics_tokens result = {} @@ -96,13 +103,21 @@ def get_one_sample(acoustic_data, semantic_data, index, num_quant, hz): # one wav per batch -def get_batch(acoustic_data, semantic_data, index, num_quant=4, hz=50): +def get_batch(acoustic_data, + semantic_data, + index: int, + num_quant: int=4, + hz: int=50, + max_prompt_sec: int=3, + max_target_sec: int=10): result = get_one_sample( acoustic_data=acoustic_data, semantic_data=semantic_data, index=index, num_quant=num_quant, - hz=hz) + hz=hz, + max_prompt_sec=max_prompt_sec, + max_target_sec=max_target_sec) prompt_acoustic_tokens = result['prompt_acoustic_tokens'] prompt_semantic_tokens = result['prompt_semantic_tokens'] target_semantic_tokens = result['target_semantic_tokens'] @@ -119,14 +134,15 @@ def get_batch(acoustic_data, semantic_data, index, num_quant=4, hz=50): return samples -def get_big_batch( - acoustic_data, - semantic_data, - index_list, - prompt_semantic_end_id, - target_semantic_end_id, - num_quant=4, - hz=50, ): +def get_big_batch(acoustic_data, + semantic_data, + index_list, + prompt_semantic_end_id: int, + target_semantic_end_id: int, + num_quant: int=4, + hz: int=50, + max_prompt_sec: int=3, + max_target_sec: int=10): tmp_prompt_semantics = [] tmp_target_semantics = [] tmp_prompt_acoustics = [] @@ -137,7 +153,9 @@ def get_big_batch( semantic_data=semantic_data, index=index, num_quant=num_quant, - hz=hz) + hz=hz, + max_prompt_sec=max_prompt_sec, + max_target_sec=max_target_sec) prompt_semantic = result['prompt_semantic_tokens'] target_semantic = result['target_semantic_tokens'] @@ -172,8 +190,13 @@ def get_big_batch( # evaluate one wav per batch -def evaluate(args, hificodec, soundstorm): - num_quant = 4 +def evaluate(args, + hificodec, + soundstorm, + num_quant: int=4, + max_prompt_sec: int=3, + max_target_sec: int=10): + sample_rate = 16000 hz = 50 output_dir = Path(args.output_dir) @@ -185,7 +208,13 @@ def evaluate(args, hificodec, soundstorm): for index, utt_id in enumerate(semantic_data['item_name'][:20]): # 需要处理 item_name 不在 acoustic_data 中的情况 batch = get_batch( - acoustic_data, semantic_data, index, num_quant=num_quant, hz=hz) + acoustic_data, + semantic_data, + index, + num_quant=num_quant, + hz=hz, + max_prompt_sec=max_prompt_sec, + max_target_sec=max_target_sec) batch = move_tensors_to_cuda(batch) # some wrong with this index od data if batch is None: @@ -211,13 +240,16 @@ def evaluate(args, hificodec, soundstorm): def evaluate_batch(args, hificodec, soundstorm, - prompt_semantic_end_id, - target_semantic_end_id, - batch_size=2): + prompt_semantic_end_id: int, + target_semantic_end_id: int, + batch_size: int=2, + num_quant: int=4, + max_prompt_sec: int=3, + max_target_sec: int=10): # 按照顺序读取测试集,若干调音频组成一个 batch - num_quant = 4 sample_rate = 16000 hz = 50 + output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) acoustic_data = torch.load(args.test_acoustic_path) @@ -236,7 +268,7 @@ def evaluate_batch(args, for i in range(0, len(all_indexs), batch_size) ] - for i, index_list in enumerate(index_lists): + for i, index_list in enumerate(index_lists[:10]): utt_ids = utt_id_lists[i] batch = get_big_batch( acoustic_data, @@ -244,8 +276,10 @@ def evaluate_batch(args, index_list, prompt_semantic_end_id=prompt_semantic_end_id, target_semantic_end_id=target_semantic_end_id, - num_quant=4, - hz=hz, ) + num_quant=num_quant, + hz=hz, + max_prompt_sec=max_prompt_sec, + max_target_sec=max_target_sec) batch = move_tensors_to_cuda(batch) # some wrong with this index od data if batch is None: @@ -345,18 +379,33 @@ def main(): semantic_token_nums = config['dataloader']['train_datasets'][0]['params'][ 'semantic_token_nums'] + num_quant = config['dataloader']['train_datasets'][0]['params']['num_quant'] + max_prompt_sec = config['dataloader']['train_datasets'][0]['params'].get( + 'max_prompt_sec', 3) + max_target_sec = config['dataloader']['train_datasets'][0]['params'].get( + 'max_target_sec', 10) + prompt_semantic_end_id = semantic_token_nums + 1 target_semantic_end_id = semantic_token_nums + 3 # cost 14s for a 10s target - evaluate(args, hificodec, soundstorm) + evaluate( + args, + hificodec, + soundstorm, + num_quant=num_quant, + max_prompt_sec=max_prompt_sec, + max_target_sec=max_target_sec) # evaluate_batch( # args, # hificodec=hificodec, # soundstorm=soundstorm, # prompt_semantic_end_id=prompt_semantic_end_id, # target_semantic_end_id=target_semantic_end_id, - # batch_size=2) + # batch_size=2, + # num_quant=num_quant, + # max_prompt_sec=max_prompt_sec, + # max_target_sec=max_target_sec) if __name__ == "__main__": From 295163a5799ebb34b9f22e6376728236e4f66dc9 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 11 Aug 2023 04:47:41 +0000 Subject: [PATCH 5/8] add prefetch_factor control, set OMP_NUM_THREADS when world_size>1 in launch.py --- egs_s2/LibriLight/conf/default.yaml | 3 ++- soundstorm/s2/data/build.py | 7 +++++-- soundstorm/s2/data/build_librilight_6k.py | 7 +++++-- soundstorm/s2/distributed/launch.py | 5 +++-- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/egs_s2/LibriLight/conf/default.yaml b/egs_s2/LibriLight/conf/default.yaml index 5d28d40..c952378 100644 --- a/egs_s2/LibriLight/conf/default.yaml +++ b/egs_s2/LibriLight/conf/default.yaml @@ -78,7 +78,8 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) - num_workers: 2 + num_workers: 32 + prefetch_factor: 50 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset_librilight_6k.SemanticDataset params: diff --git a/soundstorm/s2/data/build.py b/soundstorm/s2/data/build.py index 898566c..1784658 100644 --- a/soundstorm/s2/data/build.py +++ b/soundstorm/s2/data/build.py @@ -48,6 +48,7 @@ def build_dataloader(config, args=None, return_dataset=False): train_iters = len(train_dataset) // batch_size dev_iters = len(dev_dataset) // batch_size num_workers = dataset_cfg['num_workers'] + prefetch_factor = dataset_cfg.get('prefetch_factor', 2) persistent_workers = True if num_workers > 0 else False train_loader = torch.utils.data.DataLoader( train_dataset, @@ -60,7 +61,8 @@ def build_dataloader(config, args=None, return_dataset=False): collate_fn=train_dataset.collater, persistent_workers=persistent_workers, # 解决 num_workers>0 时的 bad value(s) in fds_to_keep 报错 - multiprocessing_context='fork') + multiprocessing_context='fork', + prefetch_factor=prefetch_factor) dev_loader = torch.utils.data.DataLoader( dev_dataset, @@ -73,7 +75,8 @@ def build_dataloader(config, args=None, return_dataset=False): pin_memory=True, collate_fn=train_dataset.collater, persistent_workers=persistent_workers, - multiprocessing_context='fork') + multiprocessing_context='fork', + prefetch_factor=prefetch_factor) dataload_info = { 'train_loader': train_loader, diff --git a/soundstorm/s2/data/build_librilight_6k.py b/soundstorm/s2/data/build_librilight_6k.py index e8dfa22..5cd8ad1 100644 --- a/soundstorm/s2/data/build_librilight_6k.py +++ b/soundstorm/s2/data/build_librilight_6k.py @@ -48,6 +48,7 @@ def build_dataloader(config, args=None, return_dataset=False): train_iters = len(train_dataset) // batch_size dev_iters = len(dev_dataset) // batch_size num_workers = dataset_cfg['num_workers'] + prefetch_factor = dataset_cfg.get('prefetch_factor', 2) persistent_workers = True if num_workers > 0 else False train_loader = torch.utils.data.DataLoader( train_dataset, @@ -60,7 +61,8 @@ def build_dataloader(config, args=None, return_dataset=False): collate_fn=train_dataset.collater, persistent_workers=persistent_workers, # 解决 num_workers>0 时的 bad value(s) in fds_to_keep 报错 - multiprocessing_context='fork') + multiprocessing_context='fork', + prefetch_factor=prefetch_factor) dev_loader = torch.utils.data.DataLoader( dev_dataset, @@ -73,7 +75,8 @@ def build_dataloader(config, args=None, return_dataset=False): pin_memory=True, collate_fn=train_dataset.collater, persistent_workers=persistent_workers, - multiprocessing_context='fork') + multiprocessing_context='fork', + prefetch_factor=prefetch_factor) dataload_info = { 'train_loader': train_loader, diff --git a/soundstorm/s2/distributed/launch.py b/soundstorm/s2/distributed/launch.py index c0fdf86..dc3792d 100644 --- a/soundstorm/s2/distributed/launch.py +++ b/soundstorm/s2/distributed/launch.py @@ -2,6 +2,7 @@ # Diffsound # code based https://github.com/cientgu/VQ-Diffusion # ------------------------------------------ +import os import soundstorm.s2.distributed.distributed as dist_fn import torch from torch import distributed as dist @@ -29,8 +30,8 @@ def launch(fn, world_size = n_machine * n_gpu_per_machine if world_size > 1: - # if "OMP_NUM_THREADS" not in os.environ: - # os.environ["OMP_NUM_THREADS"] = "1" + if "OMP_NUM_THREADS" not in os.environ: + os.environ["OMP_NUM_THREADS"] = "1" if dist_url == "auto": if n_machine != 1: From ffc2198a61ccf80ac395365bb338c73b5c086b40 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 11 Aug 2023 05:40:09 +0000 Subject: [PATCH 6/8] fix config --- egs_s2/LibriLight/conf/default.yaml | 2 +- egs_s2/LibriLight/local/train.sh | 4 +++- egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml | 2 +- egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml | 2 +- egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml | 2 +- egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml | 2 +- egs_s2/LibriTTS/local/train.sh | 4 +++- egs_s2/LibriTTS/local/train_iter.sh | 4 +++- 8 files changed, 14 insertions(+), 8 deletions(-) diff --git a/egs_s2/LibriLight/conf/default.yaml b/egs_s2/LibriLight/conf/default.yaml index c952378..4763d9e 100644 --- a/egs_s2/LibriLight/conf/default.yaml +++ b/egs_s2/LibriLight/conf/default.yaml @@ -78,7 +78,7 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) - num_workers: 32 + num_workers: 4 prefetch_factor: 50 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset_librilight_6k.SemanticDataset diff --git a/egs_s2/LibriLight/local/train.sh b/egs_s2/LibriLight/local/train.sh index 02b20cb..a3120e7 100755 --- a/egs_s2/LibriLight/local/train.sh +++ b/egs_s2/LibriLight/local/train.sh @@ -7,8 +7,10 @@ log_frequency=$4 dist_url=$5 dump_dir=$6 +opm_num=8 + # 注意 *_dirs 参数后面不可以有 ''=' -python3 ${BIN_DIR}/train_librilight_6k.py \ +OMP_NUM_THREADS=${opm_num} python3 ${BIN_DIR}/train_librilight_6k.py \ --config_file=${config_path} \ --train_semantic_dirs ''${root_dir}'/'${dump_dir}'/small/train/' ''${root_dir}'/'${dump_dir}'/medium/train/' \ --train_acoustic_dirs ''${root_dir}'/'${dump_dir}'/small/train/acoustic_token/' ''${root_dir}'/'${dump_dir}'/medium/train/acoustic_token/' \ diff --git a/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml b/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml index 5f9b53d..f31e543 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml @@ -40,7 +40,7 @@ model: solver: base_lr: 0.3e-05 # 3.0e-6 x 8 cause max_token_one_batch is x 8 (the old is 10k) adjust_lr: none # not adjust lr according to total batch_size - max_epochs: 500 # 400 for LibriTTS (train-clean-100 + train-clean-360) 9.2k epoch for LJspeech + max_epochs: 400 # 400 for LibriTTS (train-clean-100 + train-clean-360) 9.2k epoch for LJspeech save_epochs: 1 dev_epochs: 1 ema: diff --git a/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml b/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml index ec3465f..377930d 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml @@ -40,7 +40,7 @@ model: solver: base_lr: 0.3e-05 # 3.0e-6 x 8 cause max_token_one_batch is x 8 (the old is 10k) adjust_lr: none # not adjust lr according to total batch_size - max_epochs: 500 # 400 for LibriTTS (train-clean-100 + train-clean-360) 9.2k epoch for LJspeech + max_epochs: 400 # 400 for LibriTTS (train-clean-100 + train-clean-360) 9.2k epoch for LJspeech save_epochs: 1 dev_epochs: 1 ema: diff --git a/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml b/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml index 8c45518..c404ade 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml @@ -40,7 +40,7 @@ model: solver: base_lr: 0.6e-05 # 3.0e-6 x 8 cause max_token_one_batch is x 8 (the old is 10k) adjust_lr: none # not adjust lr according to total batch_size - max_epochs: 500 # 400 for LibriTTS (train-clean-100 + train-clean-360) 9.2k epoch for LJspeech + max_epochs: 400 # 400 for LibriTTS (train-clean-100 + train-clean-360) 9.2k epoch for LJspeech save_epochs: 1 dev_epochs: 1 ema: diff --git a/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml b/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml index 1cda37b..842010f 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml @@ -40,7 +40,7 @@ model: solver: base_lr: 0.6e-05 # 3.0e-6 x 8 cause max_token_one_batch is x 8 (the old is 10k) adjust_lr: none # not adjust lr according to total batch_size - max_epochs: 500 # 400 for LibriTTS (train-clean-100 + train-clean-360) 9.2k epoch for LJspeech + max_epochs: 400 # 400 for LibriTTS (train-clean-100 + train-clean-360) 9.2k epoch for LJspeech save_epochs: 1 dev_epochs: 1 ema: diff --git a/egs_s2/LibriTTS/local/train.sh b/egs_s2/LibriTTS/local/train.sh index eae6405..cd41001 100755 --- a/egs_s2/LibriTTS/local/train.sh +++ b/egs_s2/LibriTTS/local/train.sh @@ -7,7 +7,9 @@ log_frequency=$4 dist_url=$5 dump_dir=$6 -python3 ${BIN_DIR}/train.py \ +opm_num=8 + +OMP_NUM_THREADS=${opm_num} python3 ${BIN_DIR}/train.py \ --config_file=${config_path} \ --train_semantic_path=${root_dir}/${dump_dir}/train/semantic_token.tsv \ --train_acoustic_path=${root_dir}/${dump_dir}/train/acoustic_token/hificodec.pth \ diff --git a/egs_s2/LibriTTS/local/train_iter.sh b/egs_s2/LibriTTS/local/train_iter.sh index 9a9a8c5..6a90bf6 100755 --- a/egs_s2/LibriTTS/local/train_iter.sh +++ b/egs_s2/LibriTTS/local/train_iter.sh @@ -7,7 +7,9 @@ log_frequency=$4 dist_url=$5 dump_dir=$6 -python3 ${BIN_DIR}/train.py \ +opm_num=8 + +OMP_NUM_THREADS=${opm_num} python3 ${BIN_DIR}/train.py \ --config_file=${config_path} \ --train_semantic_path=${root_dir}/${dump_dir}/train/semantic_token.tsv \ --train_acoustic_path=${root_dir}/${dump_dir}/train/acoustic_token/hificodec.pth \ From ee758b37900f0684426e66eacb3ac9920e51397b Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 11 Aug 2023 06:02:04 +0000 Subject: [PATCH 7/8] fix num_worker of LibriTTS --- egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml | 2 +- egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml | 2 +- egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml | 2 +- egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml | 2 +- soundstorm/s2/exps/train.py | 4 ++++ 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml b/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml index f31e543..236bab7 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml @@ -74,7 +74,7 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) - num_workers: 2 + num_workers: 4 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset.SemanticDataset params: diff --git a/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml b/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml index 377930d..81682d6 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml @@ -74,7 +74,7 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) - num_workers: 2 + num_workers: 4 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset.SemanticDataset params: diff --git a/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml b/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml index c404ade..5d7f8c6 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml @@ -74,7 +74,7 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) - num_workers: 2 + num_workers: 4 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset.SemanticDataset params: diff --git a/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml b/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml index 842010f..16efc03 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml @@ -74,7 +74,7 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) - num_workers: 2 + num_workers: 4 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset.SemanticDataset params: diff --git a/soundstorm/s2/exps/train.py b/soundstorm/s2/exps/train.py index c37c21d..2181b02 100644 --- a/soundstorm/s2/exps/train.py +++ b/soundstorm/s2/exps/train.py @@ -5,6 +5,7 @@ # ------------------------------------------ import argparse import os +import time import warnings import torch @@ -221,7 +222,10 @@ def main_worker(local_rank, args): hificodec = None # get dataloader + print("start build dataloader...") + start_build_time = time.time() dataloader_info = build_dataloader(config, args) + print(f"time of build dataloader: {time.time() - start_build_time}") # get solver if args.train_with_iter is True: from soundstorm.s2.engine.solver_iter import Solver From 5db1ffa7262701fcd6a2b8977e3f24c836fbe074 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 11 Aug 2023 06:12:47 +0000 Subject: [PATCH 8/8] add prefetch_factor control for LibriTTS --- egs_s2/LJSpeech/conf/default.yaml | 3 ++- egs_s2/LibriLight/local/train.sh | 4 ++-- egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml | 1 + egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml | 1 + egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml | 1 + egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml | 1 + egs_s2/LibriTTS/local/train.sh | 4 ++-- egs_s2/LibriTTS/local/train_iter.sh | 4 ++-- 8 files changed, 12 insertions(+), 7 deletions(-) diff --git a/egs_s2/LJSpeech/conf/default.yaml b/egs_s2/LJSpeech/conf/default.yaml index a740cfb..5355146 100644 --- a/egs_s2/LJSpeech/conf/default.yaml +++ b/egs_s2/LJSpeech/conf/default.yaml @@ -73,7 +73,8 @@ solver: dataloader: max_token_one_batch: 96000 # 影响单卡显存占用, 96k for 80G GPU (A100) (LJSpeech) - num_workers: 2 + num_workers: 4 + prefetch_factor: 50 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset.SemanticDataset params: diff --git a/egs_s2/LibriLight/local/train.sh b/egs_s2/LibriLight/local/train.sh index a3120e7..aa9920e 100755 --- a/egs_s2/LibriLight/local/train.sh +++ b/egs_s2/LibriLight/local/train.sh @@ -7,10 +7,10 @@ log_frequency=$4 dist_url=$5 dump_dir=$6 -opm_num=8 +omp_num=8 # 注意 *_dirs 参数后面不可以有 ''=' -OMP_NUM_THREADS=${opm_num} python3 ${BIN_DIR}/train_librilight_6k.py \ +OMP_NUM_THREADS=${omp_num} python3 ${BIN_DIR}/train_librilight_6k.py \ --config_file=${config_path} \ --train_semantic_dirs ''${root_dir}'/'${dump_dir}'/small/train/' ''${root_dir}'/'${dump_dir}'/medium/train/' \ --train_acoustic_dirs ''${root_dir}'/'${dump_dir}'/small/train/acoustic_token/' ''${root_dir}'/'${dump_dir}'/medium/train/acoustic_token/' \ diff --git a/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml b/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml index 236bab7..b20da26 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx1_L7km300.yaml @@ -75,6 +75,7 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) num_workers: 4 + prefetch_factor: 50 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset.SemanticDataset params: diff --git a/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml b/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml index 81682d6..e2c16c4 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx1_L9km500.yaml @@ -75,6 +75,7 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) num_workers: 4 + prefetch_factor: 50 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset.SemanticDataset params: diff --git a/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml b/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml index 5d7f8c6..1e90073 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx2_L10km1024.yaml @@ -75,6 +75,7 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) num_workers: 4 + prefetch_factor: 50 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset.SemanticDataset params: diff --git a/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml b/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml index 16efc03..335384d 100644 --- a/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml +++ b/egs_s2/LibriTTS/conf/30k_lrx2_L9km500.yaml @@ -75,6 +75,7 @@ solver: dataloader: max_token_one_batch: 30000 # 影响单卡显存占用, 81k for 80G GPU (A100) (LibriTTS) num_workers: 4 + prefetch_factor: 50 train_datasets: # a list of configures, so we can combine several schedulers - target: soundstorm.s2.data.semantic_dataset.SemanticDataset params: diff --git a/egs_s2/LibriTTS/local/train.sh b/egs_s2/LibriTTS/local/train.sh index cd41001..3c9307e 100755 --- a/egs_s2/LibriTTS/local/train.sh +++ b/egs_s2/LibriTTS/local/train.sh @@ -7,9 +7,9 @@ log_frequency=$4 dist_url=$5 dump_dir=$6 -opm_num=8 +omp_num=8 -OMP_NUM_THREADS=${opm_num} python3 ${BIN_DIR}/train.py \ +OMP_NUM_THREADS=${omp_num} python3 ${BIN_DIR}/train.py \ --config_file=${config_path} \ --train_semantic_path=${root_dir}/${dump_dir}/train/semantic_token.tsv \ --train_acoustic_path=${root_dir}/${dump_dir}/train/acoustic_token/hificodec.pth \ diff --git a/egs_s2/LibriTTS/local/train_iter.sh b/egs_s2/LibriTTS/local/train_iter.sh index 6a90bf6..fd7bc71 100755 --- a/egs_s2/LibriTTS/local/train_iter.sh +++ b/egs_s2/LibriTTS/local/train_iter.sh @@ -7,9 +7,9 @@ log_frequency=$4 dist_url=$5 dump_dir=$6 -opm_num=8 +omp_num=8 -OMP_NUM_THREADS=${opm_num} python3 ${BIN_DIR}/train.py \ +OMP_NUM_THREADS=${omp_num} python3 ${BIN_DIR}/train.py \ --config_file=${config_path} \ --train_semantic_path=${root_dir}/${dump_dir}/train/semantic_token.tsv \ --train_acoustic_path=${root_dir}/${dump_dir}/train/acoustic_token/hificodec.pth \