-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
将多卡/单卡的权重保存成同一种格式;支持shell脚本中data为多个字符串的输入;修改test文件 #43
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
# LICENSE file in the root directory of this source tree. | ||
"""Command-line for audio compression.""" | ||
import argparse | ||
import os | ||
import sys | ||
import typing as tp | ||
from collections import OrderedDict | ||
|
@@ -64,6 +63,9 @@ def get_parser(): | |
default=[8, 5, 4, 2], | ||
help='ratios of SoundStream, shoud be set for different hop_size (32d, 320, 240d, ...)' | ||
) | ||
parser.add_argument( | ||
'-f', '--force', action='store_true', | ||
help='Overwrite output file if it exists.') | ||
parser.add_argument( | ||
'--target_bandwidths', | ||
type=float, | ||
|
@@ -77,6 +79,14 @@ def get_parser(): | |
# default for 16k_320d | ||
default=12, | ||
help='target_bw of net3.py') | ||
parser.add_argument( | ||
'--ext', | ||
type=str, | ||
# default for 16k_320d | ||
default="wav", | ||
help='audio extension', | ||
choices=["wav","flac","mp3"]) | ||
|
||
|
||
return parser | ||
|
||
|
@@ -85,6 +95,11 @@ def fatal(*args): | |
print(*args, file=sys.stderr) | ||
sys.exit(1) | ||
|
||
def check_output_exists(args): | ||
if not args.output.parent.exists(): | ||
fatal(f"Output folder for {args.output} does not exist.") | ||
if args.output.exists() and not args.force: | ||
fatal(f"Output file {args.output} exist. Use -f / --force to overwrite.") | ||
|
||
# 这只是打印了但是没有真的 clip | ||
def check_clipping(wav, rescale): | ||
|
@@ -106,7 +121,6 @@ def test_one(args, wav_root, store_root, rescale, soundstream): | |
# wav = wav[0].unsqueeze(0) | ||
# # 重采样为模型的采样率 | ||
# wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=args.sr)(wav) | ||
|
||
# load wav with librosa | ||
wav, sr = librosa.load(wav_root, sr=args.sr) | ||
wav = torch.tensor(wav).unsqueeze(0) | ||
|
@@ -160,31 +174,26 @@ def test_batch(): | |
print("args.target_bandwidths:", args.target_bandwidths) | ||
if not args.input.exists(): | ||
fatal(f"Input file {args.input} does not exist.") | ||
input_lists = os.listdir(args.input) | ||
input_lists = list(args.input.glob(f"**/*.{args.ext}")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里这样写的话只能读取一种格式了,不能同时读取 *.mp3 和 *.wav There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 但是好像一个文件夹下很少出现这种情况(?emmm 主要是有些文件夹下有对应的音频文件和转录的文件,所以我才这么写了orz There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我自己输入文件随便搞的,有各种格式,比如可以先列出所有文件,然后再把是 *.mp3 .wav,.flac 的筛选出来 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 明白了,先取所有文件,然后只保留特定的mp3,flac,wav这种格式的文件 |
||
input_lists.sort() | ||
soundstream = SoundStream( | ||
n_filters=32, | ||
D=512, | ||
ratios=args.ratios, | ||
sample_rate=args.sr, | ||
target_bandwidths=args.target_bandwidths) | ||
parameter_dict = torch.load(args.resume_path) | ||
new_state_dict = OrderedDict() | ||
# k 为 module.xxx.weight, v 为权重 | ||
for k, v in parameter_dict.items(): | ||
# 截取`module.`后面的xxx.weight | ||
name = k[7:] | ||
new_state_dict[name] = v | ||
soundstream.load_state_dict(new_state_dict) # load model | ||
soundstream.load_state_dict(torch.load(args.resume_path)) # load model | ||
remove_encodec_weight_norm(soundstream) | ||
soundstream.cuda() | ||
soundstream.eval() | ||
os.makedirs(args.output, exist_ok=True) | ||
if not args.output.exists(): | ||
args.output.mkdirs(parents=True) | ||
check_output_exists(args) | ||
for audio in input_lists: | ||
test_one( | ||
args=args, | ||
wav_root=os.path.join(args.input, audio), | ||
store_root=os.path.join(args.output, audio), | ||
wav_root=args.input.joinpath(audio.name), | ||
store_root=args.output.joinpath(audio.name), | ||
rescale=args.rescale, | ||
soundstream=soundstream) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你这里这样该就无法支持别的目录结构了,原始的写法是目录里面直接是 *.wav 文件,如果想支持多种类型的输入可以加判断