In [None]:
%load_ext autotime
# %unload_ext autotime

In [None]:
import sys
sys.path.append("..")

In [None]:
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import soundfile as sf
import torch
from jiwer import wer

from sonorus.speech.lm import (
    FairseqTokenDictionary,
    W2lKenLMDecoder,
    W2lViterbiDecoder,
    W2lFairseqLMDecoder,
)

In [None]:
librispeech_eval = load_dataset("librispeech_asr", "clean", 
                                split="validation", 
                                # split="test",
                                ignore_verifications=True)#,
                                # download_mode="force_redownload")

In [None]:
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")

In [None]:
def map_to_array(batch):
    speech, _ = sf.read(batch["file"])
    batch["speech"] = speech
    return batch

In [None]:
librispeech_eval = librispeech_eval.map(map_to_array)

In [None]:
def map_to_pred(batch):
    input_values = tokenizer(batch["speech"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = tokenizer.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch

In [None]:
def get_wer(result, batch_size=-1, lm=False):
    
    def transcripts():
        return ([x[0] for x in result["transcription"]] 
                if lm else result["transcription"])
        
    errors = []
    
    if batch_size > 0:
        for i in range(0,len(result),batch_size):
            errors.append(
                wer(
                    result["text"][i:i+batch_size], 
                    transcripts()[i:i+batch_size]
                )
            )
    else:
        errors.append(wer(result["text"], transcripts()))
    
    return np.mean(errors)

In [None]:
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", get_wer(result, batch_size=1000, lm=False))

In [None]:
def map_to_pred_lm(batch):
    input_values = tokenizer(batch["speech"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    logits = logits.float().cpu().contiguous()
    decoded = decoder.decode(logits)
    # 1st sample, 1st best transcription
    transcription = decoder.post_process(decoded)
    batch["transcription"] = transcription
    return batch

In [None]:
token_dict = FairseqTokenDictionary(
    indexed_symbols=tokenizer.get_vocab()
)

In [None]:
lexicon_path = "/home/harold/Documents/IISc-work/imperio/data/speech/fairseq/librispeech_lexicon.lst"
lm_path = "/home/harold/Documents/IISc-work/imperio/data/speech/fairseq/lm_librispeech_kenlm_word_4g_200kvocab.bin"

# decoder = W2lKenLMDecoder(
#     token_dict=token_dict,
#     lexicon=lexicon_path,
#     lang_model=lm_path,
#     beam=1500,
#     beam_size_token=100,
#     beam_threshold=25,
#     lm_weight=1.5,
#     word_weight=-1,
#     unk_weight=float("-inf"),
#     sil_weight=0,
# )

In [None]:
result = librispeech_eval.map(map_to_pred_lm, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", get_wer(result, batch_size=1000, lm=True))

In [None]:
import optuna
from optuna.integration import BoTorchSampler

In [None]:
n_startup_trials = 10
bayes_opt_sampler = BoTorchSampler(n_startup_trials=n_startup_trials)
study = optuna.create_study(sampler=bayes_opt_sampler)

In [None]:
def objective(trial):
    
    lm_weight = trial.suggest_float("lm_weight", 0, 5)
    word_weight = trial.suggest_float("word_weight", -5, 5)
    sil_weight = trial.suggest_float("sil_weight", -5, 5)
    
    decoder = W2lKenLMDecoder(
        token_dict=token_dict,
        lexicon=lexicon_path,
        lang_model=lm_path,
        beam=500,
        beam_size_token=100,
        beam_threshold=25,
        lm_weight=lm_weight,
        word_weight=word_weight,
        unk_weight=float("-inf"),
        sil_weight=sil_weight,
    )
    
    result = librispeech_eval.map(
        map_to_pred_lm, 
        batched=True, 
        batch_size=1, 
        remove_columns=["speech"]
    )
    
    return get_wer(result, batch_size=1000, lm=True)

In [None]:
n_trials=128
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)

In [None]:
import joblib
joblib.dump(study, "speech-lm-study.jb")