# Train & Align NN Acoustic Model with Window-to-MFCC input
Repeatedly re-align phone labels sequence while training the phones model.
To avoid proliferation of the more frequent phones (and mostly the silence), we diminish b() probabilities of frequent phones during re-alignment.

In [None]:
%run ../prongen/hmm_pron.py --in-jupyter
%run ../acmodel/plot.py
%matplotlib ipympl
%run ../acmodel/matrix.py
%run ../acmodel/praat_ifc.py
%run ../acmodel/hmm_acmodel.py

device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
print(f"Using {device} device")

%run ../acmodel/nn_acmodel.py

## Get training data
We previously aligned Czech CommonVoice train set using an ultra-prinmitive HMM/GMM and then NNs. Let's continue on it.

In [None]:
#infile = "mega4_training_0021.tsv"
infile = 'nn_train.tsv'
df = pd.read_csv(infile, sep="\t", keep_default_na=False)
hmms = get_training_hmms(infile, derivatives=0)
b_log_corr = b_log_corrections(infile) # get b() corrections based on frequency

In [None]:

# CONDITIONAL:

for hmm in hmms:
    triple_hmm_states(hmm)

In [None]:
all_mfcc, all_targets, b_set = collect_training_material(hmms)

out_size = len(b_set)
in_size = hmms[0].mfcc.size(1)

" ".join(b_set), out_size, in_size

## Changes for the Window-to-MFCC input

In [None]:
sideview = 9 # how many additional MFCC frames before and after the focus point are seen
in_size = hmms[0].mfcc.size(1) * (sideview+1+sideview)
in_size

In [None]:
# for alignment decoding, change mfcc in all hmms (for training, we already have a copy)

for hmm in hmms:
    hmm.mfcc = mfcc_win_view(mfcc_add_sideview(hmm.mfcc))



In [None]:
class SpeechDataset(Dataset):
    def __init__(self, all_mfcc, all_targets, b_set, sidewiew = 9):
        self.all_mfcc = all_mfcc
        self.all_targets = all_targets
        
        self.wanted_outputs = torch.eye(len(b_set), device=device).double()
        self.output_map = {}
        for i, b in enumerate(b_set):
            self.output_map[b] = self.wanted_outputs[i] # prepare outputs with one 1 at the right place

    def __len__(self):
        return len(self.all_targets) - 2*sideview

    def __getitem__(self, idx):
        idx += sideview
        return self.all_mfcc[idx-sideview:idx+sideview+1], self.output_map[self.all_targets[idx]]


## Setup PyTorch training tools

In [None]:
#model = NeuralNetwork(in_size, out_size).to(device)
model = NeuralNetwork(in_size, out_size, 100).to(device) # 50 20
print(model)

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
training_data = SpeechDataset(all_mfcc, all_targets, b_set, sidewiew = 9) # initial alignment


for mega_epoch in range(100):
    print(f"============ Training Mega Epoch {mega_epoch} =============")

    all_targets = "".join([hmm.targets for hmm in hmms])  # collect alignments
    training_data.all_targets = all_targets  # just update the object with new targets

    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) # new dataloader for this alignment

    train_n_epochs(train_dataloader, optimizer, model, criterion, 5)

    
    filename_base = f"wali_training_{'%04d' % mega_epoch}"
    torch.save(model.state_dict(), filename_base+".pth")

    #continue
    #STOP

    print('Interrupted training for re-alignment...')

    model.eval() # switch to evaluation mode


    for idx, hmm in enumerate(hmms):
        if idx%1000==0:
            print(f"Align {idx}")
    
        alp = viterbi_log_align_nn(hmm, model, b_set, b_log_corr=b_log_corr*1.0) # b() corrections according to current frame frequency
        hmm.intervals = backward_log_alignment_pass_intervals(hmm, alp) # also modifies alp
        hmm.indices = i = alp.max(1).indices
        s = "".join([hmm.b[ii] for ii in i])
        hmm.troubling = troubling_alignmet(s)
        hmm.targets = "".join([hmm.b[ii] for ii in i])


    df['targets'] = [hmm.targets for hmm in hmms]

    df.to_csv(filename_base+".tsv", sep="\t", index=False)
    
    b_log_corr = b_log_corrections(filename_base+".tsv") # get new b() corrections based on frequency
