In [1]:
from TIMIT.lightning_model_uncertainty_loss import LightningModel
import os
import pandas as pd
import numpy as np
import torch
import torchaudio
%matplotlib inline
from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, accuracy_score
from IPython import embed

################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################



Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.



In [2]:
model_checkpoint = 'checkpoints/epoch=21-step=10779-v26.ckpt'
model = LightningModel.load_from_checkpoint(model_checkpoint)
model.to('cuda')
model.eval()
print()

Using cache found in /root/.cache/torch/hub/s3prl_cache/3a990c945fbe378df95598eec534e91ba22a5d9eab0b2f88777a7a696d1344e9
for https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt
Model Details: #Params = 164364294	#Trainable Params = 164364294



In [3]:
# model.model.transformer_encoder_M.layers[-5]

In [4]:
phnFiles = os.listdir('./TIMIT_Dataset/wav_data/phn')
testWavFiles = os.listdir('./TIMIT_Dataset/wav_data/TEST')
phnFiles.sort()
testWavFiles.sort()
print(len(phnFiles), len(testWavFiles))

6301 1680


In [5]:
phnFiles[:5]

['.ipynb_checkpoints',
 'FADG0_SA1.PHN',
 'FADG0_SA2.PHN',
 'FADG0_SI1279.PHN',
 'FADG0_SI1909.PHN']

In [6]:
testWavFiles[:5]

['FADG0_SA1.WAV',
 'FADG0_SA2.WAV',
 'FADG0_SI1279.WAV',
 'FADG0_SI1909.WAV',
 'FADG0_SI649.WAV']

In [7]:
testPhnWavFiles = []

for i,testWavFile in enumerate(testWavFiles):
    if((testWavFile[:-4]+'.PHN') in phnFiles):
        testPhnWavFiles.append(['./TIMIT_Dataset/wav_data/TEST/'+testWavFile, './TIMIT_Dataset/wav_data/phn/'+testWavFile[:-4]+'.PHN'])
len(testPhnWavFiles)

1680

In [8]:
testPhnWavFiles[:5]

[['./TIMIT_Dataset/wav_data/TEST/FADG0_SA1.WAV',
  './TIMIT_Dataset/wav_data/phn/FADG0_SA1.PHN'],
 ['./TIMIT_Dataset/wav_data/TEST/FADG0_SA2.WAV',
  './TIMIT_Dataset/wav_data/phn/FADG0_SA2.PHN'],
 ['./TIMIT_Dataset/wav_data/TEST/FADG0_SI1279.WAV',
  './TIMIT_Dataset/wav_data/phn/FADG0_SI1279.PHN'],
 ['./TIMIT_Dataset/wav_data/TEST/FADG0_SI1909.WAV',
  './TIMIT_Dataset/wav_data/phn/FADG0_SI1909.PHN'],
 ['./TIMIT_Dataset/wav_data/TEST/FADG0_SI649.WAV',
  './TIMIT_Dataset/wav_data/phn/FADG0_SI649.PHN']]

# Phone map

In [9]:
# https://catalog.ldc.upenn.edu/docs/LDC93S1/PHONCODE.TXT
symbolList = [                                      
                  'b',          
                  'd',          
                  'g',          
                  'p',          
                  't',          
                  'k',         
                  'dx',         
                  'q',          

                  'jh',         
                  'ch',         

                  's',          
                  'sh',         
                  'z',          
                  'zh',         
                  'f',          
                  'th',         
                  'v',           
                  'dh',         

                  'm',          
                  'n',         
                  'ng',         
                  'em',         
                  'en',         
                  'eng',        
                  'nx',         
                             
                  'l',          
                  'r',          
                  'w',          
                  'y',          
                  'hh',         
                  'hv',         
                  'el',         
                  
                  'iy',        
                  'ih',          
                  'eh',         
                  'ey',         
                  'ae',         
                  'aa',         
                  'aw',         
                  'ay',         
                  'ah',         
                  'ao',         
                  'oy',         
                  'ow',         
                  'uh',         
                  'uw',         
                  'ux',         
                  'er',         
                  'ax',        
                  'ix',        
                  'axr',        
                  'ax-h',      

                  'pau',     
                  'epi',     
                  'h#',     
                  '1',       
                  '2', 

                   'bcl','dcl','gcl','pcl','tck','kcl','tcl']    

In [10]:
len(symbolList)

64

In [11]:
from IPython import embed

# Making Plots

In [17]:
# https://catalog.ldc.upenn.edu/docs/LDC93S1/PHONCODE.TXT
stopsSymbolList = [                                      
                  'b',          
                  'd',          
                  'g',          
                  'p',          
                  't',          
                  'k',         
                  'dx',         
                  'q']

affricatesSymbolList = ['ch'
                  'jh']     

fricativesSymbolList = [
                  's',          
                  'sh',         
                  'z',          
                  'zh',         
                  'f',          
                  'th',         
                  'v',           
                  'dh',]     
nasalsSymbolList = [
                  'm',          
                  'n',         
                  'ng',         
                  'em',         
                  'en',         
                  'eng',        
                  'nx',]       
semivowelsGlidesSymbolList = [                        
                  'l',          
                  'r',          
                  'w',          
                  'y',          
                  'hh',         
                  'hv',         
                  'el',]      
vowelsSymbolList = [            
                  'iy',        
                  'ih',          
                  'eh',         
                  'ey',         
                  'ae',         
                  'aa',         
                  'aw',         
                  'ay',         
                  'ah',         
                  'ao',         
                  'oy',         
                  'ow',         
                  'uh',         
                  'uw',         
                  'ux',         
                  'er',         
                  'ax',        
                  'ix',        
                  'axr',        
                  'ax-h',]    
othersSymbolList = [
                  'pau',     
                  'epi',     
                  'h#',     
                  '1',       
                  '2' ]    

In [18]:
df = pd.read_csv('Dataset/data_info_height_age.csv')
h_mean = df[df['Use'] == 'TRN']['height'].mean()
h_std = df[df['Use'] == 'TRN']['height'].std()
a_mean = df[df['Use'] == 'TRN']['age'].mean()
a_std = df[df['Use'] == 'TRN']['age'].std()
df.set_index('ID', inplace=True)
gender_dict = {'M' : 0, 'F' : 1}

h_mean_male = df[(df['Use'] == 'TRN') & (df['Sex'] == 'M')]['height'].mean()
a_mean_male = df[(df['Use'] == 'TRN') & (df['Sex'] == 'M')]['age'].mean()

h_mean_female = df[(df['Use'] == 'TRN') & (df['Sex'] == 'F')]['height'].mean()
a_mean_female = df[(df['Use'] == 'TRN') & (df['Sex'] == 'F')]['age'].mean()

In [19]:
print(h_mean_male, a_mean_male)
print(h_mean_female, a_mean_female)

179.84134969325135 30.393343558282208
165.36147058823522 29.80727941176471


In [46]:
height_pred = []
height_true = []
age_pred = []
age_true = []
gender_pred = []
gender_true = []

for i,testPhnWavFile in enumerate(testPhnWavFiles):
    wav, _ = torchaudio.load(testPhnWavFile[0])
    wav_len = [wav.shape[1]]
    wav = wav.to('cuda')
    phnFile = testPhnWavFile[1]
    
    with open(phnFile) as f:
        lines = f.readlines()
    for line in lines:
        lo, hi, symbol = line.split(" ")
        lo = int(lo)
        hi = int(hi)
        if(symbol[-1] == '\n'):
            symbol = symbol[:-1]
        if (symbol in vowelsSymbolList):
            wav[0,lo:hi] = 0
    
    y_hat_h, y_hat_a, y_hat_g = model(wav, wav_len)
    y_hat_h = y_hat_h.to('cpu')
    y_hat_a = y_hat_a.to('cpu')
    y_hat_g = y_hat_g.to('cpu')
    
    height_pred.append((y_hat_h*h_std+h_mean).item())
    age_pred.append((y_hat_a*a_std+a_mean).item())
    gender_pred.append(y_hat_g>0.5)

    id = testPhnWavFile[0][30:].split('_')[0][1:]
    g_id = testPhnWavFile[0][30:].split('_')[0][0]

    y_g = gender_dict[df.loc[id, 'Sex']]
    y_h = df.loc[id, 'height']
    y_a =  df.loc[id, 'age']
    height_true.append(y_h)
    age_true.append(y_a)
    gender_true.append(y_g)    
    
#     if(i == 0):
#         break

In [47]:
female_idx = np.where(np.array(gender_true) == 1)[0].reshape(-1).tolist()
male_idx = np.where(np.array(gender_true) == 0)[0].reshape(-1).tolist()

height_true = np.array(height_true)
height_pred = np.array(height_pred)
age_true = np.array(age_true)
age_pred = np.array(age_pred)

hmae = mean_absolute_error(height_true[male_idx], height_pred[male_idx])
hrmse = mean_squared_error(height_true[male_idx], height_pred[male_idx], squared=False)
amae = mean_absolute_error(age_true[male_idx], age_pred[male_idx])
armse = mean_squared_error(age_true[male_idx], age_pred[male_idx], squared=False)
print(hrmse, hmae, armse, amae)

hmae = mean_absolute_error(height_true[female_idx], height_pred[female_idx])
hrmse = mean_squared_error(height_true[female_idx], height_pred[female_idx], squared=False)
amae = mean_absolute_error(age_true[female_idx], age_pred[female_idx])
armse = mean_squared_error(age_true[female_idx], age_pred[female_idx], squared=False)
print(hrmse, hmae, armse, amae)

7.380669872114995 5.655118870326452 7.763985580250762 5.016331099169594
6.34729016434969 4.955917169843401 8.071302909828747 6.04548358127049


In [34]:
# With no masking: 
noMask = np.array([7.233415244325654, 5.551037321908133, 5.5894605275559694, 4.0340787514959064,
6.3426227432005735, 5.0193275048392145, 6.700610621889334, 4.668831104278564,])

#With vowel masking:
vowelMask = np.array([7.380669872114995, 5.655118870326452, 7.763985580250762, 5.016331099169594,
6.34729016434969, 4.955917169843401, 8.071302909828747, 6.04548358127049,])

# With nasal masking:
nasalMask = np.array([7.19552751521195, 5.525854961395263, 5.729919756395839 ,4.090414264610836,
6.347995603789645, 5.0413198852539045, 6.682625703454871, 4.752221087455749,])

# With semi vowel masking:
semiVowelMask = np.array([7.265708043165038, 5.558168871743338, 6.271451785077046, 4.247193532875605,
6.322400202082834, 5.0032027882167265, 6.655142298550157, 4.861911989484514,])

# With affricates masking:
affricatesMask = np.array([7.233415244325654, 5.551037321908133, 5.5894605275559694, 4.0340787514959064,
6.3426227432005735, 5.0193275048392145, 6.700610621889334, 4.668831104278564,])

# With fricatives masking:
fricativesMask = np.array([7.326146191429742, 5.660307633536203, 5.929453835115323, 4.276259066513607,
6.524805962265495, 5.1516073869977665, 6.53886097498481, 4.760440237317766,])

# With stop masking:
stopMask = np.array([7.276793178466237, 5.570184124537876, 5.820922343011691, 4.0579382479531425,
6.526454619386045, 5.149866971697126, 6.904805373957733, 4.968452154976981,])

# With other masking:
otherMask = np.array([7.310862375889352, 5.584817414419992, 5.915792169549828, 4.154623309748513,
6.158481982347664, 4.923283308846609, 7.51586216541237, 5.009144041606358,])


In [55]:
a = np.round(((vowelMask-noMask)/noMask)*100, 2)
print(a[0], a[4], a[2], a[6])

2.04 0.07 38.9 20.46


In [56]:
a = np.round(((nasalMask-noMask)/noMask)*100, 2)
print(a[0], a[4], a[2], a[6])

-0.52 0.08 2.51 -0.27


In [57]:
a = np.round(((semiVowelMask-noMask)/noMask)*100, 2)
print(a[0], a[4], a[2], a[6])

0.45 -0.32 12.2 -0.68


In [58]:
a = np.round(((affricatesMask-noMask)/noMask)*100, 2)
print(a[0], a[4], a[2], a[6])

0.0 0.0 0.0 0.0


In [59]:
a = np.round(((fricativesMask-noMask)/noMask)*100, 2)
print(a[0], a[4], a[2], a[6])

1.28 2.87 6.08 -2.41


In [60]:
a = np.round(((stopMask-noMask)/noMask)*100, 2)
print(a[0], a[4], a[2], a[6])

0.6 2.9 4.14 3.05


In [61]:
a = np.round(((otherMask-noMask)/noMask)*100, 2)
print(a[0], a[4], a[2], a[6])

1.07 -2.9 5.84 12.17


In [76]:
plt.barh(categories, values)
plt.xlabel("Relative importance")
plt.ylabel("Types of phone")
plt.title("")
plt.savefig('phoneGraph.png', dpi=600, bbox_inches='tight')
plt.show()

In [15]:
max(values)

122898

In [156]:
model.uncertainty_loss.log_var_height

Parameter containing:
tensor(-0.0056, device='cuda:0', requires_grad=True)

In [157]:
model.uncertainty_loss.log_var_age

Parameter containing:
tensor(-0.0035, device='cuda:0', requires_grad=True)

In [158]:
model.uncertainty_loss.log_var_gender

Parameter containing:
tensor(-0.0065, device='cuda:0', requires_grad=True)