In [1]:
import pandas as pd
import matplotlib.pylab as plt

from swat import *
from dlpy import Model, Sequential
from dlpy.model import Optimizer, MomentumSolver, AdamSolver, Gpu
from dlpy.layers import * 
from dlpy.speech import *
from dlpy.splitting import two_way_split
from dlpy.metrics import (accuracy_score, confusion_matrix, plot_roc, 
                          plot_precision_recall, roc_auc_score, f1_score, average_precision_score)
%matplotlib inline

In [2]:
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_colwidth', 500)

# Load the processed data into CAS

In [3]:
conn = CAS('host_name', port_number)

In [4]:
from dlpy.audio import AudioTable 

In [5]:
# a sashdat file that contain the audio data is generated using AudioTable.load_audio_files
# AudioTable.from_audio_sashdat can load the sashdat file into a CAS with a server data path
train_audio = AudioTable.from_audio_sashdat(conn, '/data_dir/data/spoken_language_identification/train_wav.sashdat')

In [6]:
train_audio.columnInfo()

Unnamed: 0,Column,Label,ID,Type,RawLength,FormattedLength,Format,NFL,NFD
0,_id_,,1,int64,8,12,,0,0
1,_path_,,2,varchar,120,120,,0,0
2,_type_,,3,char,8,8,,0,0
3,_audio_,,4,varbinary(sound),320046,320046,,0,0


In [7]:
train_audio.numrows()

In [8]:
# generate MFCC features
train_feature_table = train_audio.create_audio_feature_table(label_level=-2, n_output_frames=1000, random_shuffle=True)

In [9]:
train_feature_table.fetch(to=2)

Unnamed: 0,_path_,_num_frames_,_f0_v0_,_f0_v1_,_f0_v2_,_f0_v3_,_f0_v4_,_f0_v5_,_f0_v6_,_f0_v7_,_f0_v8_,_f0_v9_,_f0_v10_,_f0_v11_,_f0_v12_,_f0_v13_,_f0_v14_,_f0_v15_,_f0_v16_,_f0_v17_,_f0_v18_,_f0_v19_,_f0_v20_,_f0_v21_,_f0_v22_,_f0_v23_,_f0_v24_,_f0_v25_,_f0_v26_,_f0_v27_,_f0_v28_,_f0_v29_,_f0_v30_,_f0_v31_,_f0_v32_,_f0_v33_,_f0_v34_,_f0_v35_,_f0_v36_,_f0_v37_,_f0_v38_,_f0_v39_,_f1_v0_,_f1_v1_,_f1_v2_,_f1_v3_,_f1_v4_,_f1_v5_,_f1_v6_,_f1_v7_,_f1_v8_,_f1_v9_,_f1_v10_,_f1_v11_,_f1_v12_,_f1_v13_,_f1_v14_,_f1_v15_,_f1_v16_,_f1_v17_,_f1_v18_,_f1_v19_,_f1_v20_,_f1_v21_,_f1_v22_,_f1_v23_,_f1_v24_,_f1_v25_,_f1_v26_,_f1_v27_,_f1_v28_,_f1_v29_,_f1_v30_,_f1_v31_,_f1_v32_,_f1_v33_,_f1_v34_,_f1_v35_,_f1_v36_,_f1_v37_,_f1_v38_,_f1_v39_,_f2_v0_,_f2_v1_,_f2_v2_,_f2_v3_,_f2_v4_,_f2_v5_,_f2_v6_,_f2_v7_,_f2_v8_,_f2_v9_,_f2_v10_,_f2_v11_,_f2_v12_,_f2_v13_,_f2_v14_,_f2_v15_,_f2_v16_,_f2_v17_,_f2_v18_,_f2_v19_,_f2_v20_,_f2_v21_,_f2_v22_,_f2_v23_,_f2_v24_,_f2_v25_,_f2_v26_,_f2_v27_,_f2_v28_,_f2_v29_,_f2_v30_,_f2_v31_,_f2_v32_,_f2_v33_,_f2_v34_,_f2_v35_,_f2_v36_,_f2_v37_,_f2_v38_,_f2_v39_,_f3_v0_,_f3_v1_,_f3_v2_,_f3_v3_,_f3_v4_,_f3_v5_,_f3_v6_,_f3_v7_,_f3_v8_,_f3_v9_,_f3_v10_,_f3_v11_,_f3_v12_,_f3_v13_,_f3_v14_,_f3_v15_,_f3_v16_,_f3_v17_,_f3_v18_,_f3_v19_,_f3_v20_,_f3_v21_,_f3_v22_,_f3_v23_,_f3_v24_,_f3_v25_,_f3_v26_,_f3_v27_,_f3_v28_,_f3_v29_,_f3_v30_,_f3_v31_,_f3_v32_,_f3_v33_,_f3_v34_,_f3_v35_,_f3_v36_,_f3_v37_,_f3_v38_,_f3_v39_,_f4_v0_,_f4_v1_,_f4_v2_,_f4_v3_,_f4_v4_,_f4_v5_,_f4_v6_,_f4_v7_,_f4_v8_,_f4_v9_,_f4_v10_,_f4_v11_,_f4_v12_,_f4_v13_,_f4_v14_,_f4_v15_,_f4_v16_,_f4_v17_,_f4_v18_,_f4_v19_,_f4_v20_,_f4_v21_,_f4_v22_,_f4_v23_,_f4_v24_,_f4_v25_,_f4_v26_,_f4_v27_,_f4_v28_,_f4_v29_,_f4_v30_,_f4_v31_,_f4_v32_,_f4_v33_,_f4_v34_,_f4_v35_,_f4_v36_,_f4_v37_,_f4_v38_,_f4_v39_,_f5_v0_,_f5_v1_,_f5_v2_,_f5_v3_,_f5_v4_,_f5_v5_,_f5_v6_,_f5_v7_,_f5_v8_,_f5_v9_,_f5_v10_,_f5_v11_,_f5_v12_,_f5_v13_,_f5_v14_,_f5_v15_,_f5_v16_,_f5_v17_,_f5_v18_,_f5_v19_,_f5_v20_,_f5_v21_,_f5_v22_,_f5_v23_,_f5_v24_,_f5_v25_,_f5_v26_,_f5_v27_,_f5_v28_,_f5_v29_,_f5_v30_,_f5_v31_,_f5_v32_,_f5_v33_,_f5_v34_,_f5_v35_,_f5_v36_,_f5_v37_,_f5_v38_,_f5_v39_,_f6_v0_,_f6_v1_,_f6_v2_,_f6_v3_,_f6_v4_,_f6_v5_,_f6_v6_,_f6_v7_,...,_f993_v32_,_f993_v33_,_f993_v34_,_f993_v35_,_f993_v36_,_f993_v37_,_f993_v38_,_f993_v39_,_f994_v0_,_f994_v1_,_f994_v2_,_f994_v3_,_f994_v4_,_f994_v5_,_f994_v6_,_f994_v7_,_f994_v8_,_f994_v9_,_f994_v10_,_f994_v11_,_f994_v12_,_f994_v13_,_f994_v14_,_f994_v15_,_f994_v16_,_f994_v17_,_f994_v18_,_f994_v19_,_f994_v20_,_f994_v21_,_f994_v22_,_f994_v23_,_f994_v24_,_f994_v25_,_f994_v26_,_f994_v27_,_f994_v28_,_f994_v29_,_f994_v30_,_f994_v31_,_f994_v32_,_f994_v33_,_f994_v34_,_f994_v35_,_f994_v36_,_f994_v37_,_f994_v38_,_f994_v39_,_f995_v0_,_f995_v1_,_f995_v2_,_f995_v3_,_f995_v4_,_f995_v5_,_f995_v6_,_f995_v7_,_f995_v8_,_f995_v9_,_f995_v10_,_f995_v11_,_f995_v12_,_f995_v13_,_f995_v14_,_f995_v15_,_f995_v16_,_f995_v17_,_f995_v18_,_f995_v19_,_f995_v20_,_f995_v21_,_f995_v22_,_f995_v23_,_f995_v24_,_f995_v25_,_f995_v26_,_f995_v27_,_f995_v28_,_f995_v29_,_f995_v30_,_f995_v31_,_f995_v32_,_f995_v33_,_f995_v34_,_f995_v35_,_f995_v36_,_f995_v37_,_f995_v38_,_f995_v39_,_f996_v0_,_f996_v1_,_f996_v2_,_f996_v3_,_f996_v4_,_f996_v5_,_f996_v6_,_f996_v7_,_f996_v8_,_f996_v9_,_f996_v10_,_f996_v11_,_f996_v12_,_f996_v13_,_f996_v14_,_f996_v15_,_f996_v16_,_f996_v17_,_f996_v18_,_f996_v19_,_f996_v20_,_f996_v21_,_f996_v22_,_f996_v23_,_f996_v24_,_f996_v25_,_f996_v26_,_f996_v27_,_f996_v28_,_f996_v29_,_f996_v30_,_f996_v31_,_f996_v32_,_f996_v33_,_f996_v34_,_f996_v35_,_f996_v36_,_f996_v37_,_f996_v38_,_f996_v39_,_f997_v0_,_f997_v1_,_f997_v2_,_f997_v3_,_f997_v4_,_f997_v5_,_f997_v6_,_f997_v7_,_f997_v8_,_f997_v9_,_f997_v10_,_f997_v11_,_f997_v12_,_f997_v13_,_f997_v14_,_f997_v15_,_f997_v16_,_f997_v17_,_f997_v18_,_f997_v19_,_f997_v20_,_f997_v21_,_f997_v22_,_f997_v23_,_f997_v24_,_f997_v25_,_f997_v26_,_f997_v27_,_f997_v28_,_f997_v29_,_f997_v30_,_f997_v31_,_f997_v32_,_f997_v33_,_f997_v34_,_f997_v35_,_f997_v36_,_f997_v37_,_f997_v38_,_f997_v39_,_f998_v0_,_f998_v1_,_f998_v2_,_f998_v3_,_f998_v4_,_f998_v5_,_f998_v6_,_f998_v7_,_f998_v8_,_f998_v9_,_f998_v10_,_f998_v11_,_f998_v12_,_f998_v13_,_f998_v14_,_f998_v15_,_f998_v16_,_f998_v17_,_f998_v18_,_f998_v19_,_f998_v20_,_f998_v21_,_f998_v22_,_f998_v23_,_f998_v24_,_f998_v25_,_f998_v26_,_f998_v27_,_f998_v28_,_f998_v29_,_f998_v30_,_f998_v31_,_f998_v32_,_f998_v33_,_f998_v34_,_f998_v35_,_f998_v36_,_f998_v37_,_f998_v38_,_f998_v39_,_f999_v0_,_f999_v1_,_f999_v2_,_f999_v3_,_f999_v4_,_f999_v5_,_f999_v6_,_f999_v7_,_f999_v8_,_f999_v9_,_f999_v10_,_f999_v11_,_f999_v12_,_f999_v13_,_f999_v14_,_f999_v15_,_f999_v16_,_f999_v17_,_f999_v18_,_f999_v19_,_f999_v20_,_f999_v21_,_f999_v22_,_f999_v23_,_f999_v24_,_f999_v25_,_f999_v26_,_f999_v27_,_f999_v28_,_f999_v29_,_f999_v30_,_f999_v31_,_f999_v32_,_f999_v33_,_f999_v34_,_f999_v35_,_f999_v36_,_f999_v37_,_f999_v38_,_f999_v39_,_fName_,_label_
0,/data/spoken_language_identification/train/en/en_m_81995ee8a5e990193b7858ec4b158e48.fragment29.speed3.wav,998,-0.401382,-0.0905,0.499974,-0.863431,0.062703,-0.155239,-1.193368,-1.516721,-0.771279,-0.089794,1.066978,0.730874,0.716292,0.260554,0.589915,1.592134,0.270134,-0.18505,0.204198,1.226295,0.998201,0.265763,-0.196391,0.668495,0.554744,0.339948,0.344902,-0.232378,0.318514,-0.022201,-0.587462,0.347319,1.61093,1.238718,1.61172,1.155736,1.48392,0.234506,1.245488,-0.17371,-0.533552,-0.161163,0.789961,-0.594621,0.149567,-0.321517,-0.691131,-1.198514,-1.406269,-0.947985,0.512904,0.21167,0.548767,0.446316,0.77785,1.512466,0.466725,-1.838404,-0.535647,0.273908,-0.607209,-0.924573,-0.968981,-0.786672,-0.410684,0.017856,1.254016,0.757635,-0.036202,-1.366315,-0.486148,-0.434969,0.657506,-1.047745,-0.532606,-1.045389,0.436845,0.455426,1.298858,0.017757,-0.729107,-0.357094,0.907247,-0.44567,0.265761,-0.268973,-0.821363,-0.92043,-0.905536,-0.449521,0.617802,-0.319385,0.54112,0.274803,0.669846,1.854313,-0.341296,-1.629043,-1.821463,-0.964932,0.649594,-0.593339,-0.479689,-0.226621,1.105314,-0.310908,0.924652,1.899225,2.136959,0.124222,0.395439,-1.954394,-0.929972,-0.850254,-0.786531,0.383144,1.457303,1.585484,0.129589,-0.489631,-0.713769,-0.264721,1.088414,-0.20135,0.616712,0.072594,-0.744468,-0.59375,-0.402944,0.324578,0.706088,0.297363,1.113343,0.307367,0.079054,0.725078,-0.30056,-1.184934,-1.533819,-0.459755,1.759617,-0.341924,0.68583,0.154507,0.267625,-0.448941,0.16327,1.320231,-0.005,-0.771534,-0.177095,-1.411141,-0.489507,-0.34208,0.041218,-0.090423,1.422985,0.286583,0.262448,-0.737011,-0.219911,0.483709,0.693597,-0.697917,-0.733186,1.236604,1.582004,0.298755,0.814567,0.041269,-0.472485,0.447362,-1.430058,0.30076,1.032918,-1.323783,-0.039781,1.556521,-1.126116,-0.535371,0.533851,-0.413649,1.106613,0.679872,-1.006909,0.661321,-0.752,0.004164,1.332849,-0.806989,0.482545,-0.071032,-0.52447,-0.123292,-0.223254,-0.432232,-0.361364,-0.140044,0.135359,-0.459584,-0.00061,0.204701,-0.42714,-2.224644,-2.323942,0.230041,2.563873,0.175922,0.195622,-0.828701,-2.200543,1.029942,0.770062,1.516353,1.118874,-0.404669,-0.730828,1.978918,-0.409425,1.013658,1.057568,-0.464217,-0.096576,-0.147099,-0.018418,0.180824,-1.739154,-0.4233,-0.527497,0.83548,-0.526007,0.262476,0.430227,-0.127376,0.750469,-0.313372,0.117933,0.581653,0.474612,-0.270367,0.313315,0.590989,0.27538,-1.434505,-1.675759,0.787974,3.295438,0.205652,...,0.188981,-0.0219,0.457061,0.194625,1.214503,0.733196,0.906545,-0.068872,0.833874,0.963735,0.475659,-0.513223,-1.353145,-1.121056,0.425471,1.470345,0.749463,-0.665421,-0.12064,0.296586,-0.56539,-0.636219,-0.141742,0.743943,-1.275848,-1.728568,0.528237,0.095615,0.251202,0.148895,0.745939,1.265329,0.756757,0.304348,2.111539,-1.155278,0.038701,1.104133,0.0156,0.769958,-0.082938,-1.270421,0.55744,-0.273531,-0.815186,0.247041,1.854825,-1.858873,0.992324,1.101764,0.515853,-0.139459,-2.155543,-1.848257,-0.369203,2.273232,0.387775,0.331158,0.168417,0.530757,0.342249,0.162043,0.147225,0.637539,-3.587652,-1.368598,1.447147,0.023902,0.971137,-0.104893,1.496636,1.580972,-0.17192,-0.427363,2.702132,-1.482821,-1.112856,0.387454,0.077003,1.115102,-1.216358,-0.920266,0.289974,-0.747657,-0.985983,-0.148524,-0.190957,-2.119234,1.007737,0.961921,0.581527,-0.020005,-2.115873,-2.402903,-0.296,1.932455,0.043789,0.083608,0.303158,0.673758,0.805532,0.147438,0.422615,0.482776,-3.47992,-0.17147,1.084502,-1.179877,1.875524,1.117714,1.497774,1.334282,-1.076616,-0.576721,1.906674,-2.214571,-1.724966,-1.245445,-1.030289,0.203237,-0.817721,-0.385704,0.033037,-0.06773,0.136348,-1.017381,-0.796381,-1.433372,0.954707,0.78735,0.718754,-0.265642,-2.078419,-1.862943,-0.233338,1.78804,-0.264371,0.937119,0.659766,-0.029541,0.489509,-0.612226,0.167401,0.203247,-3.394292,-0.533802,0.147672,-2.574462,0.995788,0.604719,0.965778,0.928301,-1.41038,-1.01966,1.280889,-2.420313,-1.269905,-0.905756,-0.293689,0.989802,0.412759,0.351098,-0.286522,-0.401607,-0.460747,-0.948241,-1.061584,-0.864115,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,en_m_81995ee8a5e990193b7858ec4b158e48.fragment29.speed3.wav,en
1,/data/spoken_language_identification/train/es/es_f_db16e07d9bfc39deb1792ec84d9e61fe.fragment13.speed2.wav,998,0.758823,1.289812,-0.792434,-1.850318,-0.31023,-0.664289,-0.76221,0.54326,-0.114935,-1.958132,2.32448,1.277998,-0.710826,-1.655149,-0.820182,1.24693,-0.087125,-0.273885,1.817423,-0.770753,-0.519106,-0.084212,0.888045,1.224435,2.265453,-1.323783,1.294395,0.449034,-1.072501,0.255899,-2.172779,-1.737266,-1.137583,-0.697628,-0.519776,0.860098,-0.454468,-0.124529,1.493866,0.295573,0.769017,1.310299,-0.915534,-1.819716,-0.520683,-0.310699,-1.148315,-0.133478,0.082597,-2.486916,1.357335,1.318095,-0.537119,-1.933477,-1.027397,1.498715,0.117499,-0.669832,1.15609,-0.88542,-0.264786,-0.117169,0.81411,-0.125023,2.058067,-0.108528,0.955846,0.764034,-0.53335,0.821312,0.032678,-1.2615,-1.436406,-1.891779,-1.797789,0.765365,-0.080715,0.57091,2.527094,0.790733,0.73398,1.47034,-0.798325,-1.640666,-0.43447,0.213661,-0.858561,-0.32982,0.487013,-2.405541,1.718601,1.466591,0.225019,-0.947975,-1.000089,1.632211,0.953967,-0.428788,0.810583,-0.304791,0.142696,-0.939375,1.318141,-0.326614,2.281747,0.083635,1.16123,0.919736,0.356314,1.188861,2.074399,-0.716603,-0.275756,-1.270666,-2.580058,0.798489,-0.512941,-0.0643,1.607258,0.804814,0.674764,1.628726,-0.72664,-1.310286,-0.523921,0.605517,-0.676964,-1.758683,-0.013384,-1.850511,0.924876,0.263631,0.851436,-0.551739,-1.245427,1.239578,1.120852,-0.07602,-0.110467,-0.104601,-0.062349,-0.252166,0.316135,-0.909251,0.989025,1.299435,1.108688,0.902897,0.441613,0.123964,0.460765,-0.533091,-0.256526,-0.443429,-1.49158,-1.258049,-1.812418,-2.04602,-0.221143,-0.068378,0.67303,1.729223,-0.311056,-0.882198,-0.234481,1.464942,-0.396817,-2.176959,-0.744548,-1.912855,1.239668,0.098117,0.678322,0.41344,0.312666,1.027563,0.968561,1.266985,0.570457,-0.38357,0.147764,1.224055,1.689897,-0.66728,1.277213,0.319321,1.47128,1.810762,0.85915,0.340321,1.01199,-0.721222,-0.316742,0.329919,-0.224404,0.110883,-1.897244,-0.83687,-0.59763,0.079325,0.668951,1.768769,0.024995,-0.60731,-0.727141,1.181433,-0.20045,-2.056345,-1.183103,-2.411107,1.59613,0.139199,-0.353184,-0.285819,-0.661837,0.807536,0.804287,0.786408,0.652174,0.148894,0.018268,0.37734,1.977884,-0.977112,1.161786,0.145835,0.804273,1.163927,0.800884,2.080672,1.369019,0.14258,-0.771961,-0.003097,0.047081,0.346421,-1.060031,-0.799647,-0.676854,0.018505,0.685833,1.900413,0.241678,-0.625197,-1.142341,1.634518,0.121638,-2.284338,...,-0.771239,-0.900528,-0.754799,-1.564539,-1.977688,-1.182362,-1.71208,-1.080225,0.58925,0.079478,0.716974,1.290385,-0.618795,-1.924797,-1.392908,-0.129718,-0.464201,-1.455773,0.606796,1.545465,0.020859,0.991509,-0.757697,1.534717,-0.230189,-0.721851,-1.612286,1.989713,0.158195,-0.725397,2.639496,0.521604,-0.146678,0.253398,0.399241,0.529692,2.135937,1.038928,0.427965,-0.601985,-1.63451,-0.987409,-0.923253,-0.966782,-1.234906,-1.367018,-0.548863,0.107806,0.554594,0.059341,0.695055,1.565249,0.014018,-1.764103,-1.326569,-0.213259,-0.189263,-0.863233,0.56289,1.591204,-0.12011,1.11775,-0.527016,0.906876,-0.095988,-0.173561,-0.267228,1.638619,-1.233882,-1.193214,2.591806,1.340976,0.308897,0.564614,0.895941,0.390066,1.882648,0.277218,0.545511,-0.714779,-1.995053,-0.999731,-1.792493,-0.698841,-1.878931,-0.97376,-0.195174,-0.623257,0.489821,-0.055877,0.665002,1.437672,0.075234,-1.457325,-1.134994,-0.516025,-0.624658,-0.999319,0.260354,1.216774,-0.140291,0.805745,-0.133481,0.903927,-0.60701,-0.543236,0.02935,0.732903,-0.622446,-1.016634,2.209512,2.238872,0.660367,0.467141,1.82592,1.504455,1.99322,-0.621676,0.892596,-0.454533,-1.556856,-0.612576,-0.515915,-1.21799,-1.17473,-0.349424,-0.097358,-1.094416,0.407726,-0.123734,0.680123,1.434053,0.047698,-0.959694,-0.718586,0.105179,-0.52398,-1.342333,0.226525,0.739828,-0.713778,0.246308,-0.887197,1.092469,-0.844808,-0.484486,0.492091,0.144346,-0.725626,-1.872338,2.122771,2.318206,-0.279019,0.52287,2.533697,1.03061,1.161286,-0.720772,1.601728,0.342754,-1.853476,-0.882844,-0.532791,-0.057035,-1.102324,-1.140502,-0.354166,-0.261096,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,es_f_db16e07d9bfc39deb1792ec84d9e61fe.fragment13.speed2.wav,es


In [10]:
train_feature_table.label_freq

Unnamed: 0,Level,Frequency
de,1,24360
en,2,24360
es,3,24360


In [11]:
# load the test data
test_audio = AudioTable.from_audio_sashdat(conn, '/data_dir/data/spoken_language_identification/test_wav.sashdat')

In [12]:
test_feature_table = test_audio.create_audio_feature_table(label_level=-2, n_output_frames=1000)

In [13]:
test_feature_table.label_freq

Unnamed: 0,Level,Frequency
de,1,180
en,2,180
es,3,180


# Build the CNN+RNN Model

In [14]:
# build cnn + rnn models

model_cnnrnn = Sequential(conn=conn, model_table='cnnrnn')

init = 'msra'

# w40 * h1000
model_cnnrnn.add(InputLayer(name='input', n_channels=1, width = 40, height=1000))

# a reshape layer is required to convert the sequence input into a fixed-size tensor for CNN layers
model_cnnrnn.add(Reshape(width=40, height=1000, depth=1))

model_cnnrnn.add(Conv2D(n_filters=16, width=3, height=3, stride =1, init=init, act='identity', include_bias=False))
model_cnnrnn.add(BN())
model_cnnrnn.add(Pooling(width=3, height=3, stride =2))

# w20*h500              
model_cnnrnn.add(Conv2D(n_filters=32, width=3, height=3, stride =1, init=init, act='identity', include_bias=False))
model_cnnrnn.add(BN())
model_cnnrnn.add(Pooling(width=3, height=3, stride =2))

# w10*h250              
model_cnnrnn.add(Conv2D(n_filters=64, width=3, height=3, stride =1,init=init, act='identity', include_bias=False))
model_cnnrnn.add(BN())
model_cnnrnn.add(Pooling(width=3, height=3, stride =2))

# w5*h125*d64
# reshape the fixed-size tensor into a sequence data with 125 tokens, and token size is 5*64
model_cnnrnn.add(Reshape(order='DWH', width=5*64, height=125, depth=1))

model_cnnrnn.add(Recurrent(n=50, init='msra', rnn_type='lstm', output_type='samelength', dropout=0.8))
model_cnnrnn.add(Recurrent(n=50, init='msra', rnn_type='lstm', output_type='encoding', dropout=0.8))
              
#output              
model_cnnrnn.add(OutputLayer(n=3, name='output'))

NOTE: Input layer added.
NOTE: Reshape layer added.
NOTE: Convolution layer added.
NOTE: Batch normalization layer added.
NOTE: Pooling layer added.
NOTE: Convolution layer added.
NOTE: Batch normalization layer added.
NOTE: Pooling layer added.
NOTE: Convolution layer added.
NOTE: Batch normalization layer added.
NOTE: Pooling layer added.
NOTE: Reshape layer added.
NOTE: Recurrent layer added.
NOTE: Recurrent layer added.
NOTE: Output layer added.
NOTE: Model compiled successfully.


In [15]:
#%debug

In [None]:
#model_cnnrnn.plot_network()

In [17]:
myvars = train_feature_table.feature_vars

In [18]:
# number of feature columns: 1000 (tokens)* 40 (token size)
len(myvars)

40000

In [19]:
#myvars

# Train the Model

In [20]:
from dlpy.lr_scheduler import *
from dlpy.model import *
lr_scheduler = ReduceLROnPlateau(conn=conn, cool_down_iters=5, gamma=0.1, learning_rate=0.01, patience=3)
solver = MomentumSolver(lr_scheduler = lr_scheduler,
                        clip_grad_max = 100, clip_grad_min = -100)
optimizer = Optimizer(algorithm=solver, mini_batch_size=2, log_level=2, max_epochs=60, reg_l2=0.0005)

The following argument(s) learning_rate, gamma, step_size, power are overwritten by the according arguments specified in lr_scheduler.


In [21]:
data_specs_input = DataSpec(layer='input', type_='numnom', data=train_feature_table.feature_vars, 
                            numeric_nominal_parms=DataSpecNumNomOpts(length=train_feature_table.num_of_frames_col,
                                                                    token_size=train_feature_table.feature_size))

data_specs_output = DataSpec(layer='Output', type_='numnom', data='_label_')

In [22]:
model_cnnrnn.fit(data=train_feature_table,
              data_specs=[data_specs_input, data_specs_output],
              n_threads=16,
              seed=12598,
              record_seed=13544,
              gpu=dict(devices=2),
              train_from_scratch=True,
              force_equal_padding=True,
              optimizer=optimizer
             )

NOTE: Training from scratch.
NOTE: Using dlgrd009.unx.sas.com: 1 out of 4 available GPU devices.
NOTE:  Synchronous mode is enabled.
NOTE:  The total number of parameters is 118361.
NOTE:  The approximate memory cost is 171.00 MB.
NOTE:  Loading weights cost       0.00 (s).
NOTE:  Initializing each layer cost       7.62 (s).
NOTE:  The total number of threads on each worker is 16.
NOTE:  The total mini-batch size per thread on each worker is 2.
NOTE:  The maximum mini-batch size across all workers for the synchronous mode is 32.
NOTE:  Target variable: _label_
NOTE:  Number of levels for the target variable:      3
NOTE:  Levels for the target variable:
NOTE:  Level      0: de
NOTE:  Level      1: en
NOTE:  Level      2: es
NOTE:  Number of input variables: 40000
NOTE:  Number of numeric input variables:  40000
NOTE:  Epoch Learning Rate        Loss  Fit Error   Time(s)
NOTE:  0          0.01           1.105     0.6569    49.76
NOTE:  1          0.01           1.089       0.62    49.30

Unnamed: 0,Descr,Value
0,Model Name,cnnrnn
1,Model Type,Recurrent Neural Network
2,Number of Layers,15
3,Number of Input Layers,1
4,Number of Output Layers,1
5,Number of Convolutional Layers,3
6,Number of Pooling Layers,3
7,Number of Fully Connected Layers,0
8,Number of Recurrent Layers,2
9,Number of Batch Normalization Layers,3

Unnamed: 0,Epoch,LearningRate,Loss,FitError,L2Norm
0,1,0.01,1.104924,0.656852,0.633947
1,2,0.01,1.088826,0.620033,0.405222
2,3,0.01,1.04225,0.561871,0.275051
3,4,0.01,0.940514,0.465152,0.202303
4,5,0.01,0.912761,0.443671,0.158927
5,6,0.01,0.800655,0.357651,0.13142
6,7,0.01,0.701642,0.297053,0.123122
7,8,0.01,0.552791,0.222047,0.129162
8,9,0.01,0.422153,0.163857,0.138897
9,10,0.01,0.374011,0.145947,0.144743

Unnamed: 0,casLib,Name,Rows,Columns,casTable
0,CASUSER(username),cnnrnn_weights,118585,3,"CASTable('cnnrnn_weights', caslib='CASUSER(username)')"


# Assess the Model

In [23]:
model_cnnrnn.evaluate(data=test_feature_table, gpu=dict(devices=2), model_task='CLASSIFICATION')

NOTE: Due to data distribution, miniBatchSize has been limited to 9.
NOTE: Using dlgrd009.unx.sas.com: 1 out of 4 available GPU devices.


Unnamed: 0,Descr,Value
0,Number of Observations Read,540.0
1,Number of Observations Used,540.0
2,Misclassification Error (%),2.592593
3,Loss Error,0.072839

Unnamed: 0,casLib,Name,Rows,Columns,casTable
0,CASUSER(username),Valid_Res_0rCOAU,540,40010,"CASTable('Valid_Res_0rCOAU', caslib='CASUSER(username)')"


In [24]:
model_cnnrnn.valid_conf_mat

Unnamed: 0,_label_,Col1,Col2,Col3
0,de,172.0,3.0,5.0
1,en,3.0,176.0,1.0
2,es,2.0,0.0,178.0


In [25]:
test_result_table = model_cnnrnn.valid_res_tbl

In [26]:
# Show the confusion matrix
display(confusion_matrix(test_result_table['_label_'], test_result_table['I__label_']))

Unnamed: 0_level_0,de,en,es
_label_,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
de,172.0,3.0,5.0
en,3.0,176.0,1.0
es,2.0,0.0,178.0


In [27]:
# Calculate the average f1 score
f1 = 0
for i in 0, 1, 2:
    f1 = f1 + f1_score(test_result_table['_label_'], test_result_table['I__label_'], pos_label=i)
print('the f1 score is {:.6f}'.format(f1/3))

the f1 score is 0.974036


In [28]:
# Calculate the accuracy score
acc_score = accuracy_score(test_result_table['_label_'], test_result_table['I__label_'])
print('the accuracy score is {:.6f}'.format(acc_score))

the accuracy score is 0.974074


# Deploy the Model

In [29]:
model_cnnrnn.deploy(output_format='table', 
                      path='/data/spoken_language_identification')

NOTE: Model table saved successfully.


In [32]:
model_cnnrnn.deploy(output_format='astore', 
                      path=r'\data\DeepLearn\data\spoken_language_identification')

NOTE: Model astore file saved successfully.


In [33]:
conn.endsession()

In [34]:
#conn.shutdown()