In [1]:
import librosa
import numpy as np
import os
import math
import operator
from sklearn.cluster import KMeans
import hmmlearn.hmm

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.[0m
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.[0m
  from numba.decorators import jit as optional_jit


In [2]:
def get_mfcc(file_path):
    y, sr = librosa.load(file_path) # read .wav file
    hop_length = math.floor(sr*0.010) # 10ms hop
    win_length = math.floor(sr*0.025) # 25ms frame
    # mfcc is 12 x T matrix
    mfcc = librosa.feature.mfcc(
        y, sr, n_mfcc=12, n_fft=1024,
        hop_length=hop_length, win_length=win_length)
    # substract mean from mfcc --> normalize mfcc
    mfcc = mfcc - np.mean(mfcc, axis=1).reshape((-1,1)) 
    # delta feature 1st order and 2nd order
    delta1 = librosa.feature.delta(mfcc, order=1)
    delta2 = librosa.feature.delta(mfcc, order=2)
    # X is 36 x T
    X = np.concatenate([mfcc, delta1, delta2], axis=0) # O^r
    # return T x 36 (transpose of X)
    return X.T # hmmlearn use T x N matrix

In [3]:
def get_class_data(data_dir):
    files = os.listdir(data_dir)
    mfcc = [get_mfcc(os.path.join(data_dir,f)) for f in files if f.endswith(".wav")]
    return mfcc

In [4]:
def clustering(X, n_clusters=10):
    kmeans = KMeans(n_clusters=n_clusters, n_init=50, random_state=0, verbose=0)
    kmeans.fit(X)
    print("centers", kmeans.cluster_centers_.shape)
    return kmeans 

In [5]:
class_names = ["nguoi", "viet_nam", "lam_viec", "mot", "khong", "test_nguoi", "test_viet_nam", "test_lam_viec", "test_mot", "test_khong"]
dataset = {}
for cname in class_names:
    print(f"Load {cname} dataset")
    dataset[cname] = get_class_data(os.path.join("data_hmm", cname))
# Get all vectors in the datasets
all_vectors = np.concatenate([np.concatenate(v, axis=0) for k, v in dataset.items()], axis=0)
print("vectors", all_vectors.shape)
# Run K-Means algorithm to get clusters
kmeans = clustering(all_vectors)
print("centers", kmeans.cluster_centers_.shape)
# for cname in class_names:
#     dataset[cname] = list([kmeans.predict(v).reshape(-1,1) for v in dataset[cname]])

Load nguoi dataset
Load viet_nam dataset
Load lam_viec dataset
Load mot dataset
Load khong dataset
Load test_nguoi dataset
Load test_viet_nam dataset
Load test_lam_viec dataset
Load test_mot dataset
Load test_khong dataset
vectors (17729, 36)
centers (10, 36)
centers (10, 36)


In [23]:
def train_GMMHMM(train_set, n_component, startprob, transmat):
    model = hmmlearn.hmm.GMMHMM(n_components=n_component, n_mix=2, random_state=0, n_iter=1000, verbose=True)
    model.startprob_ = startprob
    model.transmat_ = transmat
    X = np.concatenate(train_set)
    lengths = list([len(x) for x in train_set])
    # fit dat 
    print(X.shape, lengths, len(lengths))
    model.fit(X)
    return model

In [33]:
startprob6 = np.array([0.7, 0.2, 0.1, 0.0, 0.0, 0.0])
startprob12 = np.array([0.7, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
transmat6 = np.array([[0.1 for x in range(6)] for y in range(6)])
transmat12 = np.array([[0.1 for x in range(12)] for y in range(12)])
for x in range(5):
    transmat6[x][x+1] = 0.5
transmat6[5][5] = 0.5
for x in range(11):
    transmat12[x][x+1] = 0.5
transmat12[11][11] = 0.5

models = {}
models["khong"] = train_GMMHMM(dataset["khong"], 6, startprob6, transmat6)
models["lam_viec"] = train_GMMHMM(dataset["lam_viec"], 12, startprob12, transmat12)
models["mot"] = train_GMMHMM(dataset["mot"], 6, startprob6, transmat6)
models["nguoi"] = train_GMMHMM(dataset["nguoi"], 6, startprob6, transmat6)
models["viet_nam"] = train_GMMHMM(dataset["viet_nam"], 12, startprob12, transmat12)

(1887, 36) [25, 21, 24, 27, 25, 21, 18, 24, 24, 21, 24, 29, 15, 31, 19, 21, 20, 25, 29, 32, 31, 23, 18, 35, 23, 21, 24, 28, 33, 28, 34, 23, 17, 17, 17, 23, 25, 25, 22, 25, 28, 22, 24, 19, 17, 21, 19, 23, 19, 35, 26, 24, 27, 21, 20, 23, 20, 21, 24, 16, 25, 19, 26, 21, 22, 26, 27, 19, 17, 35, 28, 22, 25, 29, 31, 21, 19, 21, 13, 25] 80


         1     -194908.8475             +nan
         2     -184402.4650      +10506.3825
         3     -183005.7096       +1396.7554
         4     -182689.6940        +316.0156
         5     -182525.9826        +163.7114
         6     -182424.6425        +101.3401
         7     -182317.0132        +107.6293
         8     -182183.8027        +133.2106
         9     -182093.3159         +90.4867
        10     -182038.8408         +54.4751
        11     -182001.6367         +37.2042
        12     -181975.1401         +26.4966
        13     -181952.4495         +22.6906
        14     -181925.8769         +26.5725
        15     -181882.3176         +43.5593
        16     -181849.8939         +32.4237
        17     -181811.0312         +38.8628
        18     -181781.0160         +30.0152
        19     -181755.7168         +25.2992
        20     -181740.0062         +15.7105
        21     -181725.7535         +14.2527
        22     -181693.5640         +32.1895
        23

(5932, 36) [35, 40, 46, 43, 36, 62, 1644, 36, 35, 46, 41, 36, 1303, 54, 40, 37, 34, 129, 39, 39, 43, 35, 50, 51, 34, 25, 29, 44, 37, 35, 32, 45, 36, 31, 38, 30, 30, 36, 35, 44, 27, 29, 52, 33, 47, 47, 27, 45, 34, 51, 40, 39, 45, 32, 31, 36, 30, 34, 32, 29, 32, 25, 25, 38, 29, 31, 31, 58, 28, 24, 30, 25, 38, 28, 28, 28, 38, 30, 41, 31, 39] 81


         1     -649586.9194             +nan
         2     -607273.5617      +42313.3577
         3     -590728.7178      +16544.8439
         4     -587372.0709       +3356.6469
         5     -586218.2241       +1153.8468
         6     -585554.6795        +663.5446
         7     -585036.8434        +517.8360
         8     -584499.5774        +537.2660
         9     -583855.1926        +644.3848
        10     -583264.9216        +590.2710
        11     -582805.1945        +459.7271
        12     -582515.5891        +289.6054
        13     -582356.1034        +159.4858
        14     -582235.6614        +120.4420
        15     -582166.1012         +69.5603
        16     -582116.1263         +49.9749
        17     -582058.7686         +57.3577
        18     -582024.3006         +34.4679
        19     -582004.5841         +19.7165
        20     -581980.1898         +24.3943
        21     -581943.0331         +37.1567
        22     -581916.1887         +26.8444
        23

(1365, 36) [19, 19, 19, 25, 21, 23, 22, 17, 12, 11, 13, 13, 17, 22, 13, 15, 16, 23, 15, 18, 20, 17, 12, 14, 14, 18, 20, 18, 16, 19, 16, 13, 13, 21, 15, 9, 29, 17, 15, 24, 11, 11, 12, 22, 16, 20, 13, 14, 15, 29, 23, 16, 12, 10, 21, 20, 16, 21, 18, 19, 14, 12, 18, 14, 15, 12, 17, 15, 21, 15, 17, 18, 11, 25, 19, 19, 14, 29, 15, 13] 80


         1     -149127.1388             +nan
         2     -143132.8644       +5994.2744
         3     -142176.5009        +956.3635
         4     -141634.2399        +542.2610
         5     -141412.9522        +221.2877
         6     -141260.3951        +152.5571
         7     -141149.1332        +111.2619
         8     -141046.3344        +102.7988
         9     -140967.8078         +78.5266
        10     -140938.2531         +29.5547
        11     -140929.4928          +8.7603
        12     -140921.2052          +8.2876
        13     -140912.9341          +8.2710
        14     -140906.6992          +6.2349
        15     -140897.4156          +9.2836
        16     -140882.4675         +14.9482
        17     -140866.1531         +16.3143
        18     -140846.4729         +19.6802
        19     -140822.2888         +24.1841
        20     -140800.9199         +21.3689
        21     -140781.6868         +19.2331
        22     -140764.5217         +17.1651
        23

(2020, 36) [29, 30, 15, 43, 30, 30, 14, 25, 24, 30, 35, 23, 17, 24, 21, 31, 22, 28, 29, 28, 19, 16, 16, 31, 24, 22, 19, 13, 21, 23, 30, 19, 32, 47, 18, 34, 26, 21, 10, 23, 29, 25, 21, 25, 25, 12, 27, 24, 37, 22, 26, 29, 17, 26, 24, 33, 36, 22, 21, 30, 31, 31, 22, 22, 26, 23, 36, 32, 23, 17, 32, 28, 21, 32, 21, 16, 25, 26, 35, 18] 80


         1     -212930.1022             +nan
         2     -203694.2749       +9235.8273
         3     -201289.2407       +2405.0342
         4     -200621.4881        +667.7526
         5     -200358.0859        +263.4022
         6     -200237.8852        +120.2006
         7     -200186.2418         +51.6434
         8     -200159.6248         +26.6170
         9     -200141.4407         +18.1841
        10     -200113.2752         +28.1655
        11     -200066.4044         +46.8707
        12     -199990.1863         +76.2181
        13     -199918.9445         +71.2418
        14     -199878.2890         +40.6555
        15     -199857.8301         +20.4589
        16     -199835.6672         +22.1629
        17     -199815.6771         +19.9901
        18     -199762.9028         +52.7743
        19     -199750.5816         +12.3212
        20     -199737.4995         +13.0821
        21     -199720.4024         +17.0972
        22     -199702.8356         +17.5668
        23

(3536, 36) [52, 43, 35, 55, 34, 55, 47, 51, 41, 49, 34, 58, 41, 41, 75, 30, 50, 53, 39, 42, 40, 42, 49, 40, 39, 36, 40, 41, 54, 47, 46, 55, 38, 49, 37, 35, 45, 39, 58, 42, 40, 43, 34, 31, 47, 48, 37, 58, 38, 37, 50, 41, 37, 43, 50, 49, 48, 44, 51, 32, 42, 51, 42, 37, 30, 50, 45, 35, 39, 55, 47, 45, 39, 40, 61, 56, 42, 40, 46, 39] 80


         1     -369544.1271             +nan
         2     -344672.9725      +24871.1546
         3     -341249.8605       +3423.1120
         4     -340294.7605        +955.1000
         5     -339857.3397        +437.4208
         6     -339616.5546        +240.7851
         7     -339435.2682        +181.2863
         8     -339278.8061        +156.4621
         9     -339149.5378        +129.2684
        10     -339071.8011         +77.7367
        11     -338976.5754         +95.2257
        12     -338853.0881        +123.4873
        13     -338771.1568         +81.9313
        14     -338693.4312         +77.7256
        15     -338629.1224         +64.3088
        16     -338575.3123         +53.8101
        17     -338531.8482         +43.4642
        18     -338496.5190         +35.3292
        19     -338461.2113         +35.3077
        20     -338422.7311         +38.4802
        21     -338403.2609         +19.4703
        22     -338381.1993         +22.0616
        23

In [34]:
print("Testing")
result = 0
length = 0
for true_cname in class_names:
    if true_cname[:4] == "test":
        for O in dataset[true_cname]:
            length+=1
            score = {cname: model.score(O, [len(O)]) for cname, model in models.items()}
            print(true_cname, score, max(score.items(), key=operator.itemgetter(1))[0])
            if true_cname[5:] == max(score.items(), key=operator.itemgetter(1))[0]:
                result = result + 1
print('Acc', result/length, result, length)

Testing
test_nguoi {'khong': -3718.920813226113, 'lam_viec': -3509.25924802559, 'mot': -3629.568211446534, 'nguoi': -3235.6631110393005, 'viet_nam': -3596.8607060732807} nguoi
test_nguoi {'khong': -3745.888859947988, 'lam_viec': -3581.33246051301, 'mot': -3914.45268015696, 'nguoi': -3290.053205493709, 'viet_nam': -3754.4591838066467} nguoi
test_nguoi {'khong': -2572.5428865546023, 'lam_viec': -2381.630901585078, 'mot': -2451.3475155089914, 'nguoi': -2225.2971998967414, 'viet_nam': -2435.9562975450704} nguoi
test_nguoi {'khong': -4279.372717043243, 'lam_viec': -4000.088252519357, 'mot': -4391.05162805815, 'nguoi': -3700.7475853791925, 'viet_nam': -4023.3987601057515} nguoi
test_nguoi {'khong': -3046.139538441775, 'lam_viec': -2981.3578345504166, 'mot': -3016.770318051231, 'nguoi': -2619.9761880838405, 'viet_nam': -2985.0682778017517} nguoi
test_nguoi {'khong': -3884.6154787636474, 'lam_viec': -3683.812626691417, 'mot': -4053.9219639841112, 'nguoi': -3442.437080505848, 'viet_nam': -3871.

In [26]:
import pickle

In [27]:
with open("modelHMM.pkl", "wb") as file:
    pickle.dump(models, file)