# Align Milanek using NN Acoustic Model
https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html

In [None]:
%run ../prongen/hmm_pron.py --in-jupyter
#%run ../acmodel/plot.py
#%matplotlib ipympl

In [None]:
%run ../acmodel/matrix.py
%run ../acmodel/praat_ifc.py

%run ../acmodel/hmm_acmodel.py

#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Using {device} device")

%run ../acmodel/nn_acmodel.py

In [None]:
from IPython.core.magic import register_line_cell_magic

@register_line_cell_magic
def writetemplate(line, cell):
    with open(line, 'w') as f:
        f.write(cell.format(**globals()))

### Get training data to compute sizes and b_set
This should be replaced by properly storing all inference details with a model!

In [None]:
import pandas as pd
pd.set_option('display.max_colwidth', None)
df = pd.read_csv("mega3_training_0003.tsv", sep="\t", keep_default_na=False)

#df = pd.read_csv("mini.tsv", sep="\t", keep_default_na=False)

In [None]:
hmms = []
for wav, sentence, targets in list(zip(df.wav.values, df.sentence.values, df.targets.values)):
    #hmm = HMM(sentence, wav=wav)
    hmm = HMM(sentence, wav=wav, derivatives=3)
    hmm.targets = targets
    hmms.append(hmm)

In [None]:
b_set = sorted({*"".join([hmm.b for hmm in hmms ])}) # make sorted set of all phone names in the training set
out_size = len(b_set)
in_size = hmms[0].mfcc.size(1)
" ".join(b_set), out_size, in_size

In [None]:
all_targets = "".join([hmm.targets for hmm in hmms])
train_len = len(all_targets)

In [None]:
all_mfcc = torch.cat([hmm.mfcc for hmm in hmms]).double().to(device)
assert all_mfcc.size()[0]==train_len

In [None]:
training_data = SpeechDataset(all_mfcc, all_targets, b_set)

In [None]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

# Load NN Acoustic Model

In [None]:
# get just b_set
t = ""
for targets in df.targets.values:
    t += targets

b_set = sorted({*t}) # make sorted set of all phone names in the training set
#out_size = len(b_set)
" ".join(b_set)

In [None]:
in_size = hmms[0].mfcc.size(1)
out_size = len(b_set)

In [None]:
in_size = 52
out_size = 45


In [None]:
x_set = b_set

In [None]:
# SKIP THIS

groups = 'aá eé yý oó uú pb td ťď kg HhG cZ čŽ sz šž fv'


def group_labels(groups, labels):
    """
    Simplify label set by replacing each label in a string by the
    first member of its group. Input in labels can be any iterable
    (likely list or string), output is a string.
    """
    lab = {}
    for phone in labels:
        lab[phone] = phone # default to be overwriten below
    for grp in groups.split():
        for p in grp:
            lab[p] = grp[0] # first phone in group represents it
        
    
    return "".join(lab[p] for p in labels)


x_set = sorted({*group_labels(groups, b_set)})
out_size = len(x_set) # 29

In [None]:
model = NeuralNetwork(in_size, out_size, 50).to(device)
print(model)

# Change NN Model

In [None]:
#filename_base = 'mega6_training_0024'
#filename_base = 'group5i_training_0021'
#filename_base = 'group_training_0042'
#filename_base = 'tri_training_0011'
#filename_base = 'tri_training_0022'

filename_base = 'minig_training_0000'

model.load_state_dict(torch.load(filename_base+".pth"))
b_log_corr = b_log_corrections(filename_base+".tsv")

model.eval()

# Select one wav to align

In [None]:
milanek = """
 Milánek 
 maminka se zeptala Milana 
 Milánku už máš hotový úkol 
 kdy ho budeš psát 
 Milan chvilku přemýšlel a pak odpověděl 
 já musím napsat pár souvětí na Říhovou kde budou nějaké gramatické fígle 
 například vztažné věty 
 čárky před á a podobně 
 vlastně jsem tě chtěl poprosit jestli mi s tím nepomůžeš 
 můžeme se na to mrknout klidně hned řekla maminka 
 jen bych dala vařit vodu na čaj a podívám se jestli máme citrony 
 jak dlouho nám ten úkol zabere 
 bude to těžké 
 no mají tam být i různé příklady na zastaralou a knižní slovní zásobu 
 skoro půlku jsem už ve škole udělal ale moc dobře mi to nešlo 
 chtěl bych začít co nejdřív 
 až budu hotov došel bych ti do lékárny pro ten pneumocyt 
 a potom půjdu hrát fotbal 
 včera jsem dal čtyři góly 
 nebýt láďových faulů mohlo jich být víc 
 jen míč budu muset přifouknout 
 neboj se dám pozor na auta
"""

In [None]:
#txt = "Milánku, už máš hotový úkol? Kdy ho budeš psát?"
#wav = "milanek/mil.wav"
txt = milanek
wav = "milanek/AH-milanek.wav"

In [None]:
waveform, sample_rate = torchaudio.load(wav)
sample_rate

In [None]:
transform = torchaudio.transforms.Resample(orig_freq=sample_rate)
resampled_waveform = transform(waveform)
waveform.size(), resampled_waveform.size()

In [None]:
torchaudio.save("resampled.wav", resampled_waveform, 16000)

In [None]:
hmm = HMM(txt, wav='resampled.wav', derivatives=3)

In [None]:
# SKIP THIS

hmm.b = group_labels(groups, hmm.b)

In [None]:
hmm.b

In [None]:
# CONDITIONAL:

triple_hmm_states(hmm)

In [None]:
%%time
alp = viterbi_log_align_nn(hmm, model, x_set, b_log_corr=b_log_corr*1.0)
hmm.intervals = backward_log_alignment_pass_intervals(hmm, alp) # also modifies alp
hmm.indices = i = alp.max(1).indices
s = "".join([hmm.b[ii] for ii in i])
hmm.troubling = troubling_alignmet(s)
hmm.targets = "".join([hmm.b[ii] for ii in i])
# Wall time: 5.89 s., for triple states: 6.28 s

In [None]:
def group_tripled_intervals(intervals):
    """
    Fix tripling of decoded intervals caused by triple_hmm_states()
    """
    result = []
    while intervals:
        (beg, _, phone), (_, _, p2), (_, end, p3), *intervals = intervals
        assert phone == p2 == p3
        result.append((beg, end, phone))
    return result

In [None]:
hmm.intervals = group_tripled_intervals(hmm.intervals)

In [None]:
tft = textgrid_file_text({"segmenty": hmm.intervals})

In [None]:
%%writetemplate milanek/milanek_minig00.TextGrid
{tft}

In [None]:
STOP

# Load the saved textgrid to praat now.

# Some experiments - align training data, use also reversed Viterbi

In [None]:
#%%time
for idx, hmm in enumerate(hmms[32:33]):   # 3177 is quite problematic   4177 nice   4178 forw/backw slight diverg.
    if idx%100==0:
        print(idx)
    
    alp = viterbi_log_align_nn(hmm, model, b_set)
    hmm.intervals = backward_log_alignment_pass_intervals(hmm, alp) # also modifies alp
    hmm.indices = i = alp.max(1).indices
    s = "".join([hmm.b[ii] for ii in i])
    hmm.troubling = troubling_alignmet(s)
    hmm.targets = "".join([hmm.b[ii] for ii in i])

#CPU times: user 39min 35s, sys: 14.8 s, total: 39min 50s
#Wall time: 4min 58s

In [None]:
hmm.targets

In [None]:
hmm.mfcc.size()

In [None]:
alp

In [None]:
cmap=['hsv', 'viridis', 'twilight', 'brg', 'gist_rainbow', 'gist_ncar', 'nipy_spectral'][-1]
plot_matrix(alp.clamp(min=-400), cmap=cmap)

In [None]:
STOP

In [None]:
plot_wavfile(hmm.wav)

In [None]:
print(hmm)

In [None]:
hmm.add_timrev()

In [None]:
    alp = viterbi_log_align_nn(hmm.timrev, model, b_set, timrev=True)
    hmm.timrev.intervals = backward_log_alignment_pass_intervals(hmm.timrev, alp) # also modifies alp
    hmm.timrev.indices = i = alp.max(1).indices
    s = "".join([hmm.timrev.b[ii] for ii in i])
    hmm.timrev.troubling = troubling_alignmet(s)
    hmm.timrev.targets = "".join([hmm.timrev.b[ii] for ii in i])
    alp = alp.flip(0).flip(1)


In [None]:
cmap=['hsv', 'viridis', 'twilight', 'brg', 'gist_rainbow', 'gist_ncar', 'nipy_spectral'][-1]
plot_matrix(alp.clamp(min=-400), cmap=cmap)