# Notebook #3 for Filtering of low-quality generations with STT

This notebook is designed to filter out low-quality generated audio recordings that have been trimmed using the VAD model (see notebook #2).

In [None]:
import pickle
import shutil
import datetime
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict, Counter

In [None]:
def save_obj(obj, path='object.pkl'):
    with open(path, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(path='object.pkl'):
	    with open(path, 'rb') as f:
	        return pickle.load(f)

def transcription_normalizer(string: str):
    for symbol in ['.','!','?',',']:
        string = string.replace(symbol, '')
    string = string.replace('ё', 'е')
    for i, j in zip([str(k) for k in range(10)], 
                    ['ноль', 'один', 'два', 'три', 'четыре', 'пять', 
                     'шесть', 'семь', 'восемь', 'девять']):
        string = string.replace(i, j)
    return string.strip().lower()

def batch_gen(df: list, bs: int = 10):
    for i in range(0, len(df), bs):
        yield df[i:i+bs]

## 1. Configs

In [None]:
path_to_data = Path('./data/Synt-RuSC/interim/')
output_data_path = Path('./data/Synt-RuSC/processed/')
results_path = Path('./results/')

# check path_to_data path
assert path_to_data.is_dir(), 'path is not exist, check "path_to_data"'

# create output_data_path and results_path
output_data_path.mkdir(parents=True, exist_ok=True)
results_path.mkdir(parents=True, exist_ok=True)

en_ru_label_enc = {'yes': 'да', 'no': 'нет', 'up': 'вверх', 'down': 'вниз', 
                   'left': 'налево', 'right': 'направо', 'on': 'включи', 
                   'off': 'выключи', 'stop': 'стоп', 'go': 'иди', 
                   'forward': 'вперед', 'backward': 'назад', 
                   'follow': 'следуй', 'visual': 'наблюдай', 'learn': 'изучай',
                   
                   'zero': 'ноль', 'one': 'один', 'two': 'два', 'three': 'три',
                   'four': 'четыре', 'five': 'пять', 'six': 'шесть', 
                   'seven': 'семь', 'eight': 'восемь', 'nine': 'девять',
                   
                   'create': 'создай', 'cry': 'зарыдай', 'over': 'сверх', 
                   'discord': 'разлад', 'harm': 'вред', 'dies': 'гибнет', 
                   'nails': 'гвозди', 'rustier': 'ржавее', 'exclude': 'исключи', 
                   'motto': 'девиз', 'grief': 'беда',  'newer': 'новее', 
                   'knock': 'стучи', 'blow_off': 'сдуй'}

ru_ind_label_enc = {'да': 0, 'нет': 1, 'вверх': 2, 'вниз': 3, 'налево': 4, 
                    'направо': 5, 'включи': 6, 'выключи': 7, 'стоп': 8, 
                    'иди': 9, 'вперед': 10, 'назад': 11, 'следуй': 12, 
                    'наблюдай': 13, 'изучай': 14, 
                    
                    'ноль': 15, 'один': 16, 'два': 17, 'три': 18, 'четыре': 19, 
                    'пять': 20, 'шесть': 21, 'семь': 22, 'восемь': 23, 
                    'девять': 24, 
                    
                    'out_of_vocabulary': 25}

ru_ind_label_enc_oov = {'создай': 0, 'зарыдай': 1, 'сдуй': 2, 'сверх': 3, 
                        'разлад': 4, 'вред': 5, 'гибнет': 6, 'гвозди': 7, 
                        'ржавее': 8, 'исключи': 9, 'девиз': 10, 'беда': 11, 
                        'новее': 12, 'стучи': 13, 'out_of_vocabulary': 14}

batch_size = 50
target_sr = 16000

## 2. Filtering of low-quality generations with STT

### 2.1 Whisper

In [None]:
import torch
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration


device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16


def get_whisper_transcriptions(audio_list: list, sr: int):
    # get input feats
    inputs = processor(audio_list, sampling_rate=sr, return_tensors="pt")
    inputs = inputs.to(device, torch_dtype)
    
    # generate token ids
    predicted_ids = model.generate(inputs.input_features)
    
    # decode token ids to text
    return processor.batch_decode(predicted_ids, skip_special_tokens=True)

In [None]:
# Load model

processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
processor.feature_extractor.sampling_rate = target_sr

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(device, torch_dtype)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language="russian", task="transcribe")

In [None]:
res_whisper = defaultdict(dict)

for subset_pathway in path_to_data.iterdir():
    subset = subset_pathway.name.lower()
    
    audio_names, inp_audio_list = [], []
    for word_pathway in subset_pathway.iterdir():
        word = word_pathway.name.lower()
        
        if word not in en_ru_label_enc.keys():
            continue
        
        for fp in word_pathway.iterdir():
            name = f'{word}_{fp.stem}'
            audio_names.append(name)
            
            res_whisper[subset][name] = {'true': en_ru_label_enc[word], 
                                         'pred': None}

            # load audio file
            audio, sr_sample = librosa.load(fp, sr=None)
            if sr_sample != target_sr:
                audio = librosa.resample(audio, orig_sr=sr_sample, 
                                         target_sr=target_sr)
            inp_audio_list.append(audio)         

    # Get STT transcribe
    start_time = datetime.datetime.now()
    print(f'Start transcribe {len(inp_audio_list)} audio from {subset}')
    
    transcriptions = []
    for batch in tqdm(batch_gen(df=inp_audio_list, bs=batch_size),
                      total=len(inp_audio_list)//batch_size + 
                            (1 if len(inp_audio_list)%batch_size!=0 else 0)):
        batch_transcript = get_whisper_transcriptions(audio_list=batch, 
                                                      sr=target_sr)
        transcriptions.extend(batch_transcript)
    
    end_time = datetime.datetime.now()
    print(f'Get transcriptions with: {str(end_time-start_time)}')
    
    for ind, name in enumerate(audio_names):
        res_whisper[subset][name]['pred'] = transcription_normalizer(transcriptions[ind])

save_obj(obj=res_whisper, path=(results_path/'res-stt_whisper.pkl'))

### 2.2 NVIDIA Conformer-CTC Large (NeMo)

In [None]:
import nemo.collections.asr as nemo_asr

In [None]:
# Load model

asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(
    model_name="stt_ru_conformer_ctc_large"
)

In [None]:
res_nemo = defaultdict(dict)

for subset_pathway in path_to_data.iterdir():
    subset = subset_pathway.name.lower()
    
    audio_names, audio_pathways = [], []
    for word_pathway in subset_pathway.iterdir():
        word = word_pathway.name.lower()
        
        if word not in en_ru_label_enc.keys():
            continue
        
        for fp in word_pathway.iterdir():
            name = f'{word}_{fp.stem}'
            audio_pathways.append(str(fp))
            audio_names.append(name)
            
            res_nemo[subset][name] = {'true': en_ru_label_enc[word], 
                                      'pred': None}
    
    start_time = datetime.datetime.now()
    print(f'Start transcribe {len(audio_pathways)} audio')
    
    transcriptions = []
    for batch in tqdm(batch_gen(df=audio_pathways, bs=batch_size), 
                      total=len(audio_pathways)//batch_size + 
                            (1 if len(audio_pathways)%batch_size!=0 else 0)):
        batch_transcript = asr_model.transcribe(
            paths2audio_files=batch, batch_size=batch_size, verbose=False)
        transcriptions.extend(batch_transcript)
    
    end_time = datetime.datetime.now()
    print(f'Get transcriptions with: {str(end_time-start_time)}')
    
    for ind, name in enumerate(audio_names):
        res_nemo[subset][name]['pred'] = transcription_normalizer(transcriptions[ind])

save_obj(obj=res_nemo, path=(results_path/'res-stt_nemo.pkl'))

### 2.3 Vosk

In [None]:
import subprocess

In [None]:
tmp_vosk = results_path/'tmp_vosk'
tmp_vosk.mkdir(parents=True, exist_ok=True)

for subset_pathway in path_to_data.iterdir():
    for word_pathway in subset_pathway.iterdir():
        inp_path = word_pathway

        bash_command =  'vosk-transcriber '\
                       f'--input {str(word_pathway)} '\
                       f'--output {str(tmp_vosk)} '\
                        '--model-name "vosk-model-ru-0.42" --lang "ru"'

        process = subprocess.Popen(bash_command.split(), stdout=subprocess.PIPE)
        output, error = process.communicate()

In [None]:
res_vosk = defaultdict(dict)

for subset_pathway in tmp_vosk.iterdir():
    subset = subset_pathway.name.lower()
    
    for word_pathway in subset_pathway.iterdir():
        word = word_pathway.name.lower()
        
        if word not in en_ru_label_enc.keys():
            continue
        
        for fp in word_pathway.iterdir():
            name = f'{word}_{fp.stem}'
            
            with open(fp, "r") as f:
                transcription = f.read()
            
            res_vosk[subset][name] = {
                'true': en_ru_label_enc[word], 
                'pred': transcription_normalizer(transcription)
            }

save_obj(obj=res_vosk, path=(results_path/'res-stt_vosk.pkl'))

### 2.4 Filtering

In [None]:
#res_whisper = load_obj(path=results_path/'res-stt_whisper.pkl')
#res_nvidia = load_obj(path=results_path/'res-stt_nemo.pkl')
#res_vosk = load_obj(path=results_path/'res-stt_vosk.pkl')

In [None]:
for subset in res_whisper.keys():
    (output_data_path / subset).mkdir(parents=True, exist_ok=True)
    
    for name in res_whisper[subset].keys():
        word = name.split('_')[0]
        fn = name.split(f'{word}_')[1]+'.wav'
        fp = path_to_data / subset / word / fn

        # check pathway to interim file
        if not fp.is_file():
            print(f'File not found! Check pathway to interim file: {fp}!')

        # check true words for all stt results
        assert res_whisper[subset][name]['true'] == \
               res_nvidia[subset][name]['true'] == \
               res_vosk[subset][name]['true'] , f'err: {name}'
        
        true = res_whisper[subset][name]['true']
        
        pred_counter = Counter([
            res_vosk[subset][name]['pred'], 
            res_whisper[subset][name]['pred'], 
            res_nvidia[subset][name]['pred']
        ])

        if true in pred_counter.keys():
            # "hard" group 
            if pred_counter[true] == 1:
                # save if needed
                '''
                shutil.copy(fp, 
                            output_data_path / subset / word / fn)
                '''
                
            # "good" group 
            else:
                shutil.copy(fp, 
                            output_data_path / subset / word / fn)
            
        # "bad" group 
        else: 
            continue