In [278]:
import torch
import sys
import os
from model import GOPT
import numpy as np
import json
import pandas as pd 
import pickle

In [279]:
device = "cuda:0"
gopt = GOPT(embed_dim=24, num_heads=1, depth=3, input_dim=84)

state_dict = torch.load('exp/models/best_audio_model.pth', map_location='cpu')
gopt.load_state_dict(state_dict)
gopt.eval()
gopt = gopt.to(device)

In [280]:
def load_scaler(path):
    with open(path, "rb") as f:
        scaler = pickle.load(f)

    return scaler

scaler = load_scaler(path="resources/scaler.pkl")

In [281]:
input_feat = np.load("data/seq_data_librispeech/if_feat.npy")
input_phn = np.load("data/seq_data_librispeech/if_label.npy")

normed_feat = scaler.transform(input_feat[0])
input_feat = torch.from_numpy(normed_feat).unsqueeze(0)
phoneme_length = np.sum(input_phn!=-1)

In [282]:
with torch.no_grad():
    t_input_feat = input_feat.to(device)
    t_phn = torch.from_numpy(input_phn[:,:,0]).to(device)

    utt_score, phn_score, wrd_score = gopt(t_input_feat.float(),t_phn.float())

In [283]:
phn_score = phn_score.view(-1)[0: phoneme_length] * 50
utt_score = utt_score.view(-1) * 50
wrd_score = wrd_score.view(-1)[0:phoneme_length] * 50
print(phn_score.shape, utt_score.shape, wrd_score.shape)

torch.Size([7]) torch.Size([1]) torch.Size([7])


In [284]:
def load_lexicon(path="librispeech-lexicon.txt"):
    with open(path, 'r') as f:
        lexicon_raw = f.read()
        rows = lexicon_raw.splitlines()
    clean_rows = [row.split() for row in rows]
    lexicon_dict_l = dict()
    for row in clean_rows:
        c_row = row.copy()
        key = c_row.pop(0)
        if len(c_row) == 1:
            c_row[0] = c_row[0] + '_S'
        if len(c_row) >= 2:
            c_row[0] = c_row[0] + '_B'
            c_row[-1] = c_row[-1] + '_E'
        if len(c_row) > 2:
            for i in range(1,len(c_row)-1):
                c_row[i] = c_row[i] + '_I'
        val = " ".join(c_row)
        lexicon_dict_l[key] = val
    return lexicon_dict_l

lexicon_path = "resources/lexicon.txt"
lexicon_dict_l = load_lexicon(lexicon_path)

with open("resources/phoneme_dict.json", "r") as f:
    phone2id = json.load(f)
id2phone = {value:key for key, value in phone2id.items()}

In [285]:
phn_id = t_phn.view(-1)
phone_with_scores = -1 * np.ones((phoneme_length, 2))

for i in range(phone_with_scores.shape[0]):
    phone_with_scores[i, 0] = phn_id[i]
    phone_with_scores[i, 1] = phn_score[i]

In [286]:
text = "IT TWO ONE"
words = text.split(' ')
path = "egs/gop_speechocean762/s5/data/local/text-phone"

phone_df = pd.read_csv(path, sep="\t", names=["word_id", "phonemes"], dtype={"word_id":str})
print(phone_df.head())
phone_df.word_id = phone_df.word_id.apply(lambda x: x.split(".")[-1])
phone_df.phonemes = phone_df.phonemes.apply(lambda x: x.split(" "))
phone_df = phone_df.explode(column="phonemes").reset_index()
phone_df.phonemes = phone_df.phonemes.apply(lambda x: x.split("_")[0])
phone_df.head()

       word_id       phonemes
0  000940032.0      IH1_B T_E
1  000940032.1      T_B UW1_E
2  000940032.2  W_B AH1_I N_E


Unnamed: 0,index,word_id,phonemes
0,0,0,IH1
1,0,0,T
2,1,1,T
3,1,1,UW1
4,2,2,W


In [287]:
word_ids = torch.tensor([int(i) for i in phone_df["word_id"].to_list()])

In [288]:
import torch.nn.functional as F

one_hot = F.one_hot(word_ids, num_classes=word_ids.max().item()+1).float()
one_hot = one_hot / one_hot.sum(0, keepdim=True)
word_scores = torch.matmul(one_hot.transpose(0, 1), phn_score.cpu())

In [289]:
"""
sample = {
    "version": None,
    "utterance":
        {
            "text": "...",
            "score": ...
            "words": [
                {
                    "text": "...",
                    "score": ...
                    "phonemes":[
                        {
                            "text": "...",
                            "score": ...
                        },
                        {
                            "text": "...",
                            "score": ...
                        },
                    ]
                }

                {
                    "text": "...",
                    "score": ...
                    "phonemes":[
                        {
                            "text": "...",
                            "score": ...
                        },
                        {
                            "text": "...",
                            "score": ...
                        },
                    ]
                }
            ]
        }
}
"""

'\nsample = {\n    "version": None,\n    "utterance":\n        {\n            "text": "...",\n            "score": ...\n            "words": [\n                {\n                    "text": "...",\n                    "score": ...\n                    "phonemes":[\n                        {\n                            "text": "...",\n                            "score": ...\n                        },\n                        {\n                            "text": "...",\n                            "score": ...\n                        },\n                    ]\n                }\n\n                {\n                    "text": "...",\n                    "score": ...\n                    "phonemes":[\n                        {\n                            "text": "...",\n                            "score": ...\n                        },\n                        {\n                            "text": "...",\n                            "score": ...\n                        },\n  

In [290]:
curr_word_id, prev_word_id = -1, -1
_tmp_word = {
        "text": words[0],
        "score": word_scores[0].item(),
        "phonemes": []
    }

scores = {
    "version": "None",
    "utterance": {
        "text": text,
        "score": utt_score.item(),
        "words": []
    }
}
for index in range(len(phone_with_scores)):
    word_id = int(phone_df["word_id"][index])

    _tmp_phone = {
        "text": id2phone[int(phone_with_scores[index][0])],
        "score": phone_with_scores[index][1].item(),
    }

    if word_id == curr_word_id:
        scores["utterance"]["words"][word_id]["phonemes"].append(
            _tmp_phone
        )
    else:
        _tmp_word = {
            "text": words[word_id],
            "score": word_scores[word_id].item(),
            "phonemes": []
        }

        _tmp_word["phonemes"].append(_tmp_phone)
        if len(scores["utterance"]["words"]) == word_id:
            scores["utterance"]["words"].append(_tmp_word)
        else:
            scores["utterance"]["words"][word_id]["phonemes"].append(
                _tmp_phone
            )

In [291]:
scores

{'version': 'None',
 'utterance': {'text': 'IT TWO ONE',
  'score': 69.40826416015625,
  'words': [{'text': 'IT',
    'score': 48.026214599609375,
    'phonemes': [{'text': 'IH', 'score': 10.327770233154297},
     {'text': 'T', 'score': 85.72466278076172}]},
   {'text': 'TWO',
    'score': 94.78627014160156,
    'phonemes': [{'text': 'T', 'score': 99.65254211425781},
     {'text': 'UW', 'score': 89.92000579833984}]},
   {'text': 'ONE',
    'score': 81.3423080444336,
    'phonemes': [{'text': 'W', 'score': 90.68363952636719},
     {'text': 'AH', 'score': 70.45494842529297},
     {'text': 'N', 'score': 82.88832092285156}]}]}}

In [292]:
with open("result.json", "w", encoding="utf-8") as f:
    json_obj = json.dumps(scores, indent=4, ensure_ascii=False)
    f.write(json_obj)