# STT model training

## Imports + model creation

In [None]:
import pandas as pd
import tensorflow as tf

from models.stt import DeepSpeech, Jasper, TransformerSTT
from datasets import get_dataset, train_test_split, prepare_dataset, test_dataset_time
from utils import plot_spectrogram
from utils.text import get_symbols
from utils.audio import display_audio, load_audio, load_mel

gpus = tf.config.list_physical_devices('GPU')

model_name = "siwis_deep_speech"


print("Tensorflow version : {}".format(tf.__version__))
print("Available GPU's ({}) : {}".format(len(gpus), gpus))

In [None]:
cleaners = ['french_cleaners']

vocab = get_symbols('fr', maj = False, ponctuation = 2)
print(vocab)

In [None]:
model = DeepSpeech.build_pretrained_deep_speech(
    nom = model_name, lang = 'fr', vocab = vocab, cleaners = cleaners
)
print(model)

In [None]:
config = {
    'embedding_dim'      : 512,
    'encoder_num_layers' : 4,
    'encoder_mha_num_heads' : 2,
    'decoder_num_layers' : 1,
    'decoder_mha_num_heads' : 2,
    'encoder_enc_mha_num_heads' : 2,
}
model = TransformerSTT(lang = 'fr', nom = model_name)

In [None]:
print(model.text_encoder)

## Model initialization

In [None]:
model = DeepSpeech(nom = model_name)

lr_config = {
    'name' : 'WarmupScheduler',
    'maxval' : 1e-3,
    'minval' : 1e-4,
    'factor' : 256,
    'warmup_steps' : 275 * 10
}
lr_config = 1e-3

model.compile(
    optimizer = 'adam', 
    optimizer_config = {
        'lr' : lr_config
    }
)

print(model)

In [None]:
dataset_name = 'siwis'
dataset = get_dataset(dataset_name)

train, valid = None, None

print("Dataset length : {}".format(len(dataset)))

## Training

In [None]:
""" Classic hyperparameters """
epochs     = 25
batch_size = 32
valid_batch_size = 2 * batch_size
train_prop = 0.9
train_size = int(len(dataset) * train_prop)
valid_size = min(len(dataset) - train_size, 250 * valid_batch_size)

shuffle_size    = 1024
pred_step       = -10 # make a prediction after every epoch
augment_prct    = 0.1

""" Custom training hparams """
trim_audio      = False
reduce_noise    = False
trim_threshold  = 0.075
max_silence     = 0.25
trim_method     = 'window'
trim_mode       = 'start_end'

trim_mel     = False
trim_factor  = 0.6
trim_mel_method  = 'max_start_end'

""" Training """

# this is to normalize dataset usage so that you can use a pre-splitted dataset or not
# without changing anything in the training configuration
if train is None or valid is None:
    train, valid = train_test_split(
        dataset, train_size = train_size, valid_size = valid_size, shuffle = True
    )

print("Training samples   : {} - {} batches".format(
    len(train), len(train) // batch_size
))
print("Validation samples : {} - {} batches".format(
    len(valid), len(valid) // valid_batch_size
))

model.train(
    train, validation_data = valid,

    epochs = epochs, batch_size = batch_size, valid_batch_size = valid_batch_size,
    
    pred_step = pred_step, shuffle_size = shuffle_size, augment_prct = augment_prct,
    
    trim_audio = trim_audio, reduce_noise = reduce_noise, trim_threshold = trim_threshold,
    max_silence = max_silence, trim_method = trim_method, trim_mode = trim_mode,
    
    trim_mel = trim_mel, trim_factor = trim_factor, trim_mel_method = trim_mel_method,
)

In [None]:
model.plot_history()

## Dataset analysis

In [None]:
config = model.get_dataset_config(batch_size = 32, is_validation = False, shuffle_size = 0)
ds = prepare_dataset(dataset, ** config, debug = True)

test_dataset_time(ds)

## Tests

In [None]:
a = tf.zeros((model.audio_rate * 10,))
mel = model.mel_fn(a)
print("Mel shape for 10 audio sec : {}".format(tf.shape(mel)))