# Align using NN Acoustic Model
https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html

In [None]:
%run ../prongen/hmm_pron.py --in-jupyter
%run ../acmodel/plot.py
%matplotlib ipympl

In [None]:
%run ../acmodel/matrix.py
%run ../acmodel/praat_ifc.py

%run ../acmodel/hmm_acmodel.py

#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Using {device} device")

%run ../acmodel/nn_acmodel.py

In [None]:
from IPython.core.magic import register_line_cell_magic

@register_line_cell_magic
def writetemplate(line, cell):
    with open(line, 'w') as f:
        f.write(cell.format(**globals()))

## Get training data
We aligned Czech CommonVoice train set using an ultra-prinmitive HMM/GMM. Let's use it as a starting point.

In [None]:
import pandas as pd
pd.set_option('display.max_colwidth', None)
df = pd.read_csv("mega_training_0005.tsv", sep="\t", keep_default_na=False)

#df = pd.read_csv("mini.tsv", sep="\t", keep_default_na=False)

In [None]:
hmms = []
for wav, sentence, targets in list(zip(df.wav.values, df.sentence.values, df.targets.values)):
    #hmm = HMM(sentence, wav=wav)
    hmm = HMM(sentence, wav=wav, derivatives=3)
    hmm.targets = targets
    hmms.append(hmm)

In [None]:
b_set = sorted({*"".join([hmm.b for hmm in hmms ])}) # make sorted set of all phone names in the training set
out_size = len(b_set)
in_size = hmms[0].mfcc.size(1)
" ".join(b_set), out_size, in_size

In [None]:
all_targets = "".join([hmm.targets for hmm in hmms])
train_len = len(all_targets)

In [None]:
all_mfcc = torch.cat([hmm.mfcc for hmm in hmms]).double().to(device)
#all_mfcc.to(device)
assert all_mfcc.size()[0]==train_len

In [None]:
training_data = SpeechDataset(all_mfcc, all_targets, b_set)

In [None]:
in_size = hmms[0].mfcc.size(1)
out_size = len(b_set)

In [None]:
model = NeuralNetwork(in_size, out_size).to(device)
print(model)

In [None]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

In [None]:
model.load_state_dict(torch.load('mega_training_0005.pth'))
model.eval()

In [None]:
#df['targets'] = [hmm.targets for hmm in hmms]

#df.to_csv("nn_train_g2.tsv", sep="\t", index=False)
#df.to_csv("nn_train_g3.tsv", sep="\t", index=False)
#df.to_csv("nn_train_g4.tsv", sep="\t", index=False)
#df.to_csv("nn_train_g5.tsv", sep="\t", index=False)

# Choose wav

In [None]:
#%%time
for idx, hmm in enumerate(hmms[32:33]):   # 3177 is quite problematic   4177 nice   4178 forw/backw slight diverg.
    if idx%100==0:
        print(idx)
    
    alp = viterbi_log_align_nn(hmm, model, b_set)
    hmm.intervals = backward_log_alignment_pass_intervals(hmm, alp) # also modifies alp
    hmm.indices = i = alp.max(1).indices
    s = "".join([hmm.b[ii] for ii in i])
    hmm.troubling = troubling_alignmet(s)
    hmm.targets = "".join([hmm.b[ii] for ii in i])

#CPU times: user 39min 35s, sys: 14.8 s, total: 39min 50s
#Wall time: 4min 58s

In [None]:
hmm.targets

In [None]:
hmm.mfcc.size()

In [None]:
alp

In [None]:
cmap=['hsv', 'viridis', 'twilight', 'brg', 'gist_rainbow', 'gist_ncar', 'nipy_spectral'][-1]
plot_matrix(alp.clamp(min=-400), cmap=cmap)

In [None]:
STOP

In [None]:
plot_wavfile(hmm.wav)

In [None]:
print(hmm)

In [None]:
hmm.add_timrev()

In [None]:
    alp = viterbi_log_align_nn(hmm.timrev, model, b_set, timrev=True)
    hmm.timrev.intervals = backward_log_alignment_pass_intervals(hmm.timrev, alp) # also modifies alp
    hmm.timrev.indices = i = alp.max(1).indices
    s = "".join([hmm.timrev.b[ii] for ii in i])
    hmm.timrev.troubling = troubling_alignmet(s)
    hmm.timrev.targets = "".join([hmm.timrev.b[ii] for ii in i])
    alp = alp.flip(0).flip(1)


In [None]:
cmap=['hsv', 'viridis', 'twilight', 'brg', 'gist_rainbow', 'gist_ncar', 'nipy_spectral'][-1]
plot_matrix(alp.clamp(min=-400), cmap=cmap)