In [1]:
import numpy as np
import pandas as pd
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import random
import copy
import os
import graphviz
import pickle
import scipy.io.wavfile as wav
from src.voice_activity_detection.extract_features import extract_features

In [None]:
with open("src/data/noise-train/features_df_5s.pickle", "rb") as file:
    voice_noise_df=pickle.load(file)
voice_noise_df.info()

In [3]:
voice_noise_df = voice_noise_df[pd.notnull(voice_noise_df['RSE'])]
voice_noise_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 76846 entries, 1 to 77663
Data columns (total 13 columns):
RMS       76846 non-null float64
SE        76846 non-null float64
ZCR       76846 non-null float64
LEFR      76846 non-null float64
SF        76846 non-null float64
SRF       76846 non-null float64
SC        76846 non-null float64
BW        76846 non-null float64
NWPD      76846 non-null float64
RSE       76846 non-null float64
type      76846 non-null object
name      76846 non-null object
number    76846 non-null int64
dtypes: float64(10), int64(1), object(2)
memory usage: 8.2+ MB


In [4]:
voice_noise_df.describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,number
count,76846.0,76846.0,76846.0,76846.0,76846.0,76846.0,76846.0,76846.0,76846.0,76846.0,76846.0
mean,4027.386623,inf,0.106162,0.336986,0.018741,4454.837767,757.398939,686381.5,0.271261,-inf,42.876845
std,2470.621936,,0.053201,0.217449,0.007833,886.254567,433.981223,592655.3,1.490372,,34.674104
min,0.49616,2839.631,5e-05,0.0,0.00153,1753.576807,55.599252,6881.946,-32.282923,-inf,0.0
25%,2264.165466,3257.9,0.070601,0.120482,0.01365,3645.456827,477.869557,296253.0,-0.410469,-0.2992179,15.0
50%,3342.738647,3404.282,0.098514,0.389558,0.017666,4319.088855,671.782718,528365.5,0.312929,-0.2534488,34.0
75%,5294.822388,3587.1,0.132314,0.508032,0.022452,5097.719001,947.220822,887765.4,1.02084,-0.1863887,65.0
max,27287.986328,inf,0.860161,0.993976,0.092539,7535.391566,7007.05765,7441943.0,51.500733,-0.02729339,207.0


In [5]:
voice_noise_df[voice_noise_df['RSE']==-np.inf].describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,number
count,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0
mean,3894.241669,2924.087528,0.165486,0.507028,0.019334,5284.732681,1388.688601,1056925.0,0.727845,-inf,0.0
std,2809.456942,28.781054,0.183082,0.278002,0.013689,1094.253513,1408.936775,819981.5,1.051611,,0.0
min,1420.806152,2887.367726,0.036113,0.096386,0.003947,4105.170683,499.139786,464266.9,0.143988,-inf,0.0
25%,1817.082916,2914.278508,0.076295,0.462349,0.010123,4674.683107,646.239637,598861.4,0.184715,-inf,0.0
50%,3317.779236,2925.720581,0.094451,0.614458,0.019801,5169.490462,782.903417,750140.8,0.231948,-inf,0.0
75%,5394.937988,2935.5296,0.183643,0.659137,0.029012,5779.540035,1525.352381,1208205.0,0.775077,-inf,0.0
max,7520.602051,2957.541225,0.43693,0.702811,0.033788,6694.779116,3489.807785,2263152.0,2.303495,-inf,0.0


In [6]:
voice_noise_df[voice_noise_df['SE']==np.inf].describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,number
count,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0
mean,8431.458496,inf,0.109108,0.23996,0.017978,4092.934237,932.592424,477436.448851,2.142212,-0.288774,43.0
std,8117.857228,,0.033473,0.319476,0.005182,672.408143,242.518121,269602.15415,0.152605,0.172011,60.811183
min,2691.266602,inf,0.085439,0.014056,0.014314,3617.46988,761.106216,286798.937429,2.034304,-0.410404,0.0
25%,5561.362549,inf,0.097273,0.127008,0.016146,3855.202058,846.84932,382117.69314,2.088258,-0.349589,21.5
50%,8431.458496,inf,0.109108,0.23996,0.017978,4092.934237,932.592424,477436.448851,2.142212,-0.288774,43.0
75%,11301.554443,inf,0.120942,0.352912,0.019811,4330.666416,1018.335528,572755.204562,2.196166,-0.227959,64.5
max,14171.650391,inf,0.132777,0.465863,0.021643,4568.398594,1104.078631,668073.960273,2.25012,-0.167144,86.0


In [7]:
voice_noise_df['RSE'].replace(-np.inf, np.nan, inplace=True)
voice_noise_df['SE'].replace(np.inf, np.nan, inplace=True)
voice_noise_df = voice_noise_df[pd.notnull(voice_noise_df['RSE'])]
voice_noise_df = voice_noise_df[pd.notnull(voice_noise_df['SE'])]
voice_noise_df.describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,number
count,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0
mean,4027.278925,3440.216378,0.106159,0.33698,0.018741,4454.803985,757.361517,686367.7,0.271188,-0.247531,42.879073
std,2470.380094,269.45641,0.053189,0.217444,0.007833,886.237323,433.883156,592648.5,1.490381,0.07061,34.673384
min,0.49616,2839.630946,5e-05,0.0,0.00153,1753.576807,55.599252,6881.946,-32.282923,-0.583062,0.0
25%,2264.211731,3257.935306,0.070601,0.120482,0.01365,3645.456827,477.840842,296252.1,-0.410489,-0.299206,15.0
50%,3342.738647,3404.285039,0.098514,0.389558,0.017666,4319.088855,671.758851,528360.8,0.312929,-0.253445,34.0
75%,5294.77478,3587.097837,0.132314,0.508032,0.022452,5097.624874,947.213312,887771.1,1.020803,-0.186386,65.0
max,27287.986328,6376.835783,0.860161,0.993976,0.092539,7535.391566,7007.05765,7441943.0,51.500733,-0.027293,207.0


In [8]:
le=LabelEncoder()
voice_noise_df['type'] = le.fit_transform(voice_noise_df["type"])
list(le.classes_)


['music', 'noise', 'speech']

In [9]:
voice_noise_df.describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,type,number
count,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0,76840.0
mean,4027.278925,3440.216378,0.106159,0.33698,0.018741,4454.803985,757.361517,686367.7,0.271188,-0.247531,1.162142,42.879073
std,2470.380094,269.45641,0.053189,0.217444,0.007833,886.237323,433.883156,592648.5,1.490381,0.07061,0.960448,34.673384
min,0.49616,2839.630946,5e-05,0.0,0.00153,1753.576807,55.599252,6881.946,-32.282923,-0.583062,0.0,0.0
25%,2264.211731,3257.935306,0.070601,0.120482,0.01365,3645.456827,477.840842,296252.1,-0.410489,-0.299206,0.0,15.0
50%,3342.738647,3404.285039,0.098514,0.389558,0.017666,4319.088855,671.758851,528360.8,0.312929,-0.253445,2.0,34.0
75%,5294.77478,3587.097837,0.132314,0.508032,0.022452,5097.624874,947.213312,887771.1,1.020803,-0.186386,2.0,65.0
max,27287.986328,6376.835783,0.860161,0.993976,0.092539,7535.391566,7007.05765,7441943.0,51.500733,-0.027293,2.0,207.0


In [10]:
voice_noise_df.groupby('type').count()

Unnamed: 0_level_0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,name,number
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
0,30221,30221,30221,30221,30221,30221,30221,30221,30221,30221,30221,30221
1,3939,3939,3939,3939,3939,3939,3939,3939,3939,3939,3939,3939
2,42680,42680,42680,42680,42680,42680,42680,42680,42680,42680,42680,42680


In [55]:
voice_noise_df[voice_noise_df["type"]==0].describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,type,number
count,30221.0,30221.0,30221.0,30221.0,30221.0,30221.0,30221.0,30221.0,30221.0,30221.0,30221.0,30221.0
mean,5387.262578,3411.997153,0.094444,0.145269,0.020558,4702.142313,606.957559,683252.2,0.131743,-0.188543,0.0,28.421396
std,2808.979193,254.35002,0.046912,0.149633,0.009523,945.739368,326.185968,623577.8,1.460622,0.046116,0.0,24.580952
min,0.500887,2888.534515,0.0019,0.0,0.001992,2679.028614,57.676472,41190.78,-14.455547,-0.387462,0.0,0.0
25%,3340.237305,3242.426032,0.061488,0.026104,0.013889,3959.839357,396.600224,196530.1,-0.592873,-0.215385,0.0,11.0
50%,4944.258789,3376.188593,0.085964,0.100402,0.018895,4706.325301,543.68779,499275.8,0.206022,-0.182518,0.0,23.0
75%,7294.654297,3541.40048,0.120464,0.220884,0.025305,5461.157129,744.175755,975981.1,0.956641,-0.154172,0.0,38.0
max,17142.080078,5093.415907,0.720872,0.937751,0.086942,7488.014558,6361.525399,6210831.0,10.424738,-0.063611,0.0,196.0


In [56]:
voice_noise_df[voice_noise_df["type"]==1].describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,type,number
count,3939.0,3939.0,3939.0,3939.0,3939.0,3939.0,3939.0,3939.0,3939.0,3939.0,3939.0,3939.0
mean,3369.012304,3275.565459,0.145967,0.21543,0.013245,5292.193065,1076.163003,1193787.0,0.43968,-0.180767,1.0,9.146484
std,3198.695913,248.314363,0.129469,0.283063,0.011207,1046.941712,1073.866213,1286374.0,3.367188,0.042567,0.0,13.7202
min,0.49616,2839.630946,5e-05,0.0,0.00153,1753.576807,55.599252,6881.946,-32.282923,-0.583062,1.0,0.0
25%,1146.377258,3117.061103,0.057419,0.0,0.005833,4477.660643,344.047161,288158.0,-0.143104,-0.193973,1.0,1.0
50%,2446.968994,3261.25445,0.103951,0.042169,0.00969,5234.626004,669.553227,635645.0,0.504699,-0.173504,1.0,4.0
75%,4873.787109,3396.0598,0.195684,0.411647,0.017645,6208.082329,1441.45716,1669975.0,1.278514,-0.163857,1.0,11.0
max,27287.986328,6376.835783,0.860161,0.993976,0.092539,7535.391566,7007.05765,7441943.0,51.500733,-0.027293,1.0,97.0


In [57]:
voice_noise_df[voice_noise_df["type"]==2].describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,type,number
count,42680.0,42680.0,42680.0,42680.0,42680.0,42680.0,42680.0,42680.0,42680.0,42680.0,42680.0,42680.0
mean,3125.049455,3475.393819,0.11078,0.483945,0.017962,4202.383949,834.437406,641743.1,0.354376,-0.295461,2.0,56.229545
std,1514.147894,273.714578,0.041517,0.113539,0.005378,712.634883,354.754532,422580.9,1.191223,0.046161,0.0,35.889229
min,0.498347,2894.223681,0.0001,0.0,0.001992,3142.382028,77.455102,59564.13,-15.620562,-0.496837,2.0,0.0
25%,2063.22699,3286.078776,0.079463,0.413655,0.014067,3576.24247,571.234417,337753.8,-0.320243,-0.324219,2.0,25.0
50%,2794.189453,3440.527066,0.106939,0.485944,0.017389,4022.998243,786.179359,535196.4,0.35849,-0.291457,2.0,53.0
75%,3835.129639,3634.419471,0.136564,0.554217,0.021263,4742.595382,1031.449498,822812.6,1.045511,-0.263614,2.0,85.0
max,9767.736328,4971.523754,0.38028,0.985944,0.067104,7382.028112,4047.635507,5280397.0,12.758825,-0.046602,2.0,207.0


In [58]:
slice_df=voice_noise_df[voice_noise_df["type"]==2]
slice_df.loc[slice_df['RMS'].idxmin()]

RMS                       0.498347
SE                         2901.19
ZCR                      0.0300504
LEFR                             0
SF                      0.00199168
SRF                        7382.03
SC                         3978.85
BW                      5.2804e+06
NWPD                    -0.0210402
RSE                      -0.168358
type                             2
name      speech-librivox-0150.wav
number                         157
Name: 46668, dtype: object

In [65]:
slice_df[slice_df['RMS']<103.6].sort_values(by=['RMS'],ascending=False) #-50dB

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,type,name,number
60219,103.12056,2987.351805,0.087939,0.012048,0.00788,5190.700301,545.865228,703929.4,2.045611,-0.177452,2,speech-us-gov-0098.wav,0
51388,102.030716,3308.936383,0.106314,0.558233,0.009584,4221.26004,786.415287,455914.0,-0.417626,-0.22206,2,speech-us-gov-0021.wav,75
59660,100.534538,3393.851997,0.080314,0.014056,0.008532,5253.012048,488.529448,619960.7,0.493476,-0.176372,2,speech-us-gov-0093.wav,36
49489,97.205772,3609.073264,0.102289,0.0,0.00776,3364.897088,622.113843,463617.5,-0.938354,-0.176897,2,speech-us-gov-0004.wav,31
59904,95.927872,3290.071136,0.090114,0.004016,0.008055,5108.74749,513.250355,588306.7,0.555993,-0.179184,2,speech-us-gov-0095.wav,42
60028,94.502502,3205.345319,0.079638,0.0,0.008916,5220.444277,457.004325,504189.9,0.125435,-0.173533,2,speech-us-gov-0096.wav,47
60789,93.907173,3436.654402,0.096001,0.018072,0.006671,5185.554719,612.439902,725098.5,-0.577803,-0.183607,2,speech-us-gov-0102.wav,94
60027,93.900169,3318.29973,0.102576,0.0,0.007084,5101.342871,598.744948,665634.4,0.0555,-0.180742,2,speech-us-gov-0096.wav,46
58523,89.184189,3435.77694,0.121339,0.385542,0.008513,4519.076305,899.405759,535545.3,-1.085438,-0.223437,2,speech-us-gov-0083.wav,36
57075,79.643112,3264.715642,0.135452,0.343373,0.007033,4211.408133,1015.997488,556093.2,0.600275,-0.216334,2,speech-us-gov-0070.wav,60


In [63]:
slice_df.sort_values(by=['SE'],ascending=False).head(100)

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,type,name,number
70923,2901.761230,4971.523754,0.049988,0.411647,0.024633,3713.792671,368.878227,248202.004100,3.208899,-0.256656,2,speech-us-gov-0192.wav,27
73997,6542.815430,4915.859237,0.083064,0.528112,0.022799,3543.047189,591.584774,367184.801826,1.854023,-0.320512,2,speech-us-gov-0219.wav,98
64160,2468.087891,4914.577551,0.080176,0.399598,0.020653,3247.615462,562.785705,245384.816764,-0.102498,-0.271981,2,speech-us-gov-0133.wav,34
69444,2560.391602,4869.550351,0.073313,0.389558,0.019592,3547.377008,490.835985,301816.276242,-0.261435,-0.243363,2,speech-us-gov-0179.wav,54
74613,3381.856689,4851.552997,0.106814,0.552209,0.019100,3552.899096,806.592253,407849.366768,1.637455,-0.316093,2,speech-us-gov-0224.wav,114
62769,4697.837402,4773.863135,0.125302,0.526104,0.013397,3518.135040,878.051943,561470.494683,-0.119965,-0.287992,2,speech-us-gov-0120.wav,87
73947,6560.067871,4760.942331,0.093414,0.497992,0.018815,3487.700803,686.332193,336958.142630,0.369715,-0.296347,2,speech-us-gov-0219.wav,48
68979,2920.130371,4726.519611,0.109514,0.277108,0.015881,3435.366466,818.596762,319792.802767,-1.211309,-0.359717,2,speech-us-gov-0175.wav,15
74012,6802.117676,4720.636815,0.088364,0.580321,0.018792,3414.784137,642.026511,336272.203575,-0.282158,-0.284705,2,speech-us-gov-0219.wav,113
73441,5764.408691,4703.371263,0.085064,0.363454,0.021404,3570.092871,585.864204,418175.735914,1.866210,-0.273281,2,speech-us-gov-0215.wav,7


In [66]:
filtered_df = voice_noise_df[voice_noise_df['RMS']>=103.6] #drop silent
filtered_df.groupby('type').count()

Unnamed: 0_level_0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,name,number
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
0,30182,30182,30182,30182,30182,30182,30182,30182,30182,30182,30182,30182
1,3866,3866,3866,3866,3866,3866,3866,3866,3866,3866,3866,3866
2,42630,42630,42630,42630,42630,42630,42630,42630,42630,42630,42630,42630


In [67]:
filtered_df.describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,type,number
count,76678.0,76678.0,76678.0,76678.0,76678.0,76678.0,76678.0,76678.0,76678.0,76678.0,76678.0,76678.0
mean,4035.68495,3440.687911,0.10615,0.337235,0.018755,4452.002178,755.914856,683891.2,0.271199,-0.24768,1.162341,42.885247
std,2466.20201,269.354468,0.053056,0.217247,0.007817,883.578553,429.142322,585792.9,1.48869,0.070572,0.960853,34.661247
min,104.475784,2839.630946,5e-05,0.0,0.00153,1753.576807,55.599252,6881.946,-32.282923,-0.583062,0.0,0.0
25%,2271.308105,3258.408742,0.070651,0.12249,0.013668,3644.954819,477.796551,296085.8,-0.410469,-0.299289,0.0,15.0
50%,3347.783569,3404.700165,0.098539,0.389558,0.017675,4316.923946,671.556575,527783.5,0.312983,-0.253665,2.0,34.0
75%,5300.560059,3587.518634,0.132314,0.508032,0.022458,5093.734312,946.644763,886438.9,1.020808,-0.186634,2.0,65.0
max,27287.986328,6376.835783,0.860161,0.993976,0.092539,7535.391566,7007.05765,7441943.0,51.500733,-0.027293,2.0,207.0


In [68]:
music_dropped_df = filtered_df[filtered_df['type']>=1] #drop music

In [69]:
X=music_dropped_df.drop(['RMS','SE','type','name','number'], axis=1)
y=music_dropped_df['type']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=666)

In [70]:
classifier = tree.DecisionTreeClassifier(max_depth = 3)
classifier = classifier.fit(X_train,y_train)

In [71]:
prediction = classifier.predict(X_test)
print(np.mean(np.equal(prediction,y_test).astype(np.float32)))

0.9815053939819336


## For testing new audio

In [72]:
from os.path import dirname,abspath,join
TEST_AUDIO_FOLDER = join(os.getcwd(),'src','data','testwav')
TEST_AUDIO_FOLDER

'C:\\Users\\tianr\\Programming\\Python\\Project Speaker Recog\\speaker_recognition_GMM_UBM\\src\\data\\testwav'

In [73]:
def create_dataset(DATA_FOLDER,WINDOW_LENGTH = 5,FRAME_LENGTH = 25):
    feature_name = "RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,type,name,number".split(",")
    features_dict = {feature:[] for feature in feature_name}

    for root, dirs, files in os.walk(DATA_FOLDER):
        for audio in files:
            if "noise" in audio or "music" in audio or "speech" in audio or "audio" in audio:
                print("****************************")
                print("reading:", audio)
                sampling_rate, sig = wav.read(join(root, audio))
                print("sampling rate:", sampling_rate, "signal length", len(sig))
                index = 0
                number = 0
                while index + (sampling_rate * WINDOW_LENGTH) < len(sig):
                    sample = sig[index:(index + (sampling_rate * WINDOW_LENGTH))]
                    ef = extract_features(sample, FRAME_LENGTH, sampling_rate)
                    rms, se, zcr, lefr, sf, srf, sc, bd, nwpd, rse = ef.return_()
                    features_dict["RMS"].append(rms)
                    features_dict["SE"].append(se)
                    features_dict["ZCR"].append(zcr)
                    features_dict["LEFR"].append(lefr)
                    features_dict["SF"].append(np.mean(sf))
                    features_dict["SC"].append(np.mean(sc))
                    features_dict["SRF"].append(np.mean(srf))
                    features_dict["BW"].append(np.mean(bd))
                    features_dict["NWPD"].append(np.mean(nwpd))
                    features_dict["RSE"].append(np.mean(rse))
                    features_dict["type"].append(audio.split("-")[0])
                    features_dict["name"].append(audio)
                    features_dict["number"].append(number)
                    number+=1
                    index += sampling_rate * WINDOW_LENGTH

    features_df=pd.DataFrame.from_dict(features_dict)
    features_df = features_df[feature_name]
    return features_df

In [74]:
test_df = create_dataset(TEST_AUDIO_FOLDER)

****************************
reading: audiotest08-06-2018-13-12-39.wav
sampling rate: 16000 signal length 84480
****************************
reading: audiotest08-06-2018-16-49-40.wav
sampling rate: 16000 signal length 84480
****************************
reading: audiotest08-06-2018-17-01-12.wav
sampling rate: 16000 signal length 84480
****************************
reading: audiotest08-06-2018-17-01-26.wav
sampling rate: 16000 signal length 84480
****************************
reading: audiotest08-06-2018-17-01-43.wav
sampling rate: 16000 signal length 84480
****************************
reading: audiotest08-06-2018-17-01-56.wav
sampling rate: 16000 signal length 84480
****************************
reading: audiotest08-06-2018-17-02-11.wav
sampling rate: 16000 signal length 84480
****************************
reading: audiotest08-06-2018-17-02-44.wav
sampling rate: 16000 signal length 84480


In [75]:
test_df.describe()

Unnamed: 0,RMS,SE,ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE,number
count,8.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0,8.0
mean,227.739219,3067.51495,0.138881,0.43248,0.019258,5694.967369,1023.248453,1693560.0,0.820804,-0.272634,0.0
std,83.481081,183.753294,0.050681,0.193466,0.006145,618.055152,337.02392,1003788.0,1.163321,0.038125,0.0
min,97.444267,2938.700814,0.085739,0.062249,0.008497,4864.646084,641.419533,668675.2,-0.566549,-0.337949,0.0
25%,179.243271,2961.295149,0.104339,0.351908,0.015903,5216.365462,798.55177,1071056.0,0.244993,-0.304492,0.0
50%,223.084564,2993.866023,0.128089,0.478916,0.019909,5574.171687,957.427323,1303808.0,0.517611,-0.256206,0.0
75%,271.824837,3096.953598,0.155996,0.542671,0.022354,6265.358308,1106.870288,2038066.0,1.194605,-0.246046,0.0
max,353.374756,3490.060427,0.239553,0.696787,0.027444,6469.879518,1700.323351,3614656.0,3.289868,-0.235335,0.0


In [76]:
test_df=test_df.drop(['RMS','SE','type','name','number'], axis=1)
test_predictions = classifier.predict(test_df)
print(test_predictions)

[1 2 2 2 2 2 2 2]


In [78]:
from sklearn.tree import _tree
import json

JSON_FILE_NAME=join(TEST_AUDIO_FOLDER,'tree_model.json')
feature_names = "ZCR,LEFR,SF,SRF,SC,BW,NWPD,RSE".split(",")
tree_ = classifier.tree_


feature_name = [
    feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
    for i in tree_.feature
]
print("def tree({}):".format(", ".join(feature_names)))

json_file = {}

def recurse(node, depth, json_file):
    indent = "  " * depth
    if tree_.feature[node] != _tree.TREE_UNDEFINED:
        name = feature_name[node]
        threshold = tree_.threshold[node]
        json_file["feature"] = name
        json_file["threshold"] = threshold
        json_file["decision"] = None
        print("{}if {} <= {}:".format(indent, name, threshold))
        try:
            temp = json_file["left"]
        except:
            json_file["left"] = {}
        recurse(tree_.children_left[node], depth + 1, json_file["left"])
        print("{}else:  # if {} > {}".format(indent, name, threshold))
        try:
            temp = json_file["right"]
        except:
            json_file["right"] = {}
        recurse(tree_.children_right[node], depth + 1, json_file["right"])
    else:
        print("{}return {}".format(indent, tree_.value[node]))
        json_file["decision"] = str(np.argmax(tree_.value[node]) == 1)
        json_file["threshold"] = 0.0
        json_file["feature"] = None
        json_file["left"] = None
        json_file["right"] = None
        return json_file

recurse(0, 1, json_file)
print(json.dumps(json_file, sort_keys=True, indent=4))
with open(JSON_FILE_NAME, "w") as file:
    json.dump(json_file, file)


def tree(ZCR, LEFR, SF, SRF, SC, BW, NWPD, RSE):
  if RSE <= -0.20199668407440186:
    if SF <= 0.005660966970026493:
      if SRF <= 4844.87939453125:
        return [[ 1. 11.]]
      else:  # if SRF > 4844.87939453125
        return [[200.  11.]]
    else:  # if SF > 0.005660966970026493
      if LEFR <= 0.07630522549152374:
        return [[74. 10.]]
      else:  # if LEFR > 0.07630522549152374
        return [[  429. 37865.]]
  else:  # if RSE > -0.20199668407440186
    if RSE <= -0.18610996007919312:
      if BW <= 386300.25:
        return [[ 35. 220.]]
      else:  # if BW > 386300.25
        return [[302.  53.]]
    else:  # if RSE > -0.18610996007919312
      if LEFR <= 0.11144578456878662:
        return [[1798.   54.]]
      else:  # if LEFR > 0.11144578456878662
        return [[651. 132.]]
{
    "decision": null,
    "feature": "RSE",
    "left": {
        "decision": null,
        "feature": "SF",
        "left": {
            "decision": null,
            "feature": "SRF