In [1]:
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
from tqdm import tqdm
from infowavegan import WaveGANQNetwork
from torch.utils.data import DataLoader
import glob
import os
from scipy.io.wavfile import read
import librosa
from IPython.display import Audio, display
import train_Q2_TD
import q_dev_fns

In [2]:
NUM_CATEG = 11
SLICE_LEN = 16384
device="cuda"
LEARNING_RATE = 1e-4
timit_words = "she had your suit in dark greasy wash water all year".split(' ')+['UNK']
datadir = "/home/stephan/notebooks/talker_variability/TIMIT_padded/"
BATCH_SIZE = 192
start_epoch = 0
start_step = 0
NUM_EPOCHS = 25

In [3]:
Q = WaveGANQNetwork(slice_len=SLICE_LEN, num_categ=NUM_CATEG).to(device).train()
optimizer_Q_to_Q = optim.RMSprop(Q.parameters(), lr=LEARNING_RATE)
criterion_Q = lambda inpt, target: torch.nn.CrossEntropyLoss()(inpt, target.max(dim=1)[1])

In [4]:
# start a dataloader with the Q network
dataset = train_Q2_TD.AudioDataSet(datadir, SLICE_LEN, NUM_CATEG, timit_words)
dataloader = DataLoader(
    dataset,
    BATCH_SIZE,
    shuffle=True,
    num_workers=1,
    drop_last=True
)

Loading data


100%|██████████| 5082/5082 [00:01<00:00, 3961.78it/s]


In [5]:
regenerate = False
if regenerate:
    step = start_step
    for epoch in range(start_epoch + 1, NUM_EPOCHS):
        print("Epoch {} of {}".format(epoch, NUM_EPOCHS))
        print("-----------------------------------------")

        pbar = tqdm(dataloader)            
        for i, trial in enumerate(pbar):            
            reals = trial[0].to(device)
            labels = trial[1].to(device)        
            optimizer_Q_to_Q.zero_grad()
            adult_recovers_from_adult = Q(reals)    
            Q_comprehension_loss = criterion_Q(adult_recovers_from_adult, labels[:,0:NUM_CATEG]) # Note we exclude the UNK label --  child never intends to produce unk
            print(Q_comprehension_loss)
            Q_comprehension_loss.backward()
            optimizer_Q_to_Q.step()
            step += 1
    torch.save(Q, 'saved_networks/adult_pretrained_Q_network.torch')
else:
    Q = torch.load('saved_networks/adult_pretrained_Q_network.torch')
    Q.eval()
    
# freeze it
for p in Q.parameters():
    p.requires_grad = True

# Run the Adult Q network on a new dataset

In [None]:
# Directory from a network that hasn't elarned the word-to-referent mapping
test_folder = '/home/stephan/notebooks/ciwganfiwgan-pytorch/run_log/12_TD_2000+_batchedQupdate/2001/'
X, Y, filenames  = q_dev_fns.load_wavs_and_labels(test_folder, SLICE_LEN, NUM_CATEG, device, timit_words)
Q_vals = torch.nn.Softmax(dim=1)(Q(X))

In [None]:
q_dev_fns.evaluate_asr_system(Q_vals, Y, device, timit_words, filenames)

In [6]:
# Directory with hand-labeled data
test_folder = '/home/stephan/notebooks/ciwganfiwgan-pytorch/q2_dev_data/2001_relabeled/'
X, Y, filenames  = q_dev_fns.load_wavs_and_labels(test_folder, SLICE_LEN, NUM_CATEG, device, timit_words)
Q_vals = torch.nn.Softmax(dim=1)(Q(X))

100%|██████████| 133/133 [00:00<00:00, 1504.81it/s]


In [7]:
Q_vals.shape

torch.Size([133, 11])

In [8]:
import imp
imp.reload(q_dev_fns)
Q_vals_with_unks = q_dev_fns.mark_unks_in_Q(Q_vals, .1, device)

  import imp


In [9]:
Qnetwork_results = q_dev_fns.evaluate_asr_system(Q_vals_with_unks, Y, device, timit_words, filenames)
Qnetwork_results

{'df':     human label asr system label  matches   
 0          suit             suit     True  \
 1           UNK             your    False   
 2           UNK              UNK     True   
 3           all              all     True   
 4           UNK              UNK     True   
 ..          ...              ...      ...   
 128         UNK              all    False   
 129        your             your     True   
 130        your              UNK    False   
 131      greasy           greasy     True   
 132         UNK              all    False   
 
                                              filenames   entropy  recognized  
 0    /home/stephan/notebooks/ciwganfiwgan-pytorch/q...  0.000931        True  
 1    /home/stephan/notebooks/ciwganfiwgan-pytorch/q...  0.019612       False  
 2    /home/stephan/notebooks/ciwganfiwgan-pytorch/q...  0.010132        True  
 3    /home/stephan/notebooks/ciwganfiwgan-pytorch/q...  0.006908        True  
 4    /home/stephan/notebooks/ciwganfiwg

# Entropy Threshold

In [None]:
%reload_ext rpy2.ipython
%R library('ggplot2')
%R library('pROC')

In [None]:
qnr =  Qnetwork_results['df']
%R -i qnr
%R head(qnr)

In [None]:
%%R
ggplot(qnr) + geom_density(aes(x=entropy, color = recognized))

In [None]:
%%R
library('plotROC')

In [None]:
%%R
ggplot(qnr, aes(m = entropy, d = -1*recognized)) + geom_roc()

# Error analysis on the Q network failures

In [None]:
# inspect a single file
q_dev_fns.inpsect_failure(Qnetwork_results, 1)

4, Q network is just wrong  
5, Q network is wrong, but it's ambiguous  
13, in vs. your -- similar production to #4  
16, Q network is wrong  
17, Q network is just wrong  
20, Q network is just wrong  
26, Q network is just wrong  
27, ambigous -- year and in at the same time  
31, Q network is wrong  
44, Q network is wrong   
49, Q netwrok is wrong  
66, Q netwrok is wrong, but slightly weird water   
74, Q netwrok is wrong, but slightly weird water   
76, ambiguous between year and in  
78, Q netwrok is wrong  

# Error Analysis on the Whisper Failures

In [None]:
import imp
imp.reload(train_Q2_TD)

In [None]:
import faster_whisper
vocab = pd.read_csv('data/vocab.csv')
vocab = vocab.loc[vocab['count'] > 20]

# Method 1: upweight proportionately
#vocab.loc[vocab.word.isin(timit_words),'count'] = 100000. * vocab.loc[vocab.word.isin(timit_words)]['count']
#vocab['probability'] = vocab['count'] / np.sum(vocab['count'])

# Method 2: uniform over timit words, with remaining words divided
prob_to_give_to_timit_words = .8
vocab['upweighted_prob'] = 0
vocab.loc[vocab.word.isin(timit_words) ,'upweighted_prob'] = prob_to_give_to_timit_words / len(timit_words)
vocab['base_probability'] = (vocab['count'] / np.sum(vocab['count'])) * (1. - prob_to_give_to_timit_words)
vocab['probability'] = vocab['base_probability'] + vocab['upweighted_prob']


vocab.word = vocab.word.astype('str')
fast_whisper_model = faster_whisper.WhisperModel('medium.en', device="cuda", compute_type="float16")
Q2_GLOBALS = {
        "MIN_DECODING_PROB" : .1,
        "MAX_NOSPEECH_PROB" : .1,
        "MAX_UNK_PROB" : .5, # this was .2
        "Q2_TIMEOUT" : 2
    }

In [None]:
vocab.loc[vocab.word == 'greasy']

In [None]:
indices_of_recognized_words, Q2_probs, filenames, whisper_recognition_info = train_Q2_TD.Q2_whisper(
    X[:,0,:], Y, fast_whisper_model, timit_words, vocab, -1, Q2_GLOBALS, write_only=False)

In [None]:
whisper_results = q_dev_fns.evaluate_asr_system(torch.from_numpy(Q2_probs).to(device), Y, device, timit_words, filenames)
whisper_results

In [None]:
# inspect a single file
q_dev_fns.inpsect_failure(whisper_results, 1)

# Todos

In [None]:
# [X] Error analysis -- which words is this bad at? Years and yours, remapping a lot of them to IN
# [X] compare with the quality of the output of whisper -- whisper is never wrong about the identity, 
# but it over-assigns the UNKs. So maybe it isn't interevening early enough
# [ ] Performance in the adult Q network is bad, but is it okay that it is this errorful?     
   # [ ] Could figure out a way to train it on other intermediate productions -- eg use WHISPER 
    # to identify a bunch of intermediate and noisy candidates; then this netowrk will be fast

# there should be no UNKs in this set

    

# [ ] could try fiddling with the params in whisper -- but this is too slow to operate from the beginning    
# [ ] How to handle uncertainty in the Q network -- the recongition performance depends on how we pick out UNKs 
    # Entropy based cirterion
    # give it a 12th category of UNK

    
    


# [ ] This means we can do early intervention with the frozen Q network
# [ ] THis means we might have a submission for the 28th
# [ ] consider VAD

# Adding Unks to the Test Set

In [None]:
raise ValueError("Don't run this part automatically -- should jsut need to happen once")

In [None]:
wavs = pd.DataFrame({'filename': glob.glob('/home/stephan/notebooks/ciwganfiwgan-pytorch/q2_dev_data/2001_relabeled/*.wav')
})
wavs['basename'] = [os.path.basename(x) for x in wavs.filename]
wavs['uuid'] = [x.replace('.wav','').split('_')[-1] for x in  wavs['basename']]
wavs['word']= [x.replace('.wav','').split('_')[0] for x in  wavs['basename']]
wavs['keep']= ['_keep_' in x for x in wavs['basename']]
wavs

In [None]:
print(wavs.shape)
keeps = wavs.loc[wavs.keep]
print(keeps.shape)
originals = wavs.loc[~wavs.keep]
unks = originals.loc[~originals.uuid.isin(keeps.uuid)]
print(unks.shape)

In [None]:
unks.head(5)

In [None]:
import shutil

In [None]:
# rename the UNKS as such
unks['new_word'] = 'UNK'
unks['new_basename'] = unks['new_word'].map(str)+'_keep_'+unks['uuid'].map(str)+'.wav'
directory = os.path.dirname(unks.iloc[0]['filename'])
unks['new_filename'] = [os.path.join(directory, x) for x in  unks['new_basename']]
unks.iloc[0].new_filename

[os.rename(x['filename'], x['new_filename']) for x in unks.to_dict('records')] 

In [None]:
# delete the ontes that ar in keeps
redundant_files= originals.loc[originals.uuid.isin(keeps.uuid)]

In [None]:
redundant_files.filename

In [None]:
[os.remove(x) for x in redundant_files.filename]