Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
luca.gaudino committed May 10, 2024
1 parent a08554d commit bcd64a4
Showing 1 changed file with 4 additions and 5 deletions.
Expand Up @@ -1259,7 +1259,7 @@ def get_train_data(**kwargs):
# "beam_sizes": [32, 70],
# },
"model_att_only_currL": {
"scales": [(0.7, 0.3, 0.0, 1.0)],
"scales": [(0.7, 0.3, 0.3, 1.0)],
},
# "model_ctc0.3_att0.7_lay6": {
# "scales": [(0.8, 0.2, 0.55)],
Expand Down Expand Up @@ -1302,9 +1302,8 @@ def get_train_data(**kwargs):
# joint_training_model_names_2[first_model_name]["beam_sizes"],
# joint_training_model_names_2[first_model_name]["scales"],
# ):
for beam_size, prior_scale, scales in product(
for beam_size, scales in product(
[32],
[0.0, 0.1 ,0.2 ,0.3, 0.4, 0.5, 0.6, 0.7],
dict_sep_recombine[first_model_name]["scales"],
):
search_args = copy.deepcopy(args)
Expand All @@ -1317,7 +1316,7 @@ def get_train_data(**kwargs):
}
search_args["beam_size"] = beam_size
search_args["ctc_log_prior_file"] = models["model_ctc_only"]["prior"]
att_scale, ctc_scale,_,_= scales
att_scale, ctc_scale,prior_scale,_= scales
label_scale = 1.0

search_args["decoder_args"] = CTCDecoderArgs(
Expand Down Expand Up @@ -1354,7 +1353,7 @@ def get_train_data(**kwargs):
checkpoint=models[first_model_name]["ckpt"],
search_args=search_args,
bpe_size=BPE_1K,
test_sets=["dev"],
test_sets=["dev", "test"],
remove_label={"<s>", "<blank>"},
use_sclite=True,
time_rqmt=2.0,
Expand Down

0 comments on commit bcd64a4

Please sign in to comment.