Skip to content

Commit

Permalink
cleanup librispeech usage
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed May 10, 2024
1 parent fa65858 commit e671922
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 47 deletions.
27 changes: 3 additions & 24 deletions users/zeyer/experiments/exp2024_04_23_baselines/aed.py
Expand Up @@ -25,7 +25,6 @@
from i6_experiments.users.zeyer.model_interfaces import ModelDef, ModelDefWithCfg, RecogDef, TrainDef
from i6_experiments.users.zeyer.returnn.models.rf_layerdrop import SequentialLayerDrop
from i6_experiments.users.zeyer.speed_pert.librosa_config import speed_pert_librosa_config
from i6_experiments.users.zeyer.accum_grad_schedules.piecewise_linear import dyn_accum_grad_piecewise_linear

from .configs import *
from .configs import _get_cfg_lrlin_oclr_by_bs_nep, _batch_size_factor
Expand Down Expand Up @@ -89,6 +88,7 @@ def train_exp(
*,
model_def: Optional[Union[ModelDefWithCfg, ModelDef[Model]]] = None,
vocab: str = "bpe10k",
train_vocab_opts: Optional[Dict[str, Any]] = None,
train_def: Optional[TrainDef[Model]] = None,
model_config: Optional[Dict[str, Any]] = None,
config_updates: Optional[Dict[str, Any]] = None,
Expand All @@ -104,12 +104,13 @@ def train_exp(
"""
from i6_experiments.users.zeyer.train_v3 import train
from i6_experiments.users.zeyer.recog import recog_training_exp
from i6_experiments.users.zeyer.datasets.librispeech import get_librispeech_task_raw_v2

if _sis_prefix is None:
_sis_setup_global_prefix()

prefix = _sis_prefix + "/" + name
task = _get_ls_task(vocab=vocab)
task = get_librispeech_task_raw_v2(vocab=vocab, train_vocab_opts=train_vocab_opts)
config = config.copy()
config = dict_update_deep(config, config_updates, config_deletes)
if "__num_epochs" in config:
Expand Down Expand Up @@ -159,28 +160,6 @@ def _sis_setup_global_prefix(prefix_name: Optional[str] = None):
_sis_prefix = prefix_name


_ls_task = {} # vocab -> task


def _get_ls_task(*, vocab: str = "bpe10k") -> Task:
global _ls_task
if vocab in _ls_task:
return _ls_task[vocab]

from i6_experiments.users.zeyer.datasets.librispeech import get_librispeech_task_raw_v2, bpe10k, spm_10k

# Check via ``sis c ...`` and then ``tk.print_graph()``.
# The problem is we don't want:
# - CorpusToStmJob
# - BlissChangeEncodingJob (changing flac to ogg)
# - SearchWordsToCTMJob
# - CorpusToTxtJob
# That's why we use v2.
vocab_ = {"bpe10k": bpe10k, "spm10k": spm_10k}[vocab]
_ls_task[vocab] = get_librispeech_task_raw_v2(vocab=vocab_)
return _ls_task[vocab]


def aed_model_def(*, epoch: int, in_dim: Dim, target_dim: Dim) -> Model:
"""Function is run within RETURNN."""
from returnn.config import get_global_config
Expand Down
25 changes: 2 additions & 23 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc.py
Expand Up @@ -76,12 +76,13 @@ def train_exp(
"""
from i6_experiments.users.zeyer.train_v3 import train
from i6_experiments.users.zeyer.recog import recog_training_exp
from i6_experiments.users.zeyer.datasets.librispeech import get_librispeech_task_raw_v2

if _sis_prefix is None:
_sis_setup_global_prefix()

prefix = _sis_prefix + "/" + name
task = _get_ls_task(vocab=vocab)
task = get_librispeech_task_raw_v2(vocab=vocab)
config = config.copy()
config = dict_update_deep(config, config_updates, config_deletes)
if "__num_epochs" in config:
Expand Down Expand Up @@ -131,28 +132,6 @@ def _sis_setup_global_prefix(prefix_name: Optional[str] = None):
_sis_prefix = prefix_name


_ls_task = {} # vocab -> task


def _get_ls_task(*, vocab: str = "bpe10k") -> Task:
global _ls_task
if vocab in _ls_task:
return _ls_task[vocab]

from i6_experiments.users.zeyer.datasets.librispeech import get_librispeech_task_raw_v2, bpe10k, spm_10k

# Check via ``sis c ...`` and then ``tk.print_graph()``.
# The problem is we don't want:
# - CorpusToStmJob
# - BlissChangeEncodingJob (changing flac to ogg)
# - SearchWordsToCTMJob
# - CorpusToTxtJob
# That's why we use v2.
vocab_ = {"bpe10k": bpe10k, "spm10k": spm_10k}[vocab]
_ls_task[vocab] = get_librispeech_task_raw_v2(vocab=vocab_)
return _ls_task[vocab]


def ctc_model_def(*, epoch: int, in_dim: Dim, target_dim: Dim) -> Model:
"""Function is run within RETURNN."""
from returnn.config import get_global_config
Expand Down

0 comments on commit e671922

Please sign in to comment.