# Train & Align NN AM with Win-to-MFCC and simple speaker adaptation
Repeatedly re-align phone labels sequence while training the phones model.
To avoid proliferation of the more frequent phones (and mostly the silence), we diminish b() probabilities of frequent phones during re-alignment. We use 3 states pre phone.

This version of training adds simple speaker adaptation by making average cepstra over the recording visible as NN inputs (very simple i-vector like approach). We split MFCC to 4 groups according to log-energy (split to above average and below average, then each group is further split into two the same way).

(Future versions might compute MFCC averages for key frequent phones.)

# Config

In [None]:
infile = 'sv200c-100_training_0024.tsv'
sideview = 9 # how many additional MFCC frames before and after the focus point are seen
mid_size = 100
filename_base_base = "sv100mulX_training"

In [None]:
%run ../prongen/hmm_pron.py --in-jupyter
%run ../acmodel/plot.py
%matplotlib ipympl
%run ../acmodel/matrix.py
%run ../acmodel/praat_ifc.py
%run ../acmodel/hmm_acmodel.py

device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
print(f"Using {device} device")

%run ../acmodel/nn_acmodel.py

## Get training data
We previously aligned Czech CommonVoice train set using an ultra-prinmitive HMM/GMM and then NNs. Let's continue on it.

In [None]:
#infile = "mega4_training_0021.tsv"
#infile = 'nn_train.tsv'
#infile = 'sv200con_training_0004.tsv'
#infile = 'sv200c-300_training_0027.tsv'

df = pd.read_csv(infile, sep="\t", keep_default_na=False)
hmms = get_training_hmms(infile, derivatives=0)
b_log_corr = b_log_corrections(infile) # get b() corrections based on frequency

In [None]:
# Upgrade to 3 states per phone (just for duration, b() is still shared)
for hmm in hmms:
    #triple_hmm_states(hmm)
    multiply_hmm_states(hmm)

In [None]:
all_mfcc, all_targets, b_set = collect_training_material(hmms)

out_size = len(b_set)
in_size = hmms[0].mfcc.size(1)

" ".join(b_set), out_size, in_size

## Add speaker vectors (mean cepstra in 4 energy bands)

In [None]:
# Make s-vectors
all_speaker_vectors_refs = []
for hmm in hmms:
    hmm.speaker_vector = mfcc_make_speaker_vector(hmm.mfcc)
    ref = hmm.speaker_vector.to(device)
    all_speaker_vectors_refs += [ref]*len(hmm.mfcc)

## Changes for the Window-to-MFCC input

In [None]:
in_size = hmms[0].mfcc.size(1) * (sideview+1+sideview) + 4*13 # added s-vector
in_size

In [None]:
# for alignment decoding, change mfcc in all hmms (for training, we already have a copy)
# NOTE: Make speaker vectors BEFORE this!
for hmm in hmms:
    hmm.mfcc = mfcc_win_view(mfcc_add_sideview(hmm.mfcc, sideview=sideview), sideview=sideview)

## Setup training

In [None]:
model = NeuralNetwork(in_size, out_size, mid_size).to(device) # 50 20 100=svec 50=sv50
print(model)

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

training_data = SpeechDataset(all_mfcc, all_targets, b_set, sideview=sideview, speaker_vectors=all_speaker_vectors_refs) # initial alignment

In [None]:
for mega_epoch in range(0, 30):
    print(f"======= Train {filename_base_base}, Epoch {mega_epoch} ========")

    all_targets = "".join([hmm.targets for hmm in hmms])  # collect alignments
    training_data.all_targets = all_targets  # just update the object with new targets

    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) # new dataloader for this alignment

    train_n_epochs(train_dataloader, optimizer, model, criterion, 5)

    
    filename_base = f"{filename_base_base}_{'%04d' % mega_epoch}"
    torch.save(model.state_dict(), filename_base+".pth")

    #break

    print('Interrupted training for re-alignment...')

    model.eval() # switch to evaluation mode


    for idx, hmm in enumerate(hmms):
        if idx%1000==0:
            print(f"Align {idx}")   
        #alp = align_hmm(hmm, model, b_set, b_log_corr=b_log_corr*1.0, group_tripled=True)
        alp = align_hmm(hmm, model, b_set, b_log_corr=b_log_corr*1.0, group_tripled=False)


    df['targets'] = [hmm.targets for hmm in hmms]

    df.to_csv(filename_base+".tsv", sep="\t", index=False)
    
    b_log_corr = b_log_corrections(filename_base+".tsv") # get new b() corrections based on frequency