# Train NN Acoustic Model
https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html

In [None]:
%run ../prongen/hmm_pron.py --in-jupyter
%run ../acmodel/plot.py
%matplotlib ipympl

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

## Get training data
We aligned Czech CommonVoice train set using an ultra-prinmitive HMM/GMM. Let's use it as a starting point.

In [None]:
import pandas as pd
pd.set_option('display.max_colwidth', None)

#f = pd.read_csv("nn_train.tsv", sep="\t", keep_default_na=False)
#df = pd.read_csv("nn_train_g2.tsv", sep="\t", keep_default_na=False)
#df = pd.read_csv("nn_train_g3.tsv", sep="\t", keep_default_na=False)
df = pd.read_csv("nn_train_g4.tsv", sep="\t", keep_default_na=False)


In [None]:
#df

In [None]:
hmms = []
for wav, sentence, targets in list(zip(df.wav.values, df.sentence.values, df.targets.values)):
    hmm = HMM(sentence, wav=wav, use_DA=True)
    hmm.targets = targets
    hmms.append(hmm)

In [None]:
b_set = sorted({*"".join([hmm.b for hmm in hmms ])}) # make sorted set of all phone names in the training set
out_size = len(b_set)
in_size = hmms[0].mfcc.size(1)
" ".join(b_set), out_size, in_size

In [None]:
all_targets = "".join([hmm.targets for hmm in hmms])
train_len = len(all_targets)

In [None]:
all_mfcc = torch.cat([hmm.mfcc for hmm in hmms]).double().to(device)
#all_mfcc.to(device)
assert all_mfcc.size()[0]==train_len

## Setup PyTorch training tools

In [None]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
#from torchvision import datasets, transforms

In [None]:
class SpeechDataset(Dataset):
    def __init__(self, all_mfcc, all_targets, b_set):
        self.all_mfcc = all_mfcc
        self.all_targets = all_targets
        
        self.wanted_outputs = torch.eye(len(b_set), device=device).double()
        self.output_map = {}
        for i, b in enumerate(b_set):
            self.output_map[b] = self.wanted_outputs[i] # prepare outputs with one 1 at the right place

    def __len__(self):
        return len(self.all_targets)

    def __getitem__(self, idx):
        return self.all_mfcc[idx], self.output_map[self.all_targets[idx]]

In [None]:
training_data = SpeechDataset(all_mfcc, all_targets, b_set)

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_size, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, out_size)
            #nn.LogSoftmax()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [None]:
model = NeuralNetwork().to(device)
print(model)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
#test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20000 == 19999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
#[19, 20000] loss: 1.181 1.187 1.190 1.179 1.184 [20, 60000] loss: 1.187
#
# g2:[19, 20000] loss: 0.997 1.003 1.008 0.995 1.002 [20, 60000] loss: 1.002
# g3: [19, 20000] loss: 0.950 0.953 0.955 0.947 0.950 [20, 60000] loss: 0.954
# g4: [19, 20000] loss: 0.922 0.927 0.928 0.920 0.924 [20, 60000] loss: 0.925

In [None]:
torch.save(model.state_dict(), 'model_weights_g4_20.pth')

In [None]:
for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20000 == 19999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
#[19, 20000] loss: 1.149 1.153 1.155 1.146 1.152 #[20, 60000] loss 1.156
#
# g2:[19, 20000] loss: 0.961 0.969 0.973 0.961 0.968 [20, 60000] loss: 0.971
# g3: [19, 20000] loss: 0.911 0.918 0.922 0.910 0.918 [20, 60000] loss: 0.919

In [None]:
#torch.save(model.state_dict(), 'model_weights_g2_40.pth')
torch.save(model.state_dict(), 'model_weights_g4_40.pth')


In [None]:
STOP

In [None]:
torch.save(model, 'model.pth')

https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

In [None]:
for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20000 == 19999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
torch.save(model.state_dict(), 'model_weights_60.pth')

In [None]:
for epoch in range(40):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20000 == 19999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
torch.save(model.state_dict(), 'model_weights_100.pth')

In [None]:
for epoch in range(40):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20000 == 19999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
torch.save(model.state_dict(), 'model_weights_140.pth')