Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

将多卡/单卡的权重保存成同一种格式;支持shell脚本中data为多个字符串的输入;修改test文件 #43

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion academicodec/models/encodec/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torchaudio
from torch.utils.data import Dataset
from pathlib import Path


class NSynthDataset(Dataset):
Expand All @@ -12,7 +13,10 @@ class NSynthDataset(Dataset):
def __init__(self, audio_dir):
super().__init__()
self.filenames = []
self.filenames.extend(glob.glob(audio_dir + "/*.wav"))
for sub_dir in audio_dir:
print(sub_dir)
self.filenames.extend(list(Path(sub_dir).glob("**/*.wav")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这里这样该就无法支持别的目录结构了,原始的写法是目录里面直接是 *.wav 文件,如果想支持多种类型的输入可以加判断


print(len(self.filenames))
_, self.sr = torchaudio.load(self.filenames[0])
self.max_len = 24000 # 24000
Expand Down
36 changes: 22 additions & 14 deletions academicodec/models/encodec/main_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ def get_args():
parser.add_argument(
'--train_data_path',
type=str,
nargs='*',
# default='/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/dongchao/code4/InstructTTS2/data_process/soundstream_data/train16k.lst',
default="/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/data/codec_data_24k/train_valid_lists/train.lst",
help='training data')
parser.add_argument(
'--valid_data_path',
type=str,
nargs='*',
# default='/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/dongchao/code4/InstructTTS2/data_process/soundstream_data/val16k.lst',
default="/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/data/codec_data_24k/train_valid_lists/valid_256.lst",
help='validation data')
Expand All @@ -130,6 +132,12 @@ def get_args():
# default for 16k_320d
default=[1, 1.5, 2, 4, 6, 12],
help='target_bandwidths of net3.py')
parser.add_argument(
'--lr',
type=float,
default=3e-4,
help="learning rate"
)
args = parser.parse_args()
time_str = time.strftime('%Y-%m-%d-%H-%M')
if args.resume:
Expand Down Expand Up @@ -237,13 +245,13 @@ def main_worker(local_rank, args):
sampler=valid_sampler)
logger.log_info("Build optimizers and lr-schedulers")
optimizer_g = torch.optim.AdamW(
soundstream.parameters(), lr=3e-4, betas=(0.5, 0.9))
soundstream.parameters(), lr=args.lr, betas=(0.5, 0.9))
lr_scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optimizer_g, gamma=0.999)
optimizer_d = torch.optim.AdamW(
itertools.chain(stft_disc.parameters(),
msd.parameters(), mpd.parameters()),
lr=3e-4,
lr=args.lr,
betas=(0.5, 0.9))
lr_scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optimizer_d, gamma=0.999)
Expand Down Expand Up @@ -273,12 +281,12 @@ def train(args, soundstream, stft_disc, msd, mpd, train_loader, valid_loader,
stft_disc.train()
msd.train()
mpd.train()
train_loss_d = 0.0
train_adv_g_loss = 0.0
train_feat_loss = 0.0
train_rec_loss = 0.0
train_loss_g = 0.0
train_commit_loss = 0.0
train_loss_d = 0.0 # hinge-loss adversarial loss function
train_adv_g_loss = 0.0 # adversarial loss for the generator
train_feat_loss = 0.0 # a relative feature matching loss for the generator
train_rec_loss = 0.0 # l_t and l_f
train_loss_g = 0.0 # loss generation
train_commit_loss = 0.0 # RVQ commit loss
k_iter = 0
if args.distributed:
train_loader.sampler.set_epoch(epoch)
Expand Down Expand Up @@ -427,12 +435,12 @@ def train(args, soundstream, stft_disc, msd, mpd, train_loader, valid_loader,
fmap_f_g, y_ds_hat_r, y_ds_hat_g,
fmap_s_r, fmap_s_g)
valid_loss_d += loss_d.item()
if dist.get_rank() == 0:
best_model = soundstream.state_dict().copy()
latest_model_soundstream = soundstream.state_dict().copy()
latest_model_dis = stft_disc.state_dict().copy()
latest_mpd = mpd.state_dict().copy()
latest_msd = msd.state_dict().copy()
if not args.distributed or dist.get_rank() == 0:
best_model = soundstream.module.state_dict().copy() if args.distributed else soundstream.state_dict().copy()
latest_model_soundstream = soundstream.module.state_dict().copy() if args.distributed else soundstream.state_dict().copy()
latest_model_dis = stft_disc.module.state_dict().copy() if args.distributed else stft_disc.state_dict().copy()
latest_mpd = mpd.module.state_dict().copy() if args.distributed else mpd.state_dict().copy()
latest_msd = msd.module.state_dict().copy() if args.distributed else msd.state_dict().copy()
if valid_rec_loss < best_val_loss:
best_val_loss = valid_rec_loss
best_val_epoch = epoch
Expand Down
37 changes: 23 additions & 14 deletions academicodec/models/encodec/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
"""Command-line for audio compression."""
import argparse
import os
import sys
import typing as tp
from collections import OrderedDict
Expand Down Expand Up @@ -64,6 +63,9 @@ def get_parser():
default=[8, 5, 4, 2],
help='ratios of SoundStream, shoud be set for different hop_size (32d, 320, 240d, ...)'
)
parser.add_argument(
'-f', '--force', action='store_true',
help='Overwrite output file if it exists.')
parser.add_argument(
'--target_bandwidths',
type=float,
Expand All @@ -77,6 +79,14 @@ def get_parser():
# default for 16k_320d
default=12,
help='target_bw of net3.py')
parser.add_argument(
'--ext',
type=str,
# default for 16k_320d
default="wav",
help='audio extension',
choices=["wav","flac","mp3"])


return parser

Expand All @@ -85,6 +95,11 @@ def fatal(*args):
print(*args, file=sys.stderr)
sys.exit(1)

def check_output_exists(args):
if not args.output.parent.exists():
fatal(f"Output folder for {args.output} does not exist.")
if args.output.exists() and not args.force:
fatal(f"Output file {args.output} exist. Use -f / --force to overwrite.")

# 这只是打印了但是没有真的 clip
def check_clipping(wav, rescale):
Expand All @@ -106,7 +121,6 @@ def test_one(args, wav_root, store_root, rescale, soundstream):
# wav = wav[0].unsqueeze(0)
# # 重采样为模型的采样率
# wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=args.sr)(wav)

# load wav with librosa
wav, sr = librosa.load(wav_root, sr=args.sr)
wav = torch.tensor(wav).unsqueeze(0)
Expand Down Expand Up @@ -160,31 +174,26 @@ def test_batch():
print("args.target_bandwidths:", args.target_bandwidths)
if not args.input.exists():
fatal(f"Input file {args.input} does not exist.")
input_lists = os.listdir(args.input)
input_lists = list(args.input.glob(f"**/*.{args.ext}"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里这样写的话只能读取一种格式了,不能同时读取 *.mp3 和 *.wav

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但是好像一个文件夹下很少出现这种情况(?emmm 主要是有些文件夹下有对应的音频文件和转录的文件,所以我才这么写了orz

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我自己输入文件随便搞的,有各种格式,比如可以先列出所有文件,然后再把是 *.mp3 .wav,.flac 的筛选出来

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明白了,先取所有文件,然后只保留特定的mp3,flac,wav这种格式的文件

input_lists.sort()
soundstream = SoundStream(
n_filters=32,
D=512,
ratios=args.ratios,
sample_rate=args.sr,
target_bandwidths=args.target_bandwidths)
parameter_dict = torch.load(args.resume_path)
new_state_dict = OrderedDict()
# k 为 module.xxx.weight, v 为权重
for k, v in parameter_dict.items():
# 截取`module.`后面的xxx.weight
name = k[7:]
new_state_dict[name] = v
soundstream.load_state_dict(new_state_dict) # load model
soundstream.load_state_dict(torch.load(args.resume_path)) # load model
remove_encodec_weight_norm(soundstream)
soundstream.cuda()
soundstream.eval()
os.makedirs(args.output, exist_ok=True)
if not args.output.exists():
args.output.mkdirs(parents=True)
check_output_exists(args)
for audio in input_lists:
test_one(
args=args,
wav_root=os.path.join(args.input, audio),
store_root=os.path.join(args.output, audio),
wav_root=args.input.joinpath(audio.name),
store_root=args.output.joinpath(audio.name),
rescale=args.rescale,
soundstream=soundstream)

Expand Down