In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import csv
import shutil
from itertools import islice
from pathlib import Path
from IPython.display import Audio, display
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pylab
import random
import concurrent.futures
import torch
import torchaudio
from fastai.vision import *
from fastai.metrics import error_rate

In [None]:
from exp.nb_AudioCommon import *
from exp.nb_DataBlock import *
from exp.nb_DataAugmentation import *
from exp.nb_TransformsManager import *
from exp.nb_AudioTransformsManager import *

In [None]:
path = Path("/home/jupyter/rob/TIMIT/timit")
path.ls()
path_test = Path("/home/jupyter/rob/test_augment/audio")

In [None]:
path_phoneme = path/"PHONEMES"
path_spectrogram = path/"spectrogram"
path_synth = path/"synth"
path_google = Path('/home/jupyter/rob/googlespeech/train/audio3')

In [None]:
vowel_maps = {
    'aa': 'ɑ', 'ae':'æ', 'ah':'ʌ', 'ao':'ɔ', 'aw':'aʊ', 'ax':'ə',
    'axr':'ɚ', 'ay':'aɪ', 'eh':'ɛ', 'er':'ɝ', 'ey':'eɪ', 'ih':'ɪ',
    'ix':'ɪ', 'iy':'i', 'ow':'oʊ', 'oy':'ɔɪ', 'uh':'ʊ', 'uw':'u', 'ux':'u',
}

#dx is the flap like tt in butter, arpabet says it translates to ɾ in ipa
#but im not so sure
#nx is another one to be careful with, it translates to either ng or n as in winner
#wh is meant to be wh like why/when/where but most ipa consider it a w
cons_maps = {
    'ch':'tʃ', 'dh':'ð', 'dx':'ɾ', 'el':'l', 'em':'m', 'en':'n', 'hh':'h',
    'jh':'dʒ', 'ng':'ŋ', 'nx':'n', 'q':'ʔ', 'r':'ɹ', 'sh':'ʃ', 'th':'θ',
    'wh':'w', 'y':'j', 'zh':'ʒ'
}

#these are maps that only timit uses, not arpanet
timit_specific_maps = {
    'ax-h':'ə', 'bcl':'b', 'dcl':'d', 'eng':'ŋ', 'gcl':'g', 'hv':'h', 'kcl':'k',
    'pcl':'p', 'tcl':'t', 'pau':'silence', 'epi':'silence', 'h#':'silence',
}

def get_timit_to_ipa_dict():
    timit_phonemes = [x.stem for x in path_phoneme.ls()]
    timit_to_ipa_dict = {k:k for k in timit_phonemes}
    for k,v in vowel_maps.items(): timit_to_ipa_dict[k] = v
    for k,v in cons_maps.items(): timit_to_ipa_dict[k] = v
    for k,v in timit_specific_maps.items(): timit_to_ipa_dict[k] = v
    return timit_to_ipa_dict

In [None]:
def get_timit_dict():
    timit_dict = {}
    with open(path/"DOC/TIMITDIC.TXT") as f:
        for line in f:
            if line[0].isalpha():
                #note the split is by double space, not single
                word, timit_string = line.split('  ')
                timit_string = timit_string.replace('/', '').replace('1', '').replace('2','').strip()
                timit_dict[word] = timit_string
    return timit_dict
    

In [None]:
timit_string_dict = get_timit_dict()

In [None]:
timit_string_dict["zoologist"]

In [None]:
timit_to_ipa_dict = get_timit_to_ipa_dict()

In [None]:
ipa_to_timit_dict = {v:k for k,v in timit_to_ipa_dict.items() if 'cl' not in k and k not in ['wh','nx', 'en', 'em']}

In [None]:
def get_lengths_by_phoneme():
    len_dict = {}
    count = 0
    for p in path_phoneme.ls():
        len_dict[p] = []
        print(p)
        p_count = 0
        p_len = 0
        with os.scandir(p) as sd:
            for entry in sd:
                if(count % 10000 == 9999):
                    print(count)
                count+=1
                fname = path/p/entry
                y, sr = librosa.load(fname)
                len_dict[p].append(len(y)/sr)
    return len_dict

In [None]:
def get_all_speaker_ids():
    speaker_list = []
    for p in path_phoneme.ls():
        with os.scandir(p) as sd:
            for entry in sd:
                speaker_id = str(entry).split('-')[1]
                speaker_list.append(speaker_id)
    return list(set(speaker_list))

In [None]:
speaker_ids = get_all_speaker_ids()

In [None]:
speaker_ids[0:10]

In [None]:
# takes a few minutes to run, generates a dict with a key for each phoneme and value is list of the lengths 
# of each sample we have of that phoneme. Used to generate stats. 
# len_dict = get_lengths_by_phoneme()

In [None]:
# fixed_dict = {str(k).split('/')[-1]:v for k,v in len_dict.items()}

In [None]:
# import pickle
# pickle.dump(fixed_dict, open( "phon_len.p", "wb" ) )

In [None]:
len_dict = pickle.load(open("phon_len.p", 'rb'))

In [None]:
def rd4(x):
    return round(x, 4)

In [None]:
stats_dict = {k:list(map(rd4, [min(v), max(v), sum(v)/len(v)])) for k,v in len_dict.items()}

In [None]:
stats_dict

In [None]:
def convert_ipa_to_timit(ipa_string):
    timit_list = []
    skip=False
    ipa_string = ipa_string.replace('a', 'ɑ').replace('r', 'ɹ')
    for i in range(len(ipa_string)):
        if(skip):
            skip = False
            continue
        if(i < len(ipa_string)-1) and ipa_string[i:i+2] in ipa_to_timit_dict:
            timit_list.append(ipa_to_timit_dict[ipa_string[i:i+2]])
            skip = True
        else: timit_list.append(ipa_to_timit_dict[ipa_string[i]])
    return timit_list

In [None]:
convert_ipa_to_timit('jɛs')

In [None]:
convert_ipa_to_timit('daʊn')

In [None]:
convert_ipa_to_timit('dɔg')

In [None]:
def synthesize_word(word, timit_string_dict = None, speaker_ids=None, by_sex = True, above_mean = True, one_speaker=False):
    timit_string_dict = timit_string_dict or get_timit_dict()
    speaker_ids = speaker_ids or get_all_speaker_ids()
    if(word == "bird"): timit_string = "b er d"
    else: timit_string = timit_string_dict[word]
    print(timit_string)
    timit_list = timit_string.split()
    speaker_sex_list = random.choice([['M'], ['F']]) if by_sex else ['M', 'F']
    
    speaker_id_list = [random.choice(speaker_ids)] if one_speaker else speaker_ids
    tensor_list = []
    for timit_phoneme in timit_list:
        fnames = [fname for fname in os.listdir(path_phoneme/timit_phoneme) if fname[4] in speaker_sex_list  and fname[4:9] in speaker_id_list]
        fname = random.choice(fnames)
        y, sr = torchaudio.load(path_phoneme/timit_phoneme/fname)
        if(above_mean):
            while(len(y[0])/sr < stats_dict[timit_phoneme][2]):
                fname = random.choice(fnames)
                y, sr = torchaudio.load(path_phoneme/timit_phoneme/fname)
        tensor_list.append(y)
    return torch.cat(tensor_list, dim=1)
    

In [None]:
x = synthesize_word("down", above_mean=True, one_speaker=False)
display(AudioItem(AudioData(x, 16000)))

In [None]:
eng_commands = "bed bird dog down no off on one three tree".split()
ipa_commands = 'bɛd bɝrd dɔg daʊn noʊ ɔf ɑn wʌn θri tri'.split()

In [None]:
for word in eng_commands:
    print(word)
    x = synthesize_word(word)
    display(AudioItem(AudioData(x, 16000)))

In [None]:
# Note: come back here and make sure we are only grabbing phonemes of a certain length when constructing words
# Use the mean length of each indiv. one as a guide, and then make a huge dataset. (can also consider fine tuning 
# by gender or individual speaker, look at the stats and see how this would affect the combinatorics)
def synthesize_timit_word_from_ipa(ipa_string):
    tensor_list = []
    timit_list = convert_ipa_to_timit(ipa_string)
    for timit_phoneme in timit_list:
        fname = random.choice((path_phoneme/timit_phoneme).ls())
        y, sr = torchaudio.load(fname)
        tensor_list.append(y)
    return torch.cat(tensor_list, dim=1)

In [None]:
def synthesize_n_examples_of_word(n, word, by_sex = True, above_mean = True, one_speaker=False):
    timit_string_dict = get_timit_dict()
    speaker_ids = get_all_speaker_ids()
    if not(os.path.exists(path_synth/word)): os.mkdir(path_synth/word)
    for i in range(n):
        x = synthesize_word(word, timit_string_dict=timit_string_dict, speaker_ids=speaker_ids,
                            by_sex=by_sex, above_mean=above_mean, one_speaker=one_speaker)
        torchaudio.save(f'{path_command}/{word}-{i}.wav', x, 16000)

In [None]:
for word in eng_commands:
    synthesize_n_examples_of_word(5000, word, by_sex=True, above_mean=True, one_speaker=False)

In [None]:
audio_list_train = AudioList.from_folder(path_synth)

In [None]:
path_google.ls()

In [None]:
audio_list_valid = AudioList.from_folder(path_google)

In [None]:
audio_list_train.add(audio_list_valid)

In [None]:
audio_list_valid

In [None]:
audio_list = audio_list_train.split_by_valid_func(valid_func).label_from_folder()

In [None]:
audio_list

In [None]:
def valid_func(o):
    return 'nohash' in o.stem and o.stem[3] in '2 4 6 8'.split()


In [None]:
#audio_list.add_test_folder(path_test);

In [None]:
tm_speech = AudioTfmsManager.get_audio_tfms_manager(
                            spec_augment=True, pct_hori=.2, num_vert=0, num_hori=1,
                            spectro=True, #We're going to replace it...
                            mx_to_pad=16127, pad_type="middle", #1 sec window
                            white_noise=False, noise_scl=0.001, # Small noise
                            modulate_volume=False, lower_gain=.95, upper_gain=1.05, # Not big volume variation
                            random_cutout=False,
                            pad_with_silence=False,
                            pitch_warp=False,
                            down_and_up=False)

In [None]:
tfms = tm_speech.get_tfms()
del tfms[1][-1]
tfms

In [None]:
data_speech = audio_list.transform(tfms).databunch(bs=64)

In [None]:
data_speech

In [None]:
def adapt_first_layer(src_model, nChannels):
    '''
    Change first layer of network to accomodate new channels
    '''
    # save original
    original_weights = src_model[0][0].weight.clone()
    new_weights = original_weights[:,0:1,:,:]

    # create new layes
    new_layer = nn.Conv2d(nChannels,64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    new_layer.weight = nn.Parameter(new_weights)

    # Replace layer and put to gpu.
    src_model[0][0] = new_layer
    src_model.cuda();

In [None]:
data_speech.train_ds[0][0].show()
data_speech.valid_ds[0][0].show()


In [None]:
learn_speech = cnn_learner(data_speech, models.resnet50, metrics=accuracy)

In [None]:
nChannels=1

# Alter existing model
adapt_first_layer(learn_speech.model,nChannels)
#print(f'First layer shape: {learn_speech.model[0][0].weight.shape}')

In [None]:
learn_speech.lr_find()

In [None]:
learn_speech.recorder.plot()

In [None]:
learn_speech.fit_one_cycle(8, 2e-2)

In [None]:
interp = ClassificationInterpretation.from_learner(learn_speech)

In [None]:
interp.plot_confusion_matrix()

In [None]:
learn_speech.lr_find()

In [None]:
learn_speech.recorder.plot()

In [None]:
learn_speech.fit_one_cycle(9, 1e-3)

In [None]:
learn_speech.save("synth-99")

In [None]:
learn_speech.load("synth-stage2");

In [None]:
data_speech.test_ds[1300][0].show()

In [None]:
correct = 0
for d in data_speech.test_ds[0:250]:
    if learn_speech.predict(d[0])[1].item() == 5:
        correct+=1 

In [None]:
correct

In [None]:
learn_speech.predict(data_speech.test_ds[4][0])

In [None]:
audio_list.test[0]

In [None]:
def display_audio_prediction():
    rand_file = test_files_list[random.randint(0, num_files-1)]
    clip, sr = librosa.load(path_test_audio/rand_file, sr=None)
    print(rand_file)
    img_filename = rand_file + ".png"
    image = open_image(path_test_spectrogram/img_filename)
    pred = learn.predict(image)
    print(f"Prediction: {pred[0]}")
    for idx, pct in enumerate(pred[2]):
        if(pct.item() > 0.1):
            print(f"{data.classes[idx]}: {round(pct.item()*100, 2)}%")
    display(Audio(clip, rate=sr))