In [None]:
%cd /data/codes/sb-apa/

from src.brain import get_brain_class
from hyperpyyaml import load_hyperpyyaml
import speechbrain as sb
import torch
import json
import sys
import os

from utils.arpa import arpa_to_ipa
import pandas as pd
import re


In [None]:
def load_state_dict(hparams):
    wav2vec2_ckpt_path = f'{ckpt_path}/wav2vec2.ckpt'
    model_ckpt_path = f'{ckpt_path}/model.ckpt'

    wav2vec2_state_dict = torch.load(wav2vec2_ckpt_path)
    model_state_dict = torch.load(model_ckpt_path)

    hparams["wav2vec2"].load_state_dict(wav2vec2_state_dict)
    hparams["model"].load_state_dict(model_state_dict)

    return hparams

def init_model(hparams):
    brain_class = get_brain_class(hparams)

    model = brain_class(
            modules=hparams["modules"],
            hparams=hparams,
            run_opts=run_opts,
            checkpointer=hparams["checkpointer"],
        )

    hparams = load_state_dict(hparams)
    
    for key, value in hparams["modules"].items():
        value.eval()
    
    return model, hparams


In [None]:
DATA_DIR = "data"
APR_DATA_FOLDER = f'{DATA_DIR}/apr/'

RESULTS_FOLDER = f'{DATA_DIR}/results/'
EXP_METADATA_FILE = f'{RESULTS_FOLDER}/exp_metadata.csv'
APR_RESULTS_FILE = f'{RESULTS_FOLDER}/results_scoring.csv'
EPOCH_RESULTS_DIR = f'{RESULTS_FOLDER}/epoch_results'
PARAMS_DIR= f'{RESULTS_FOLDER}/params'


MODEL_TYPE = "w2v2"
SCORING_TYPE=""

APR_MODEL_DIR = f"pretrained/apr"
PRETRAINED_MODEL_DIR = f"pretrained/apr"
SCORING_HPARAM_FILE = f"hparams/apr.yml"

argv = [
    SCORING_HPARAM_FILE,
    "--data_folder", APR_DATA_FOLDER,
    "--exp_folder", APR_MODEL_DIR,
    "--batch_size", "4",
    "--exp_metadata_file", EXP_METADATA_FILE,
    "--results_file", APR_RESULTS_FILE,
    "--epoch_results_dir", EPOCH_RESULTS_DIR,
    "--params_dir", PARAMS_DIR
    ]

In [None]:
hparams_file, run_opts, overrides = sb.parse_arguments(argv)
with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin, overrides)

lexicon_path = "resources/lexicon"
ckpt_path = "results/apr/save/best"
label_encoder_path = "results/apr/save/label_encoder.txt"

hparams["ckpt_path"] = ckpt_path
hparams["label_encoder_path"] = label_encoder_path
label_encoder_path = hparams["label_encoder_path"]

prep_model, hparams = init_model(hparams)
label_encoder = sb.dataio.encoder.CTCTextEncoder.from_saved(label_encoder_path)


In [None]:
from src.data import apr_dataio_prep

train_data, valid_data, test_data, label_encoder = apr_dataio_prep(hparams, label_encoder)

In [None]:
sample = valid_data[0]
sample

In [None]:
phns = label_encoder.decode_ndim(sample["phn_encoded"])
phns

In [None]:
wavs = sample["sig"].unsqueeze(0).cuda()
wav_lens = torch.tensor([wavs.shape[1]]).cuda()
phn_encoded = sample["phn_encoded"].unsqueeze(0).cuda()
phns_eos = sample["phn_encoded_eos"].unsqueeze(0).cuda()
phns_bos = sample["phn_encoded_bos"].unsqueeze(0).cuda()

In [None]:
p_ctc, p_seq, wav_lens = prep_model.infer(wavs, wav_lens, phns_bos)

In [None]:
p_seq.shape

In [None]:
p_ctc.shape