# Train & Align NN AM using also manually time-aligned data
This notebook is like [`NN_Train_Align.ipynb`](NN_Train_Align.ipynb) but uses additional
manually time-labeled data. You should already have:
* `initial_train_cs.tsv` made by [`Prepare_Training_Data.ipynb`](Prepare_Training_Data.ipynb)
* `manual_train.tsv` made by [`Prep_Manual_Train_Data.ipynb`](Prep_Manual_Train_Data.ipynb)

In the first iteration, model is trained just on the `manual_train.tsv` data and then used
to time-align data from `initial_train_cs.tsv`.
Following iterations refine the model on both datasets. The part coming from `initial_train_cs.tsv` is repeatedly re-aligned while the `manual_train.tsv` part is kept at the hand-made alignment. The details of the phone boundary positioning are thereby anchored to the human made examples.

# Config

In [None]:
infile_cv = 'initial_train_cs.tsv'# no phone targets in this tsv yet
infile_man = 'manual_train.tsv' # manually labeled phones and their time boundaries

sideview = 9 # how many additional MFCC frames before and after the focus point are seen
mid_size = 100
filename_base_base = "man_both_training"

In [None]:
import sys
if sys.path[0] != '..':
    sys.path[0:0] = ['..'] # prepend main Prak directory
from prongen.hmm_pron import *
from acmodel.praat_ifc import *
from acmodel.nn_acmodel import *
import acmodel

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
acmodel.nn_acmodel.device = device
print(f"Using {device} device")

## Get training data

In [None]:
import pandas as pd

In [None]:
df_cv = pd.read_csv(infile_cv, sep="\t", keep_default_na=False)
hmms_cv = get_training_hmms(infile_cv, derivatives=0)

In [None]:
df_man = pd.read_csv(infile_man, sep="\t", keep_default_na=False)
hmms_man = get_training_hmms(infile_man, derivatives=0)

In [None]:
df_cv

In [None]:
df_man

In [None]:
hmms = hmms_man+hmms_cv # hmms_man should be the first

In [None]:
for hmm in hmms:
    triple_hmm_states(hmm) # Upgrade to 3 states per phone (just for duration, b() is still shared)

In [None]:
# create pro-forma targets, will not be used in this variant of training
for hmm in hmms_cv:
    create_start_targets(hmm)

In [None]:
b_log_corr = b_log_corrections(infile_man, b_set=b_set) # get b() corrections based on frequency

In [None]:
len(b_set)

In [None]:
out_size = len(b_set)
in_size = hmms[0].mfcc.size(1)

" ".join(b_set), out_size, in_size

In [None]:
# DO THIS NOW, mfccs are modified below!
# (we use tricky way to access mfcc in training, different from the inference time)
all_mfcc, all_targets = collect_training_material(hmms)

## Add speaker vectors (mean cepstra in 4 energy bands)

In [None]:
# Make s-vectors
all_speaker_vectors_refs = []
for hmm in hmms:
    hmm.speaker_vector = mfcc_make_speaker_vector(hmm.mfcc)
    ref = hmm.speaker_vector.to(device)
    all_speaker_vectors_refs += [ref]*len(hmm.mfcc)

## Changes for the Window-to-MFCC input

In [None]:
in_size = hmms[0].mfcc.size(1) * (sideview+1+sideview) + 4*13 # added s-vector
in_size

In [None]:
# for alignment decoding, change mfcc in all hmms (for training, we already have a copy)
# NOTE: Make speaker vectors BEFORE this!
for hmm in hmms:
    hmm.mfcc = mfcc_win_view(mfcc_add_sideview(hmm.mfcc, sideview=sideview), sideview=sideview)

## Setup training

In [None]:
model = NeuralNetwork(in_size, out_size, mid_size).to(device) # 50 20 100=svec 50=sv50
print(model)

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

training_data = SpeechDataset(all_mfcc, all_targets, b_set, sideview=sideview, speaker_vectors=all_speaker_vectors_refs) # initial alignment
training_data.ignored_end = len("".join([hmm.targets for hmm in hmms_cv])) # first train without cv

In [None]:
for mega_epoch in range(1, 50): # starting from 1 as we have zero tsv
    print(f"======= Train {filename_base_base}, Epoch {mega_epoch} ========")
    print(f"{len(training_data)=}")

    all_targets = "".join([hmm.targets for hmm in hmms])  # collect alignments
    training_data.all_targets = all_targets  # just update the object with new targets

    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) # new dataloader for this alignment

    train_n_epochs(train_dataloader, optimizer, model, criterion, 20 if mega_epoch==1 else 5)
    
    filename_base = f"{filename_base_base}_{'%04d' % mega_epoch}"
    torch.save(model.state_dict(), filename_base+".pth")
    
    training_data.ignored_end = 0 # Use both man+cv data for all the following training
    print('Interrupted training for re-alignment...')

    model.eval() # switch to evaluation mode

    for idx, hmm in enumerate(hmms_cv): # aligning CV part, NOT the manually aligned part
        if idx%1000==0:
            print(f"Align {idx}")   
        alp = align_hmm(hmm, model, b_set, b_log_corr=b_log_corr, group_tripled=True)

    df_cv['targets'] = [hmm.targets for hmm in hmms_cv]
    df = pd.concat([df_man, df_cv])
    df.to_csv(filename_base+".tsv", sep="\t", index=False)
    
    b_log_corr = b_log_corrections(filename_base+".tsv", b_set=b_set) # get new b() corrections based on frequency