In [8]:
import os
import math
import warnings
import torch
import torch.nn as nn
import numpy as np
from model import GOPT
import pickle
import random

data_dir = "data/seq_data_librispeech"

In [9]:
from torch.utils.data import Dataset, DataLoader

class GoPDataset(Dataset):
    def __init__(self, mode):
        # if am=='librispeech':
        dir='seq_data_librispeech'
        # norm_mean, norm_std = 3.203, 4.045
        self.mode = mode
            
        if mode == 'train':
            self.feat = torch.tensor(np.load(os.getcwd()+'/data/'+dir+'/tr_feat.npy'), dtype=torch.float)
            self.label = torch.tensor(np.load(os.getcwd()+'/data/'+dir+'/tr_label.npy'), dtype=torch.float)
        elif mode == 'test':
            self.feat = torch.tensor(np.load(os.getcwd()+'/data/'+dir+'/te_feat.npy'), dtype=torch.float)
            self.label = torch.tensor(np.load(os.getcwd()+'/data/'+dir+'/te_label.npy'), dtype=torch.float)

        # normalize the GOP feature using the training set mean and std (only count the valid token features, exclude the padded tokens).
        # self.feat = self.norm_valid(self.feat, norm_mean, norm_std)
        self.scaler = self.load_scaler(path="resources/scaler.pkl")

        # normalize data set
        tmp_feat = self.feat.reshape(-1, 84)
        tmp_feat = self.scaler.transform(tmp_feat)
        tmp_feat = tmp_feat.reshape(self.feat.shape).astype('float')
        self.feat = torch.tensor(tmp_feat, dtype=torch.float)

        self.label[:, :, 1:][self.label[:, :, 1:] !=-1 ] /= 50
        
    def load_scaler(self, path):
        with open(path, "rb") as f:
            scaler = pickle.load(f)

        return scaler

    # only normalize valid tokens, not padded token
    def norm_valid(self, feat, norm_mean, norm_std):
        norm_feat = torch.zeros_like(feat)
        for i in range(feat.shape[0]):
            for j in range(feat.shape[1]):
                if feat[i, j, 0] != 0:
                    norm_feat[i, j, :] = (feat[i, j, :] - norm_mean) / norm_std
                else:
                    break
        return norm_feat

    def __len__(self):
        return self.feat.shape[0]

    def __getitem__(self, idx):
        feature = self.feat[idx, :]
        phoneme_id= self.label[idx, :, 0]
        phone_score = self.label[idx, :, 1]
        word_score = self.label[idx, :, 2]
        word_id = self.label[idx, :, 3]
        utterance_score = self.label[idx, 0:1, 4]

        # if self.mode == "train":
        #     phone_score[phone_score != -1] += random.randint(-5, 5)/50

        return {
            "feature": feature, 
            "phoneme_id": phoneme_id, 
            "phone_score": phone_score,
            "word_score": word_score,
            "word_id": word_id,
            "utterance_score": utterance_score
        }

In [10]:
import sys
import os
import time

def train(audio_model, train_loader, test_loader, args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('running on ' + str(device))

    # best_cum_mAP is checkpoint ensemble from the first epoch to the best epoch
    best_epoch, best_mse = 0, 999
    global_step, epoch = 0, 0
    exp_dir = args.exp_dir

    # if not isinstance(audio_model, nn.DataParallel):
    #     audio_model = nn.DataParallel(audio_model)

    audio_model = audio_model.to(device)
    # Set up the optimizer
    trainables = [p for p in audio_model.parameters() if p.requires_grad]
    print('Total parameter number is : {:.3f} k'.format(sum(p.numel() for p in audio_model.parameters()) / 1e3))
    print('Total trainable parameter number is : {:.3f} k'.format(sum(p.numel() for p in trainables) / 1e3))
    optimizer = torch.optim.Adam(trainables, args.lr, weight_decay=5e-7, betas=(0.95, 0.999))

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(10, 100, 5)), gamma=0.5, last_epoch=-1)

    loss_fn = nn.MSELoss()

    print("current #steps=%s, #epochs=%s" % (global_step, epoch))
    print("start training...")
    result = np.zeros([args.n_epochs, 32])

    while epoch < args.n_epochs:
        audio_model.train()
        for i, batch in enumerate(train_loader):
            audio_input = batch["feature"].to(device, non_blocking=True)
            phn_id = batch["phoneme_id"].to(device, non_blocking=True)
            phn_label = batch["phone_score"].to(device, non_blocking=True)
            word_label = batch["word_score"].to(device, non_blocking=True)
            word_id = batch["word_id"].to(device, non_blocking=True)
            utt_label = batch["utterance_score"].to(device, non_blocking=True)

            # warmup
            warm_up_step = 100
            if global_step <= warm_up_step and global_step % 5 == 0:
                warm_lr = (global_step / warm_up_step) * args.lr
                for param_group in optimizer.param_groups:
                    param_group['lr'] = warm_lr
                #print('warm-up learning rate is {:f}'.format(optimizer.param_groups[0]['lr']))

            # add random noise for augmentation.
            # noise = (torch.rand([audio_input.shape[0], audio_input.shape[1], audio_input.shape[2]]) - 1) * args.noise
            # noise = noise.to(device, non_blocking=True)
            # audio_input = audio_input + noise

            #print(phns.shape)
            u, p, w = audio_model(audio_input, phn_id)

            # filter out the padded tokens, only calculate the loss based on the valid tokens
            # < 0 is a flag of padded tokens
            mask = (phn_label>=0)
            p = p.squeeze(2)
            p = p * mask
            phn_label = phn_label * mask
            
            loss_phn = loss_fn(p, phn_label)

            # avoid the 0 losses of the padded tokens impacting the performance
            loss_phn = loss_phn * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

            # utterance level loss, also mse
            utt_preds = u
            # print(utt_preds.shape ,utt_label[:,0:1].shape)
            loss_utt = loss_fn(utt_preds ,utt_label[:,0:1])

            # word level loss
            word_label = word_label
            mask = (word_label>=0)
            word_pred = w[:, :, 0]
            
            word_pred = word_pred * mask
            word_label = word_label * mask
            
            loss_word = loss_fn(word_pred, word_label)
            loss_word = loss_word * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

            loss = args.loss_w_phn * loss_phn + args.loss_w_utt * loss_utt + args.loss_w_word * loss_word

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1

        print('start validation of epoch {:d}'.format(epoch))

        # ensemble results
        # don't save prediction for the training set
        # tr_mse, tr_corr, tr_utt_mse, tr_utt_corr, tr_word_mse, tr_word_corr = validate(audio_model, train_loader, args, -1)
        te_mse, te_corr, te_utt_mse, te_utt_corr, te_word_mse, te_word_corr = validate(audio_model, test_loader, args, best_mse)

        print('Phone: Test MSE: {:.3f}, CORR: {:.3f}'.format(te_mse.item(), te_corr))
        print('Utterance:, MSE: {:.3f}, CORR: {:.3f}'.format(te_utt_mse[0], te_utt_corr[0]))
        print('Word:, MSE: {:.3f}, CORR: {:.3f}'.format(te_word_mse[0], te_word_corr[0]))

        print('-------------------validation finished-------------------')

        if te_mse < best_mse:
            best_mse = te_mse
            best_epoch = epoch

        if best_epoch == epoch:
            if os.path.exists("%s/models/" % (exp_dir)) == False:
                os.mkdir("%s/models" % (exp_dir))
            torch.save(audio_model.state_dict(), "%s/models/best_audio_model.pth" % (exp_dir))

        if global_step > warm_up_step:
            scheduler.step()

        #print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr']))
        epoch += 1

def validate(audio_model, val_loader, args, best_mse):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # if not isinstance(audio_model, nn.DataParallel):
    #     audio_model = nn.DataParallel(audio_model)
    audio_model = audio_model.to(device)
    audio_model.eval()

    A_phn, A_phn_target = [], []
    A_u, A_utt_target = [], []
    A_w, A_word_target , A_word_id= [], [], []
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            audio_input = batch["feature"].to(device, non_blocking=True)
            phn_id = batch["phoneme_id"].to(device, non_blocking=True)
            phn_label = batch["phone_score"]
            word_label = batch["word_score"]
            word_id = batch["word_id"]
            utt_label = batch["utterance_score"]

            # compute output
            u, p, w = audio_model(audio_input, phn_id)
            # print(u.shape, p.shape, w.shape)
            p = p.to('cpu').detach()
            u = u.to('cpu').detach()
            w = w.to('cpu').detach()
            
            A_phn.append(p[:, :, 0])
            A_phn_target.append(phn_label)
            
            A_u.append(u[:, 0:1])
            A_utt_target.append(utt_label)

            A_w.append(w[:, :, 0])
            A_word_target.append(word_label)
            A_word_id.append(word_id)
            
        index = random.randint(0, len(A_phn)-1)
        print("++++++++++++++++++++++++++++++++++++++++++")
        print("## predicted phone: ", A_phn[index][0][0:6])
        print("## label phone: ", A_phn_target[index][0][0:6])

        print("## predicted word: ", A_w[index][0][0:6])
        print("## label word: ", A_word_target[index][0][0:6])

        print("## predicted utt: ", A_u[index][0])
        print("## label utt: ", A_utt_target[index][0])
        print("++++++++++++++++++++++++++++++++++++++++++")

        # phone level
        A_phn, A_phn_target  = torch.vstack(A_phn), torch.vstack(A_phn_target)
        # print("Phone Level: ", A_phn.shape, A_phn_target.shape)
        # utterance level
        A_u, A_utt_target = torch.vstack(A_u), torch.vstack(A_utt_target)
        # print("Utterance Level: ", A_u.shape, A_utt_target.shape)
        # word level
        A_w, A_word_target, A_word_id = torch.vstack(A_w), torch.vstack(A_word_target), torch.vstack(A_word_id)
        # print("Word Level: ", A_w.shape, A_word_target.shape)
        # get the scores
        phn_mse, phn_corr = valid_phn(A_phn, A_phn_target)

        A_utt = A_u
        utt_mse, utt_corr = valid_utt(A_utt, A_utt_target)

        A_word = A_w
        word_mse, word_corr, valid_word_pred, valid_word_target = valid_word(A_word, A_word_target, A_word_id)
        # word_mse, word_corr, valid_word_pred, valid_word_target = 0, 0, 0 , 0

        if phn_mse < best_mse:
            print('new best phn mse {:.3f}, now saving predictions.'.format(phn_mse))

            # create the directory
            if os.path.exists(args.exp_dir + '/preds') == False:
                os.mkdir(args.exp_dir + '/preds')

            # saving the phn target, only do once
            if os.path.exists(args.exp_dir + '/preds/phn_target.npy') == False:
                np.save(args.exp_dir + '/preds/phn_target.npy', A_phn_target)
                np.save(args.exp_dir + '/preds/word_target.npy', valid_word_target)
                np.save(args.exp_dir + '/preds/utt_target.npy', A_utt_target)

            np.save(args.exp_dir + '/preds/phn_pred.npy', A_phn)
            np.save(args.exp_dir + '/preds/word_pred.npy', valid_word_pred)
            np.save(args.exp_dir + '/preds/utt_pred.npy', A_utt)

    return phn_mse, phn_corr, utt_mse, utt_corr, word_mse, word_corr

def valid_phn(audio_output, target):
    valid_token_pred = []
    valid_token_target = []
    # audio_output = audio_output.squeeze(2)
    for i in range(audio_output.shape[0]):
        for j in range(audio_output.shape[1]):
            # only count valid tokens, not padded tokens (represented by negative values)
            if target[i, j] >= 0:
                valid_token_pred.append(audio_output[i, j])
                valid_token_target.append(target[i, j])
    valid_token_target = np.array(valid_token_target)
    valid_token_pred = np.array(valid_token_pred)

    valid_token_mse = np.mean((valid_token_target - valid_token_pred) ** 2)
    # valid_token_mse = np.mean(np.abs(valid_token_target - valid_token_pred))
    corr = np.corrcoef(valid_token_pred, valid_token_target)[0, 1]
    return valid_token_mse, corr

def valid_utt(audio_output, target):
    mse = []
    corr = []
    target = target[:,0:1]
    for i in range(1):
        cur_mse = np.mean(((audio_output[:, i] - target[:, i]) ** 2).numpy())
        # cur_mse = np.mean((np.abs(audio_output[:, i] - target[:, i])).numpy())
        cur_corr = np.corrcoef(audio_output[:, i], target[:, i])[0, 1]
        mse.append(cur_mse)
        corr.append(cur_corr)
    return mse, corr

def valid_word(audio_output, target, word_id):
    # word_id = target[:, :, -1]
    # target = target[:, :, 0:3]
    valid_token_pred = []
    valid_token_target = []

    # unique, counts = np.unique(np.array(target), return_counts=True)
    # print(dict(zip(unique, counts)))

    # for each utterance
    for i in range(target.shape[0]):
        prev_w_id = 0
        start_id = 0
        # for each token
        for j in range(target.shape[1]):
            cur_w_id = word_id[i, j].int()
            # if a new word
            if cur_w_id != prev_w_id:
                # average each phone belongs to the word
                valid_token_pred.append(np.mean(audio_output[i, start_id: j].numpy(), axis=0))
                valid_token_target.append(np.mean(target[i, start_id: j].numpy(), axis=0))
                # sanity check, if the range indeed contains a single word
                # if len(torch.unique(target[i, start_id: j])) != 1:
                #     print(target[i, start_id: j])
                # if end of the utterance
                if cur_w_id == -1:
                    break
                else:
                    prev_w_id = cur_w_id
                    start_id = j

    valid_token_pred = np.array(valid_token_pred)
    # this rounding is to solve the precision issue in the label
    valid_token_target = np.array(valid_token_target).round(2)

    mse_list, corr_list = [], []
    # for each (accuracy, stress, total) word score
    valid_token_mse = np.mean((valid_token_target[:] - valid_token_pred[:]) ** 2)
    corr = np.corrcoef(valid_token_pred[:], valid_token_target[:])[0, 1]
    mse_list.append(valid_token_mse)
    corr_list.append(corr)
    return mse_list, corr_list, valid_token_pred, valid_token_target

In [11]:
import argparse

print("I am process %s, running on %s: starting (%s)" % (os.getpid(), os.uname()[1], time.asctime()))
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--exp-dir", type=str, default=os.getcwd()+"/exp/", help="directory to dump experiments")
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', help='initial learning rate')
parser.add_argument("--n-epochs", type=int, default=50, help="number of maximum training epochs")
parser.add_argument("--goptdepth", type=int, default=3, help="3 depth of gopt models")
parser.add_argument("--goptheads", type=int, default=1, help="heads of gopt models")
parser.add_argument("--batch_size", type=int, default=64, help="training batch size")
parser.add_argument("--embed_dim", type=int, default=24, help="24 gopt transformer embedding dimension")
parser.add_argument("--loss_w_phn", type=float, default=1, help="weight for phoneme-level loss")
parser.add_argument("--loss_w_word", type=float, default=1, help="weight for word-level loss")
parser.add_argument("--loss_w_utt", type=float, default=1, help="weight for utterance-level loss")
parser.add_argument("--model", type=str, default='gopt', help="name of the model")
parser.add_argument("--am", type=str, default='librispeech', help="name of the acoustic models")
parser.add_argument("--noise", type=float, default=0., help="the scale of random noise added on the input GoP feature")

args = parser.parse_args(args=[])

if torch.cuda.is_available() == False:
    raise ValueError('GPU is not enabled. Please go to top menu - edit - notebook settings -hardware accelerator - GPU')

input_dim = 84

audio_mdl = GOPT(embed_dim=args.embed_dim, num_heads=args.goptheads, depth=args.goptdepth, input_dim=input_dim)

tr_dataset = GoPDataset('train')
print("Num train sample: ", len(tr_dataset))
tr_dataloader = DataLoader(tr_dataset, batch_size=25, shuffle=True, drop_last=True, num_workers=4)
te_dataset = GoPDataset('test')
print("Num test sample: ", len(te_dataset))
te_dataloader = DataLoader(te_dataset, batch_size=256, shuffle=False, drop_last=True, num_workers=4)

if os.path.exists(args.exp_dir) == False:
  os.makedirs(args.exp_dir)
train(audio_mdl, tr_dataloader, te_dataloader, args)

I am process 1492575, running on PrepAI: starting (Thu Sep 21 15:50:52 2023)
Num train sample:  97429
Num test sample:  2498
running on cuda
Total parameter number is : 25.947 k
Total trainable parameter number is : 25.947 k
current #steps=0, #epochs=0
start training...
start validation of epoch 0
++++++++++++++++++++++++++++++++++++++++++
## predicted phone:  tensor([0.4907, 0.4504, 1.5153, 1.1892, 1.7540, 0.9267])
## label phone:  tensor([0.6400, 0.0000, 1.8800, 1.9800, 1.9400, 0.7400])
## predicted word:  tensor([0.9437, 0.9083, 0.9889, 1.0623, 1.0925, 0.9637])
## label word:  tensor([0.6800, 0.6800, 0.6800, 0.6800, 0.6800, 0.6800])
## predicted utt:  tensor([1.0107])
## label utt:  tensor([0.6800])
++++++++++++++++++++++++++++++++++++++++++
new best phn mse 0.258, now saving predictions.
Phone: Test MSE: 0.258, CORR: 0.702
Utterance:, MSE: 0.126, CORR: 0.710
Word:, MSE: 0.124, CORR: 0.711
-------------------validation finished-------------------
start validation of epoch 1
++++++++

In [12]:
# ### v1
# -------------------validation finished-------------------
# start validation of epoch 36
# Phone: Test MSE: 539.188, CORR: 0.765
# Utterance:, MSE: 246.776, CORR: 0.789
# Word:, MSE: 293.750, CORR: 0.766
# -------------------validation finished-------------------
# start validation of epoch 37
# Phone: Test MSE: 539.483, CORR: 0.765
# Utterance:, MSE: 247.567, CORR: 0.788
# Word:, MSE: 293.962, CORR: 0.766
# -------------------validation finished-------------------
# start validation of epoch 38
# new best phn mse 538.541, now saving predictions.
# Phone: Test MSE: 538.541, CORR: 0.766
# Utterance:, MSE: 246.970, CORR: 0.789
# Word:, MSE: 293.527, CORR: 0.766

In [13]:
#### v2
## predicted phone:  tensor([ 1.9180,  0.5972,  0.1705,  1.8036,  0.7993, -0.1650])
## label phone:  tensor([ 2.,  0.,  0.,  2., -1., -1.])
## predicted word:  tensor([0.7987, 0.8188, 0.7050, 0.7736, 0.3033, 0.1018])
## label word:  tensor([ 0.6400,  0.6400,  0.6400,  0.6400, -1.0000, -1.0000])
## predicted utt:  tensor([0.7410])
## label utt:  tensor([0.6400])