# Training a microWakeWord Model

This notebook steps you through training a basic microWakeWord model. It is intended as a **starting point** for advanced users. You should use Python 3.10.

**The model generated will most likely not be usable for everyday use; it may be difficult to trigger or falsely activates too frequently. You will most likely have to experiment with many different settings to obtain a decent model!**

In the comment at the start of certain blocks, I note some specific settings to consider modifying.

This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!

At the end of this notebook, you will be able to download a tflite file. To use this in ESPHome, you need to write a model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for the details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for examples.

In [None]:
# Installs microWakeWord. Be sure to restart the session after this is finished.
import platform

if platform.system() == "Darwin":
    # `pymicro-features` is installed from a fork to support building on macOS
    !pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version'

# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter
!pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'

!git clone -b november-update https://github.com/kahrendt/microWakeWord
!pip install -e ./microWakeWord

# --- Patch microWakeWord for TF 2.20+ compatibility ---
# model.evaluate() now returns plain numpy values instead of tf.Tensor,
# so .numpy() calls on metric results fail. Add a safe conversion helper.
import pathlib
_train_py = pathlib.Path('microWakeWord/microwakeword/train.py')
_src = _train_py.read_text()
if "_to_numpy" not in _src:
    _src = _src.replace(
        'test_set_fp = result["fp"].numpy()',
        'test_set_fp = result["fp"].numpy() if hasattr(result["fp"], "numpy") else np.asarray(result["fp"])'
    ).replace(
        'all_true_positives = ambient_predictions["tp"].numpy()',
        'all_true_positives = ambient_predictions["tp"].numpy() if hasattr(ambient_predictions["tp"], "numpy") else np.asarray(ambient_predictions["tp"])'
    ).replace(
        'ambient_false_positives = ambient_predictions["fp"].numpy() - test_set_fp',
        'ambient_false_positives = (ambient_predictions["fp"].numpy() if hasattr(ambient_predictions["fp"], "numpy") else np.asarray(ambient_predictions["fp"])) - test_set_fp'
    ).replace(
        'all_false_negatives = ambient_predictions["fn"].numpy()',
        'all_false_negatives = ambient_predictions["fn"].numpy() if hasattr(ambient_predictions["fn"], "numpy") else np.asarray(ambient_predictions["fn"])'
    )
    # Mark as patched
    _src = '# _to_numpy patch applied\n' + _src
    _train_py.write_text(_src)
    print('Patched train.py for TF 2.20+ .numpy() compat')
else:
    print('train.py already patched')

In [None]:
# Generates 1 sample per voice/spelling combo for manual verification.
# For pt-BR, we download Piper checkpoints and export the generator models.
# The default .pt release only supports English, so we convert pt_BR checkpoints.

# === CONFIGURE YOUR WAKE WORD SPELLINGS AND VOICES HERE ===
# Add/remove spellings to cover pronunciation variations.
# Add/remove voices to get different speaker characteristics.
# Each spelling produces a different phoneme sequence via espeak-ng.
target_spellings = [
    'Ei Sexta!',     # canonical pronunciation
    'Oi Sexta',      # "oi" instead of "ei" (common casual variant)
    'Ei sêxtá',      # stress on both syllables of "sexta"
]

voices = {
    'faber': {
        'ckpt_url': 'https://huggingface.co/datasets/rhasspy/piper-checkpoints/resolve/main/pt/pt_BR/faber/medium/epoch%3D6159-step%3D1230728.ckpt',
        'config_url': 'https://huggingface.co/datasets/rhasspy/piper-checkpoints/resolve/main/pt/pt_BR/faber/medium/config.json',
    },
    'cadu': {
        'ckpt_url': 'https://huggingface.co/datasets/rhasspy/piper-checkpoints/resolve/main/pt/pt_BR/cadu/medium/epoch%3D5195-step%3D109116.ckpt',
        'config_url': 'https://huggingface.co/datasets/rhasspy/piper-checkpoints/resolve/main/pt/pt_BR/cadu/medium/config.json',
    },
    'jeff': {
        'ckpt_url': 'https://huggingface.co/datasets/rhasspy/piper-checkpoints/resolve/main/pt/pt_BR/jeff/medium/epoch%3D5462-step%3D118728.ckpt',
        'config_url': 'https://huggingface.co/datasets/rhasspy/piper-checkpoints/resolve/main/pt/pt_BR/jeff/medium/config.json',
    },
}

import os
import sys
import platform
import shutil

from IPython.display import Audio, display, HTML

if not os.path.exists('./piper-sample-generator'):
    if platform.system() == 'Darwin':
        !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator
    else:
        !git clone https://github.com/rhasspy/piper-sample-generator

    # Install system dependencies
    !pip install -q torch torchaudio piper-phonemize-cross==1.2.1 pytorch-lightning

    # Patch generate_samples.py for PyTorch 2.6+ (weights_only default changed to True)
    # and for single-speaker models (no emb_g attribute)
    gen_script = 'piper-sample-generator/generate_samples.py'
    with open(gen_script, 'r') as f:
        content = f.read()
    content = content.replace('torch.load(model_path)', 'torch.load(model_path, weights_only=False)')
    content = content.replace(
        'emb0 = model.emb_g(speaker_1)\n    emb1 = model.emb_g(speaker_2)\n    g = slerp(emb0, emb1, slerp_weight).unsqueeze(-1)  # [b, h, 1]',
        'if hasattr(model, \'emb_g\'):\n        emb0 = model.emb_g(speaker_1)\n        emb1 = model.emb_g(speaker_2)\n        g = slerp(emb0, emb1, slerp_weight).unsqueeze(-1)  # [b, h, 1]\n    else:\n        g = None'
    )
    with open(gen_script, 'w') as f:
        f.write(content)

if 'piper-sample-generator/' not in sys.path:
    sys.path.append('piper-sample-generator/')

# Download and export each voice model (idempotent — skips if .pt already exists)
import torch
import pathlib

models_dir = 'piper-sample-generator/models'
model_paths = {}  # voice_name -> .pt path

for voice_name, voice_info in voices.items():
    pt_path = os.path.join(models_dir, f'pt_BR-{voice_name}-medium.pt')
    json_path = pt_path + '.json'
    model_paths[voice_name] = pt_path

    if os.path.exists(pt_path):
        print(f'{voice_name}: model already exported at {pt_path}')
        continue

    ckpt_path = os.path.join(models_dir, f'pt_BR-{voice_name}-medium.ckpt')
    print(f'{voice_name}: downloading checkpoint...')
    !curl -L -o {ckpt_path} '{voice_info["ckpt_url"]}'
    !curl -L -o {json_path} '{voice_info["config_url"]}'

    torch.serialization.add_safe_globals([pathlib.PosixPath, pathlib.WindowsPath])

    # Manually reconstruct the generator from the checkpoint to avoid
    # Lightning version incompatibilities with load_from_checkpoint.
    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    hparams = ckpt['hyper_parameters']

    from piper_train.vits.models import SynthesizerTrn
    model_g = SynthesizerTrn(
        n_vocab=hparams['num_symbols'],
        spec_channels=hparams.get('filter_length', 1024) // 2 + 1,
        segment_size=hparams.get('segment_size', 8192) // hparams.get('hop_length', 256),
        inter_channels=hparams.get('inter_channels', 192),
        hidden_channels=hparams.get('hidden_channels', 192),
        filter_channels=hparams.get('filter_channels', 768),
        n_heads=hparams.get('n_heads', 2),
        n_layers=hparams.get('n_layers', 6),
        kernel_size=hparams.get('kernel_size', 3),
        p_dropout=hparams.get('p_dropout', 0.1),
        resblock=hparams.get('resblock', '2'),
        resblock_kernel_sizes=hparams.get('resblock_kernel_sizes', (3, 5, 7)),
        resblock_dilation_sizes=hparams.get('resblock_dilation_sizes', ((1,2),(2,6),(3,12))),
        upsample_rates=hparams.get('upsample_rates', (8, 8, 4)),
        upsample_initial_channel=hparams.get('upsample_initial_channel', 256),
        upsample_kernel_sizes=hparams.get('upsample_kernel_sizes', (16, 16, 8)),
        n_speakers=hparams.get('num_speakers', 1),
        gin_channels=hparams.get('gin_channels', 0),
        use_sdp=hparams.get('use_sdp', True),
    )

    # Load only the generator weights from the checkpoint state_dict
    g_state = {k.replace('model_g.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('model_g.')}
    model_g.load_state_dict(g_state)
    model_g.eval()
    torch.save(model_g, pt_path)
    del ckpt, g_state, model_g  # free memory
    os.remove(ckpt_path)
    print(f'{voice_name}: exported successfully')

# Generate 1 sample per voice × spelling combo and play them all
preview_dir = 'generated_preview'
if os.path.exists(preview_dir):
    shutil.rmtree(preview_dir)
os.makedirs(preview_dir, exist_ok=True)

sample_idx = 0
for voice_name, pt_path in model_paths.items():
    for spelling in target_spellings:
        out_dir = os.path.join(preview_dir, f'{voice_name}_{sample_idx}')
        os.makedirs(out_dir, exist_ok=True)
        !python3 piper-sample-generator/generate_samples.py "{spelling}" \
            --model {pt_path} \
            --max-samples 1 \
            --batch-size 1 \
            --output-dir {out_dir}
        wav_path = os.path.join(out_dir, '0.wav')
        display(HTML(f'<b>{voice_name}</b> — "{spelling}"'))
        display(Audio(wav_path))
        sample_idx += 1

print(f'\nGenerated {sample_idx} preview samples across {len(model_paths)} voices and {len(target_spellings)} spellings.')

In [None]:
# Generates a larger amount of wake word samples from ALL voices and spellings.
# Uses the voices/spellings defined in cell 2 — run that cell first.
#
# Key parameters:
#   --length-scales: speaking speed (lower=faster, higher=slower).
#   --noise-scales: overall variability/expressiveness.
#   --noise-scale-ws: stochastic duration variation of individual phonemes.
#   --max-samples: more samples = better model. ~1000+ per combo recommended.

import os, shutil, math, glob, subprocess

output_dir = 'generated_samples'
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)

total_samples = 10000
num_combos = len(model_paths) * len(target_spellings)
samples_per_combo = math.ceil(total_samples / num_combos)

print(f'Generating ~{total_samples} total samples ({samples_per_combo} per combo, {num_combos} combos)')

file_idx = 0
for voice_name, pt_path in model_paths.items():
    for spelling in target_spellings:
        # Generate into a temp dir to avoid filename collisions
        tmp_dir = os.path.join(output_dir, f'_tmp_{voice_name}_{id(spelling)}')
        os.makedirs(tmp_dir, exist_ok=True)

        print(f'\n>>> {voice_name} — "{spelling}" ({samples_per_combo} samples)')
        subprocess.run([
            'python3', 'piper-sample-generator/generate_samples.py', spelling,
            '--model', pt_path,
            '--max-samples', str(samples_per_combo),
            '--batch-size', '100',
            '--length-scales', '0.6', '0.75', '0.85', '1.0', '1.15', '1.3', '1.5',
            '--noise-scales', '0.5', '0.667', '0.75', '0.85', '1.0', '1.2',
            '--noise-scale-ws', '0.6', '0.8', '1.0', '1.2',
            '--output-dir', tmp_dir,
        ], check=True)

        # Move files into the main output dir with unique sequential names
        for wav in sorted(glob.glob(os.path.join(tmp_dir, '*.wav'))):
            os.rename(wav, os.path.join(output_dir, f'{file_idx}.wav'))
            file_idx += 1
        shutil.rmtree(tmp_dir)

print(f'\nDone! {file_idx} samples in {output_dir}/')

In [None]:
# Downloads audio data for augmentation. This can be slow!
# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024
#
# datasets 4.0+ replaced soundfile with torchcodec and changed the audio API.
# We need <4.0.0 so the dict-based audio decoding (row['audio']['array']) works.
import datasets as _ds_check
if _ds_check.__version__ >= '4.0.0':
    import subprocess, sys, IPython
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'datasets<4.0.0'])
    print('\n⚠️  datasets was downgraded. Restarting kernel — just re-run this cell.')
    IPython.get_ipython().kernel.do_shutdown(restart=True)
    raise SystemExit  # stop execution until kernel restarts
del _ds_check

#
# **Important note!** The data downloaded here has a mixture of difference
# licenses and usage restrictions. As such, any custom models trained with this
# data should be considered as appropriate for **non-commercial** personal use only.


import datasets
import scipy
import os
import glob

import numpy as np

from pathlib import Path
from tqdm import tqdm

## Download MIT RIR data

output_dir = "./mit_rirs"
if not glob.glob(os.path.join(output_dir, '*.wav')):
    os.makedirs(output_dir, exist_ok=True)
    rir_dataset = datasets.load_dataset("davidscripka/MIT_environmental_impulse_responses", split="train", streaming=True)
    # Save clips to 16-bit PCM wav files
    for row in tqdm(rir_dataset):
        name = row['audio']['path'].split('/')[-1]
        scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))
else:
    print(f'MIT RIRs already downloaded ({len(glob.glob(os.path.join(output_dir, "*.wav")))} files)')

## Download noise and background audio

# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)
# Download one part of the audioset .tar files, extract, and convert to 16khz
# For full-scale training, it's recommended to download the entire dataset from
# https://huggingface.co/datasets/agkphysics/AudioSet, and
# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)

audioset_16k_dir = "./audioset_16k"
if not glob.glob(os.path.join(audioset_16k_dir, '*.wav')):
    os.makedirs("audioset", exist_ok=True)

    fname = "bal_train09.tar"
    out_dir = f"audioset/{fname}"
    link = "https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/" + fname
    !curl -L -o {out_dir} {link}
    !tar -xf audioset/bal_train09.tar -C audioset

    os.makedirs(audioset_16k_dir, exist_ok=True)

    # Save clips to 16-bit PCM wav files
    audioset_dataset = datasets.Dataset.from_dict({"audio": [str(i) for i in Path("audioset/audio").glob("**/*.flac")]})
    audioset_dataset = audioset_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
    for row in tqdm(audioset_dataset):
        name = row['audio']['path'].split('/')[-1].replace(".flac", ".wav")
        scipy.io.wavfile.write(os.path.join(audioset_16k_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))
else:
    print(f'Audioset already converted ({len(glob.glob(os.path.join(audioset_16k_dir, "*.wav")))} files)')

# Free Music Archive dataset
# https://github.com/mdeff/fma
# (Third-party mchl914 extra small set)

fma_16k_dir = "./fma_16k"
if not glob.glob(os.path.join(fma_16k_dir, '*.wav')):
    os.makedirs("fma", exist_ok=True)
    fname = "fma_xs.zip"
    link = "https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/" + fname
    out_dir = f"fma/{fname}"
    !curl -L -o {out_dir} {link}
    !unzip -q fma/{fname} -d fma

    os.makedirs(fma_16k_dir, exist_ok=True)

    # Save clips to 16-bit PCM wav files
    fma_dataset = datasets.Dataset.from_dict({"audio": [str(i) for i in Path("fma/fma_small").glob("**/*.mp3")]})
    fma_dataset = fma_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
    for row in tqdm(fma_dataset):
        name = row['audio']['path'].split('/')[-1].replace(".mp3", ".wav")
        scipy.io.wavfile.write(os.path.join(fma_16k_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))
else:
    print(f'FMA already converted ({len(glob.glob(os.path.join(fma_16k_dir, "*.wav")))} files)')


In [None]:
# Sets up the augmentations.
# To improve your model, experiment with these settings and use more sources of
# background clips.

from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

clips = Clips(input_directory='generated_samples',
              file_pattern='*.wav',
              max_clip_duration_s=None,
              remove_silence=False,
              random_split_seed=10,
              split_count=0.1,
              )
augmenter = Augmentation(augmentation_duration_s=3.2,
                         augmentation_probabilities = {
                                "SevenBandParametricEQ": 0.1,
                                "TanhDistortion": 0.1,
                                "PitchShift": 0.1,
                                "BandStopFilter": 0.1,
                                "AddColorNoise": 0.1,
                                "AddBackgroundNoise": 0.75,
                                "Gain": 1.0,
                                "RIR": 0.5,
                            },
                         impulse_paths = ['mit_rirs'],
                         background_paths = ['fma_16k', 'audioset_16k'],
                         background_min_snr_db = -5,
                         background_max_snr_db = 10,
                         min_jitter_s = 0.195,
                         max_jitter_s = 0.205,
                         )


In [None]:
# Augment a random clip and play it back to verify it works well

from IPython.display import Audio
from microwakeword.audio.audio_utils import save_clip

random_clip = clips.get_random_clip()
augmented_clip = augmenter.augment_clip(random_clip)
save_clip(augmented_clip, 'augmented_clip.wav')

Audio("augmented_clip.wav", autoplay=True)

In [None]:
# Augment samples and save the training, validation, and testing sets.
# Validating and testing samples generated the same way can make the model
# benchmark better than it performs in real-word use. Use real samples or TTS
# samples generated with a different TTS engine to potentially get more accurate
# benchmarks.

import os
from mmap_ninja.ragged import RaggedMmap

output_dir = 'generated_augmented_features'

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

splits = ["training", "validation", "testing"]
for split in splits:
  out_dir = os.path.join(output_dir, split)
  if not os.path.exists(out_dir):
      os.mkdir(out_dir)


  split_name = "train"
  repetition = 2

  spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=10,    # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.
                                     step_ms=10,
                                     )
  if split == "validation":
    split_name = "validation"
    repetition = 1
  elif split == "testing":
    split_name = "test"
    repetition = 1
    spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=1,    # The testing set uses the streaming version of the model, so no artificial repetition is necessary
                                     step_ms=10,
                                     )

  RaggedMmap.from_generator(
      out_dir=os.path.join(out_dir, 'wakeword_mmap'),
      sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),
      batch_size=100,
      verbose=True,
  )

In [None]:
# Downloads pre-generated spectrogram features (made for microWakeWord in
# particular) for various negative datasets. This can be slow!

output_dir = './negative_datasets'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
    filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']
    for fname in filenames:
        link = link_root + fname

        zip_path = f"negative_datasets/{fname}"
        !curl -L -o {zip_path} {link}
        !unzip -q {zip_path} -d {output_dir}

In [None]:
# Save a yaml config that controls the training process
# These hyperparamters can make a huge different in model quality.
# Experiment with sampling and penalty weights and increasing the number of
# training steps.

import yaml
import os

config = {}

config["window_step_ms"] = 10

config["train_dir"] = (
    "trained_models/wakeword"
)


# Each feature_dir should have at least one of the following folders with this structure:
#  training/
#    ragged_mmap_folders_ending_in_mmap
#  testing/
#    ragged_mmap_folders_ending_in_mmap
#  testing_ambient/
#    ragged_mmap_folders_ending_in_mmap
#  validation/
#    ragged_mmap_folders_ending_in_mmap
#  validation_ambient/
#    ragged_mmap_folders_ending_in_mmap
#
#  sampling_weight: Weight for choosing a spectrogram from this set in the batch
#  penalty_weight: Penalizing weight for incorrect predictions from this set
#  truth: Boolean whether this set has positive samples or negative samples
#  truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated
#       - random: choose a random portion of the entire spectrogram - useful for long negative samples
#       - truncate_start: remove the start of the spectrogram
#       - truncate_end: remove the end of the spectrogram
#       - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets

config["features"] = [
    {
        "features_dir": "generated_augmented_features",
        "sampling_weight": 2.0,
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/speech",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/no_speech",
        "sampling_weight": 5.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    { # Only used for validation and testing
        "features_dir": "negative_datasets/dinner_party_eval",
        "sampling_weight": 0.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
]

# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps
config["training_steps"] = [10000]

# Penalizing weight for incorrect class predictions - lists that correspond to training steps
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [20]

config["learning_rates"] = [
    0.001,
]  # Learning rates for Adam optimizer - list that corresponds to training steps
config["batch_size"] = 128

config["time_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["time_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps
config["freq_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["freq_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps

config["eval_step_interval"] = (
    500  # Test the validation sets after every this many steps
)
config["clip_duration_ms"] = (
    1500  # Maximum length of wake word that the streaming model will accept
)

# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization
# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize
# Available metrics:
#   - "loss" - cross entropy error on validation set
#   - "accuracy" - accuracy of validation set
#   - "recall" - recall of validation set
#   - "precision" - precision of validation set
#   - "false_positive_rate" - false positive rate of validation set
#   - "false_negative_rate" - false negative rate of validation set
#   - "ambient_false_positives" - count of false positives from the split validation_ambient set
#   - "ambient_false_positives_per_hour" - estimated number of false positives per hour on the split validation_ambient set
config["target_minimization"] = 0.9
config["minimization_metric"] = None  # Set to None to disable

config["maximization_metric"] = "average_viable_recall"

with open(os.path.join("training_parameters.yaml"), "w") as file:
    documents = yaml.dump(config, file)

In [None]:
# Trains a model. When finished, it will quantize and convert the model to a
# streaming version suitable for on-device detection.
# It will resume if stopped, but it will start over at the configured training
# steps in the yaml file.
# Change --train 0 to only convert and test the best-weighted model.
# On Google colab, it doesn't print the mini-batch results, so it may appear
# stuck for several minutes! Additionally, it is very slow compared to training
# on a local GPU.

!python -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 1 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 0 \
--test_tflite_nonstreaming 0 \
--test_tflite_nonstreaming_quantized 0 \
--test_tflite_streaming 0 \
--test_tflite_streaming_quantized 1 \
--use_weights "best_weights" \
mixednet \
--pointwise_filters "64,64,64,64" \
--repeat_in_block  "1, 1, 1, 1" \
--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \
--residual_connection "0,0,0,0" \
--first_conv_filters 32 \
--first_conv_kernel_size 5 \
--stride 1