# Train & Align NN Acoustic Model
Repeatedly re-align phone labels sequence while training the phones model.
To avoid proliferation of the more frequent phones (and mostly the silence), we diminish b() probabilities of frequent phones during re-alignment.

In [1]:
%run ../prongen/hmm_pron.py --in-jupyter
%run ../acmodel/plot.py
%matplotlib ipympl
%run ../acmodel/matrix.py
%run ../acmodel/praat_ifc.py
%run ../acmodel/hmm_acmodel.py

device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
print(f"Using {device} device")

%run ../acmodel/nn_acmodel.py

hmm_pron.py library - generate Czech pron HMM. Included to this notebook.


Using cuda device


## Get training data
We previously aligned Czech CommonVoice train set using an ultra-prinmitive HMM/GMM and then NNs. Let's continue on it.

In [2]:
#infile = "mega4_training_0021.tsv"
infile = 'nn_train.tsv'
df = pd.read_csv(infile, sep="\t", keep_default_na=False)
hmms = get_training_hmms(infile, derivatives=3)
b_log_corr = b_log_corrections(infile) # get b() corrections based on frequency

In [3]:

# CONDITIONAL:

for hmm in hmms:
    triple_hmm_states(hmm)

In [4]:
all_mfcc, all_targets, b_set = collect_training_material(hmms)

out_size = len(b_set)
in_size = hmms[0].mfcc.size(1)

" ".join(b_set), out_size, in_size

('? A E G H N O Z a b c d e f g h j k l m n o p r s t u v y z | á é ó ú ý č ď ň Ř ř š ť Ž ž',
 45,
 52)

## Setup PyTorch training tools

In [5]:
#model = NeuralNetwork(in_size, out_size).to(device)
model = NeuralNetwork(in_size, out_size, 20).to(device) # 50
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=52, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=20, bias=True)
    (3): ReLU()
    (4): Linear(in_features=20, out_features=20, bias=True)
    (5): ReLU()
    (6): Linear(in_features=20, out_features=45, bias=True)
  )
)


In [6]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
training_data = SpeechDataset(all_mfcc, all_targets, b_set) # initial alignment


for mega_epoch in range(100):
    print(f"============ Training Mega Epoch {mega_epoch} =============")

    all_targets = "".join([hmm.targets for hmm in hmms])  # collect alignments
    training_data.all_targets = all_targets  # just update the object with new targets

    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) # new dataloader for this alignment

    train_n_epochs(train_dataloader, optimizer, model, criterion, 5)

    print('Interrupted training for re-alignment...')

    model.eval() # switch to evaluation mode


    for idx, hmm in enumerate(hmms):
        if idx%1000==0:
            print(f"Align {idx}")
    
        alp = viterbi_log_align_nn(hmm, model, b_set, b_log_corr=b_log_corr*1.0) # b() corrections according to current frame frequency
        hmm.intervals = backward_log_alignment_pass_intervals(hmm, alp) # also modifies alp
        hmm.indices = i = alp.max(1).indices
        s = "".join([hmm.b[ii] for ii in i])
        hmm.troubling = troubling_alignmet(s)
        hmm.targets = "".join([hmm.b[ii] for ii in i])


    df['targets'] = [hmm.targets for hmm in hmms]

    filename_base = f"minih_training_{'%04d' % mega_epoch}"

    torch.save(model.state_dict(), filename_base+".pth")
    df.to_csv(filename_base+".tsv", sep="\t", index=False)
    
    b_log_corr = b_log_corrections(filename_base+".tsv") # get new b() corrections based on frequency


[1, 20000] loss: 1.516
[1, 40000] loss: 1.434
[1, 60000] loss: 1.421
[2, 20000] loss: 1.410
[2, 40000] loss: 1.406
[2, 60000] loss: 1.404
[3, 20000] loss: 1.403
[3, 40000] loss: 1.403
[3, 60000] loss: 1.401
[4, 20000] loss: 1.399
[4, 40000] loss: 1.399
[4, 60000] loss: 1.399
[5, 20000] loss: 1.394
[5, 40000] loss: 1.395
[5, 60000] loss: 1.393
Interrupted training for re-alignment...
Align 0
Align 1000
Align 2000
Align 3000
Align 4000
Align 5000
Align 6000
Align 7000
Align 8000
Align 9000
Align 10000
[1, 20000] loss: 1.400
[1, 40000] loss: 1.398
[1, 60000] loss: 1.393
[2, 20000] loss: 1.394
[2, 40000] loss: 1.394
[2, 60000] loss: 1.395
[3, 20000] loss: 1.392
[3, 40000] loss: 1.393
[3, 60000] loss: 1.394
[4, 20000] loss: 1.391
[4, 40000] loss: 1.390
[4, 60000] loss: 1.395
[5, 20000] loss: 1.393
[5, 40000] loss: 1.391
[5, 60000] loss: 1.392
Interrupted training for re-alignment...
Align 0
Align 1000
Align 2000
Align 3000
Align 4000
Align 5000
Align 6000
Align 7000
Align 8000
Align 9000
Al