# Align using NN Acoustic Model
https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html

In [None]:
%run ../prongen/hmm_pron.py --in-jupyter
%run ../acmodel/plot.py
%matplotlib ipympl

In [None]:
%run ../acmodel/matrix.py
%run ../acmodel/praat_ifc.py

%run ../acmodel/hmm_acmodel.py

In [None]:
from IPython.core.magic import register_line_cell_magic

@register_line_cell_magic
def writetemplate(line, cell):
    with open(line, 'w') as f:
        f.write(cell.format(**globals()))

In [None]:
#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Using {device} device")

## Get training data
We aligned Czech CommonVoice train set using an ultra-prinmitive HMM/GMM. Let's use it as a starting point.

In [None]:
import pandas as pd
pd.set_option('display.max_colwidth', None)
#df = pd.read_csv("mega_training_0020.tsv", sep="\t", keep_default_na=False)

df = pd.read_csv("mini.tsv", sep="\t", keep_default_na=False)



In [None]:
df

In [None]:
hmms = []
for wav, sentence in list(zip(df.wav.values, df.sentence.values)):
    #hmm = HMM(sentence, wav=wav)
    hmm = HMM(sentence, wav=wav, derivatives=3)
    #hmm.targets = targets
    hmms.append(hmm)

In [None]:
hmms = []
for wav, sentence, targets in list(zip(df.wav.values, df.sentence.values, df.targets.values)):
    #hmm = HMM(sentence, wav=wav)
    hmm = HMM(sentence, wav=wav, derivatives=3)
    hmm.targets = targets
    hmms.append(hmm)

In [None]:
b_set = sorted({*"".join([hmm.b for hmm in hmms ])}) # make sorted set of all phone names in the training set
out_size = len(b_set)
in_size = hmms[0].mfcc.size(1)
" ".join(b_set), out_size, in_size

In [None]:
all_targets = "".join([hmm.targets for hmm in hmms])
train_len = len(all_targets)

In [None]:
all_mfcc = torch.cat([hmm.mfcc for hmm in hmms]).double().to(device)
#all_mfcc.to(device)
assert all_mfcc.size()[0]==train_len

## Setup PyTorch training tools

In [None]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [None]:
class SpeechDataset(Dataset):
    def __init__(self, all_mfcc, all_targets, b_set):
        self.all_mfcc = all_mfcc
        self.all_targets = all_targets
        
        self.wanted_outputs = torch.eye(len(b_set), device=device).double()
        self.output_map = {}
        for i, b in enumerate(b_set):
            self.output_map[b] = self.wanted_outputs[i] # prepare outputs with one 1 at the right place

    def __len__(self):
        return len(self.all_targets)

    def __getitem__(self, idx):
        return self.all_mfcc[idx], self.output_map[self.all_targets[idx]]

In [None]:
training_data = SpeechDataset(all_mfcc, all_targets, b_set)

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_size, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, out_size)
            #nn.LogSoftmax()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [None]:
model = NeuralNetwork().to(device)
print(model)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

## Load trained model
https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

In [None]:
#model.load_state_dict(torch.load('model_weights_40.pth'))
#model.load_state_dict(torch.load('model_weights_g2_40.pth'))
#model.load_state_dict(torch.load('model_weights_g3_40.pth'))
#model.load_state_dict(torch.load('model_weights_g4_40.pth'))

#model.load_state_dict(torch.load('mega_training_0020.pth'))
#model.load_state_dict(torch.load('mega_training_0010.pth'))
model.load_state_dict(torch.load('mega_training_0025.pth'))
model.eval()

## Alignment code variant for NN

## Run inference
https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html

In [None]:
def compute_hmm_nn_b(hmm, nn_model, full_b_set):
    """
    For a sentence hmm model with an attached mfcc, compute b() values
    for every sound frame and every model state, using NN phone model.
    """
    logits = nn_model(hmm.mfcc.double())
    pred_probab = nn.Softmax(dim=1)(logits)
   
    # Now select b() columns as needed for this hmm
    
    ph_to_i = {ph:i for i, ph in enumerate(full_b_set)} # map phone to column
    
    idx = torch.tensor([ph_to_i[ph] for ph in hmm.b])
    return(pred_probab[:, idx]) # repeat each b() column as needed

In [None]:
def viterbi_align_nn(hmm, nn_model, full_b_set):
    """
    Align hmm states with mfcc, using b_means phone models dictionary
    """
    b = m(compute_hmm_nn_b(hmm, nn_model, full_b_set))
    A = m(hmm.A)
    tmax = hmm.mfcc.size()[0]
    len_x = A.size()[0]
    x_list = [1]+[0]*(len_x-1)
    x_m = m([x_list])
    exponent = 0
    # allocate space for mantissa-like (kept in range) and row-exponent values
    alpha_m = m.rowlist((tmax,len_x))

    for row in range(tmax):
        s = x_m.val.sum() # renormalize
        x_m.val *= 1/s
        exponent += s.log()
        alpha_m[row] = x_m
        x_m = x_m@A*b[row]
    return alpha_m, exponent

In [None]:
%%time
for idx, hmm in enumerate(hmms):
    if idx%100==0:
        print(idx)
    
    alpha_m, alpha_exp = viterbi_align_nn(hmm, model, b_set)
    alp = torch.cat(alpha_m.val)
    hmm.intervals, bap = backward_alignment_pass_intervals(hmm, alp)
    hmm.indices = i = bap.max(1).indices
    s = "".join([hmm.b[ii] for ii in i])
    hmm.troubling = troubling_alignmet(s)
    hmm.targets = "".join([hmm.b[ii] for ii in i])

#CPU times: user 39min 35s, sys: 14.8 s, total: 39min 50s
#Wall time: 4min 58s

In [None]:
df['targets'] = [hmm.targets for hmm in hmms]

#df.to_csv("nn_train_g2.tsv", sep="\t", index=False)
#df.to_csv("nn_train_g3.tsv", sep="\t", index=False)
#df.to_csv("nn_train_g4.tsv", sep="\t", index=False)
#df.to_csv("nn_train_g5.tsv", sep="\t", index=False)

In [None]:
hmm = hmms[-1]

In [None]:
bbb = compute_hmm_nn_b(hmm, model, b_set)

In [None]:
bbb.log()

In [None]:
plot_matrix(bbb.log().detach().numpy())

https://pytorch.org/docs/stable/generated/torch.nonzero.html

In [None]:
hmm = hmms[3]

In [None]:
A = torch.tensor(hmm.A)
A

In [None]:
A.nonzero(as_tuple=True)

In [None]:
A.nonzero()

In [None]:
A[A.nonzero(as_tuple=True)]

In [None]:
print(hmm)

In [None]:
col = A[:,6]
col

https://discuss.pytorch.org/t/find-indices-with-value-zeros/10151

In [None]:
torch.nonzero(col.view(-1).data).squeeze()

In [None]:
col = A[:,1]>0
col

In [None]:
len_x = A.size()[0]

x_list = [5]+[0]*(len_x-1)
x_m = torch.tensor([x_list])

In [None]:
x = x_m[0]
x

In [None]:
col

In [None]:
x[col]

In [None]:
x[col].max()

In [None]:



def viterbi_align_log_nn(hmm, nn_model, full_b_set):
    """
    Align hmm states with mfcc, using b_means phone models dictionary.
    Work with logprobs to avoid numeric problems.
    """
    #b = m(compute_hmm_nn_b(hmm, nn_model, full_b_set))
    log_b = compute_hmm_nn_b(hmm, nn_model, full_b_set).log().detach()
    
    
    
    # convert A to our special sparse format
    A = m(hmm.A)
    len_x = A.size()[0]

    
    
    for ci in range(len_x):
        column = A[:,ci]
        # find nonzero elements in column, indicating source cells in token passing
    
    
    
    tmax = hmm.mfcc.size()[0]
    x_list = [1]+[0]*(len_x-1)
    x_m = m([x_list])
    exponent = 0
    # allocate space for mantissa-like (kept in range) and row-exponent values
    alpha_m = m.rowlist((tmax,len_x))

    for row in range(tmax):
        s = x_m.val.sum() # renormalize
        x_m.val *= 1/s
        exponent += s.log()
        alpha_m[row] = x_m
        x_m = x_m@A*b[row]
    return alpha_m, exponent


In [None]:
#print(hmm)

In [None]:
plot_matrix(bbb.detach().numpy())

In [None]:
plot_fun(bbb.detach().numpy())

In [None]:
alpha_m, alpha_exp = viterbi_align_nn(hmm, model, b_set)
alp = torch.cat(alpha_m.val)
intervals, bap = backward_alignment_pass_intervals(hmm, alp) # intervals will be sent to praat
#intervals

In [None]:
plot_matrix(alp)

In [None]:
alp

In [None]:
plot_matrix(torch.cat([bap, alp], dim=1))

In [None]:
tft = textgrid_file_text({"segmenty": intervals})
!cp {hmm.wav} test.wav

In [None]:
%%writetemplate test.TextGrid
{tft}

In [None]:
STOP

In [None]:
print(hmms[-1])

In [None]:
#hmm.mfcc

In [None]:
logits = model(hmm.mfcc.double())
pred_probab = nn.Softmax(dim=1)(logits)

In [None]:
pred_probab.size()

In [None]:
plot_matrix(pred_probab.detach().numpy())

In [None]:
X = all_mfcc[10] # likely a silence sample
logits = model(X[None]) # needs added dimension, otherwise flatten() inside breaks
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

In [None]:
b_set[30]

In [None]:
hmm = hmms[2]

In [None]:
hmm.targets

In [None]:

mfcc_double = hmm.mfcc.double()

t=""

for fi in range(hmm.mfcc.size()[0]):

    X = mfcc_double[fi] # likely a silence sample
    logits = model(X[None]) # needs added dimension, otherwise flatten() inside breaks
    pred_probab = nn.Softmax(dim=1)(logits)
    y_pred = pred_probab.argmax(1)
    #print(f"Predicted class: {y_pred}")
    t += b_set[y_pred]
    
t

In [None]:
print(hmm)

In [None]:
hmm.mfcc.size()[0]

# BELOW IS TRAINING - JUST FYI, IT IS DONE IN ANOTHER NOTEBOOK

In [None]:
STOP # stop here if we somehow get to this part

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20000 == 19999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')

In [None]:
for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20000 == 19999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
torch.save(model.state_dict(), 'model_weights_40.pth')

In [None]:
torch.save(model, 'model.pth')

https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html