In [1]:
import transformers 
import soundfile as sf
import os
import pandas as pd
from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer
from src.models import Wav2Vec2ForFrameClassificationSAT
from alignment_helper_fns import *
from src.Charsiu import charsiu_sat_forced_aligner

In [2]:
def get_all_audiofiles_in_dir(audio_dir):
    audiofiles = []
    for root, dirs, files in os.walk(audio_dir):
        for fname in files:
            if '.wav' in fname:
                _fpath = os.path.join(root, fname)
                audiofiles.append(_fpath)
    return audiofiles

In [3]:
''' define directories '''
audio_dir = '/home/prad/datasets/ChildSpeechDataset/child_speech_16_khz'
manual_textgrids_dir = '/home/prad/datasets/ChildSpeechDataset/manually-aligned-text-grids/'
mfa_sat_dir = '/home/prad/datasets/ChildSpeechDataset/mfa_adapted'
output_path = './phone_matched_xvec_proj_textgrids'


EXCLUDE_FILES = ['0505_M_EKs4T10', '0411_M_LMwT32']

audiofiles = get_all_audiofiles_in_dir(audio_dir)
audiofiles = [filename for filename in audiofiles if not any([exf in filename for exf in EXCLUDE_FILES])]     

manual_textgrids = get_all_textgrids_in_directory(manual_textgrids_dir)
manual_textgrids = [filename for filename in manual_textgrids if not any([exf in filename for exf in EXCLUDE_FILES])]

SATVECTOR_SIZE=128

Extracting all textgrids in directory:	 /home/prad/datasets/ChildSpeechDataset/manually-aligned-text-grids/


43it [00:00, 6924.22it/s]


In [42]:
satvectors_csv = pd.read_csv('./extracted_xvectors_proj_libri.csv', index_col='Filename')

In [48]:
def extract_satvectors(satvectorcsv, audiofile):
    relevant_satvectors = satvectors[satvectors.index==audiofile]
    return satvectors[satvectors.index==audiofile].values[0]

In [4]:
def get_transcripts_for_audiofiles(audiofiles):
    transcripts = {}
    for filename in audiofiles:
        fname = filename.split('/')[-1]
        # speaker_dir = filename.split('/')[-2]
        # print(fname)
        # fname = os.path.join(manual_textgrids_dir, speaker_dir, fname[:-8]+'lab')
        fname = filename[:-3] + 'lab'
        # print(fname)
        # break
        f = open(fname)
        _transcript = f.read()
        # print(_transcript)
        transcripts[filename] = _transcript[:-1]
    return list(transcripts.values())

In [5]:
def load_aligner_from_modelpath(modelpath, tokenizer_name='charsiu/tokenizer_en_cmu'):
    charsiu = charsiu_sat_forced_aligner(model_path, ixvector_size=SATVECTOR_SIZE)
    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(tokenizer_name)
    feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0,
                                             do_normalize=True,
                                             return_attention_mask=False)
    processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

    model = Wav2Vec2ForFrameClassificationSAT.from_pretrained(modelpath, pad_token_id=processor.tokenizer.pad_token_id,
                                                              vocab_size=len(processor.tokenizer.decoder),
                                                              ivector_size=SATVECTOR_SIZE)
    charsiu.aligner = model
    return charsiu

In [6]:
def is_sil_or_unk(phone):
    silunk = ['sil', '[SIL]', '[UNK]', '<UNK>']
    return any([su in phone for su in silunk])
    
def extract_phones_from_textgrid(tg):
    phones = list(tg.phone.values)
    phones = [phone for phone in phones if not is_sil_or_unk(phone)]
    return phones

def get_phoneseqs_from_textgridpaths(textgridpaths):
    phone_seqs = []
    for tgpth in textgridpaths:
        tg = textgridpath_to_phonedf(tgpth, phone_key='ha phones')
        _phoneseq = extract_phones_from_textgrid(tg)
        phone_seqs.append(_phoneseq)
    return phone_seqs

In [50]:
def run_aligner_on_files(audiopaths, transcripts, output_dir, satvectorcsv=None, gt_phoneme_sequences=None):
    aligner = load_aligner_from_modelpath(model_path)
    print('Evaluating Aligner...')
    for ii in tqdm.tqdm(range(len(audiofiles))):
        audiofilepath = audiopaths[ii]
        if satvectorcsv is not None:
            ixvec = extract_satvectors(satvectorcsv, audiofilepath)
        
        _fname = audiofilepath.split('/')[-1].split('.')[-2]
        output_tg = os.path.join(output_dir, _fname+'.TextGrid')
        if gt_phoneme_sequences is not None:
            target_phones = gt_phoneme_sequences[ii]
            aligner.serve(audiofilepath, ixvector=ixvec, text=transcripts[ii], target_phones=target_phones, save_to=output_tg)
        else:
            aligner.serve(audiofilepath, ixvector=ixvec, text=transcripts[ii], save_to=output_tg)

In [51]:
speaker_ids = [dirs for root, dirs, files in os.walk(audio_dir)][0]
for kk, speaker_id in enumerate(speaker_ids):
    print(f"Ruinning speaker {speaker_id}, {kk+1}/{len(speaker_ids)}")
    model_path = os.path.join('./sat_xvector_proj_models/' + speaker_id)
    speaker_audiofiles = [audfile for audfile in audiofiles if speaker_id in audfile]
    speaker_manual_textgridfiles = [tgfile for tgfile in manual_textgrids if speaker_id in tgfile]
    speaker_transcripts = get_transcripts_for_audiofiles(audiofiles)
    
    speaker_textgrids_outputdir = os.path.join(output_path, speaker_id)
    gtphone_seqs = get_phoneseqs_from_textgridpaths(speaker_manual_textgridfiles)
    
    run_aligner_on_files(speaker_audiofiles, speaker_transcripts, speaker_textgrids_outputdir, satvectorcsv=satvectors_csv, gt_phoneme_sequences=gtphone_seqs)
    break

Ruinning speaker 0407_M_SJ, 1/42


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Evaluating Aligner...


  0%|                                                                                                                                                                                                                          | 0/3762 [00:00<?, ?it/s]


RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same