# Train Grouped Phones discriminator - 5 iters before re-align
Starting with the latest NN phone alignment, train phone groups discriminator. Groups are merged lijke this:
- short/long is ignored for vowels (mostly to help ó)
- voiced/voiceless is ignored (this groups rare voiced variants of č and c with the frequent voiceless ones)

In [None]:
%run ../prongen/hmm_pron.py --in-jupyter
%run ../acmodel/plot.py
%matplotlib ipympl
###%run ../acmodel/matrix.py
%run ../acmodel/praat_ifc.py
%run ../acmodel/hmm_acmodel.py

device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
print(f"Using {device} device")

%run ../acmodel/nn_acmodel.py

In [None]:
device

## Get training data
We previously aligned Czech CommonVoice train set using an ultra-prinmitive HMM/GMM and then NNs. Let's replace every non-silent phone label by either 'c' or 'v' (so ve have 3 labels: "cv|").

In [None]:
infile = "mega6_training_0028.tsv"
df = pd.read_csv(infile, sep="\t", keep_default_na=False)
hmms = get_training_hmms(infile, derivatives=3)
b_log_corr = b_log_corrections(infile) # get b() corrections based on frequency

In [None]:
all_mfcc, all_targets, b_set = collect_training_material(hmms)

#out_size = len(b_set)
in_size = hmms[0].mfcc.size(1)

" ".join(b_set), in_size

### Map phone labels to group representants

In [None]:
groups = 'aá eé yý oó uú pb td ťď kg HhG cZ čŽ sz šž fv'

In [None]:
lab = {}
for phone in b_set:
    lab[phone] = phone # default to be overwriten below
for grp in groups.split():
    for p in grp:
        lab[p] = grp[0] # first phone in group represents it
#lab

In [None]:
all_targets[:200]

In [None]:
x_targets = "".join(lab[p] for p in all_targets)

In [None]:
x_targets[:200]

In [None]:
x_set = sorted({*x_targets})
" ".join(x_set)

In [None]:
out_size = len(x_set)

## Setup PyTorch training tools

In [None]:
in_size, out_size

In [None]:
model = NeuralNetwork(in_size, out_size).to(device)
print(model)

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
def b_log_corrections_from_targets(targets):
    """
    Compute log(b()) additive correction needed to suppress very frequent
    phones and boost rare ones. Use string of all targets as input.
    """
    c=Counter(targets)
    return -torch.tensor([count for phone, count in sorted(i for i in c.items())]).log()

In [None]:
b_log_corr = b_log_corrections_from_targets(x_targets) # get b() corrections based on frequency of new targets

In [None]:
b_log_corr

In [None]:
for hmm in hmms: # update targets also in individual hmms (for the first iteration)
    hmm.targets = "".join(lab[p] for p in hmm.targets)
    hmm.b = "".join(lab[p] for p in hmm.b)

In [None]:
def compute_hmm_nn_log_b(hmm, nn_model, full_b_set, b_log_corr=None):
    """
    For a sentence hmm model with an attached mfcc, compute ln(b()) values
    for every sound frame and every model state, using NN phone model.
    """
    logits = nn_model(hmm.mfcc.double().to(device)).detach().to('cpu')
    
    #print(logits)
    
    
    pred_probab = nn.LogSoftmax(dim=1)(logits)
    if b_log_corr!=None:
        pred_probab += b_log_corr[None]

    # Now select b() columns as needed for this hmm
    ph_to_i = {ph:i for i, ph in enumerate(full_b_set)} # map phone to column
    
    idx = torch.tensor([ph_to_i[ph] for ph in hmm.b])
    return(pred_probab[:, idx]) # repeat each b() column as needed



def viterbi_log_align_nn(hmm, nn_model, full_b_set, timrev=False, b_log_corr=None):
    """
    Align hmm states with mfcc, working with logprobs
    """
    b = compute_hmm_nn_log_b(hmm, nn_model, full_b_set, b_log_corr)
    if timrev:
        b = b.flip(0)
    A = hmm.A
    tmax = hmm.mfcc.size()[0]
    len_x = len(A)
    x_list = [0]+[float('-inf')]*(len_x-1)
    x = torch.tensor(x_list)
    alpha = [] #growing list of rows with alpha logprobs
    A = torch.tensor(hmm.A)
    e_e_f, e_e_t = matrix_extra_edges(A) # prepare efficient representation of A
    hmm.optimized_edges = e_e_f, e_e_t  # save it for backward pass - DO THIS ELSEWHERE
    for row in range(tmax):
        s = x.max() #renormalize
        x -= s
        alpha.append(x.clone())
        next_x(x, (e_e_f, e_e_t))
        x += b[row]
    return torch.stack(alpha)

In [None]:
training_data = SpeechDataset(all_mfcc, x_targets, x_set) # initial alignment


for mega_epoch in range(100):
    print(f"============ Training Group Epoch {mega_epoch} =============")

    all_targets = "".join([hmm.targets for hmm in hmms])  # collect alignments
    training_data.all_targets = all_targets  # just update the object with new targets

    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) # new dataloader for this alignment

    train_n_epochs(train_dataloader, optimizer, model, criterion, 5)

    print('Interupted training for re-alignment...')

    model.eval() # switch to evaluation mode


    for idx, hmm in enumerate(hmms):
        if idx%1000==0:
            print(f"Align {idx}")
    
        alp = viterbi_log_align_nn(hmm, model, x_set, b_log_corr=b_log_corr*1.0) # b() corrections according to current frame frequency
        hmm.intervals = backward_log_alignment_pass_intervals(hmm, alp) # also modifies alp
        hmm.indices = i = alp.max(1).indices
        s = "".join([hmm.b[ii] for ii in i])
        hmm.troubling = troubling_alignmet(s)
        hmm.targets = "".join([hmm.b[ii] for ii in i])


    df['targets'] = [hmm.targets for hmm in hmms]

    filename_base = f"group5i_training_{'%04d' % mega_epoch}"

    torch.save(model.state_dict(), filename_base+".pth")
    df.to_csv(filename_base+".tsv", sep="\t", index=False)
    
    b_log_corr = b_log_corrections(filename_base+".tsv") # get new b() corrections based on frequency


In [None]:
alp = viterbi_log_align_nn(hmm, model, x_set, b_log_corr=b_log_corr*1.0)

In [None]:
b_log_corr

In [None]:
device

In [None]:
hmms[0].mfcc