In [1]:
import torch
import sys
import os
from model import GOPT
import numpy as np
import json
import pandas as pd 
import pickle

In [2]:
device = "cuda:0"
gopt = GOPT(embed_dim=24, num_heads=1, depth=3, input_dim=84)

state_dict = torch.load('exp/models/best_audio_model.pth', map_location='cpu')
gopt.load_state_dict(state_dict)
gopt.eval()
gopt = gopt.to(device)

In [3]:
def load_scaler(path):
    with open(path, "rb") as f:
        scaler = pickle.load(f)

    return scaler

scaler = load_scaler(path="resources/scaler.pkl")

In [4]:
input_feat = np.load("data/seq_data_librispeech/if_feat.npy")
input_phn = np.load("data/seq_data_librispeech/if_label.npy")

normed_feat = scaler.transform(input_feat[0])
input_feat = torch.from_numpy(normed_feat).unsqueeze(0)
phoneme_length = np.sum(input_phn!=-1)

In [5]:
with torch.no_grad():
    input_feats = input_feat.to(device)
    input_phone_ids = torch.from_numpy(input_phn[:,:,0]).to(device)

    utt_score, phn_scores, wrd_scores = gopt(input_feats.float(),input_phone_ids.float())

In [6]:
phn_scores = phn_scores.view(-1)[0: phoneme_length] * 50
utt_score = utt_score.view(-1) * 50
wrd_scores = wrd_scores.view(-1)[0:phoneme_length] * 50
print(wrd_scores.shape, utt_score.shape, wrd_scores.shape)

torch.Size([47]) torch.Size([1]) torch.Size([47])


In [7]:
def load_lexicon(path="librispeech-lexicon.txt"):
    with open(path, 'r') as f:
        lexicon_raw = f.read()
        rows = lexicon_raw.splitlines()
    clean_rows = [row.split() for row in rows]
    lexicon_dict_l = dict()
    for row in clean_rows:
        c_row = row.copy()
        key = c_row.pop(0)
        if len(c_row) == 1:
            c_row[0] = c_row[0] + '_S'
        if len(c_row) >= 2:
            c_row[0] = c_row[0] + '_B'
            c_row[-1] = c_row[-1] + '_E'
        if len(c_row) > 2:
            for i in range(1,len(c_row)-1):
                c_row[i] = c_row[i] + '_I'
        val = " ".join(c_row)
        lexicon_dict_l[key] = val
    return lexicon_dict_l

lexicon_path = "resources/lexicon.txt"
lexicon_dict_l = load_lexicon(lexicon_path)

with open("resources/phoneme_dict.json", "r") as f:
    phone2id = json.load(f)
id2phone = {value:key for key, value in phone2id.items()}

In [8]:
def load_force_alignment_result(alignment_path, phones_path):
    alignment = pd.read_csv(alignment_path, sep="\s", names=["file_utt","utt","start","duration","id"], engine='python')
    
    id2phoneme = pd.read_csv(phones_path, sep="\s", names=["phonemes", "id"], engine='python')
    id2phoneme = id2phoneme.set_index(keys="id").to_dict()["phonemes"]
    alignment["phonemes"] = alignment.id.apply(lambda x: id2phoneme[int(x)])
    
    return alignment

alignment_path = "/data/codes/prep_gopt/egs/gop_speechocean762/s5/exp/ali_infer/merged_alignment.txt"
phones_path = "/data/codes/prep_gopt/egs/gop_speechocean762/s5/data/lang_nosp/phones.txt"

alignment = load_force_alignment_result(alignment_path=alignment_path, phones_path=phones_path)
alignment.phonemes = alignment.phonemes.apply(lambda x: x.split(" "))
alignment = alignment.explode(column="phonemes").reset_index()
alignment["pure_phonemes"] = alignment.phonemes.apply(lambda x: x.split("_")[0])
alignment = alignment[alignment.phonemes != "SIL"].reset_index()
alignment.head()

Unnamed: 0,level_0,index,file_utt,utt,start,duration,id,phonemes,pure_phonemes
0,1,1,237-126133-0013,1,0.42,0.14,102,AY1_S,AY1
1,2,2,237-126133-0013,1,0.56,0.1,231,N_B,N
2,3,3,237-126133-0013,1,0.66,0.31,248,OW1_E,OW1
3,5,5,237-126133-0013,1,1.21,0.11,175,G_B,G
4,6,6,237-126133-0013,1,1.32,0.16,37,AE1_I,AE1


In [9]:
text = "JUMPER"
words = text.split(' ')
path = "egs/gop_speechocean762/s5/data/local/text-phone"

text_phone_df = pd.read_csv(path, sep="\t", names=["word_id", "phonemes"], dtype={"word_id":str})

text_phone_df.word_id = text_phone_df.word_id.apply(lambda x: x.split(".")[-1])
text_phone_df.phonemes = text_phone_df.phonemes.apply(lambda x: x.split(" "))
text_phone_df = text_phone_df.explode(column="phonemes").reset_index()
text_phone_df["pure_phonemes"] = text_phone_df.phonemes.apply(lambda x: x.split("_")[0])
text_phone_df.head()

Unnamed: 0,index,word_id,phonemes,pure_phonemes
0,0,0,AY1_S,AY1
1,1,1,N_B,N
2,1,1,OW1_E,OW1
3,2,2,G_B,G
4,2,2,AE1_I,AE1


In [10]:
joined_metadata = pd.concat([text_phone_df, alignment[["start", "duration", "phonemes", "pure_phonemes"]]], axis=1)
joined_metadata.head()

Unnamed: 0,index,word_id,phonemes,pure_phonemes,start,duration,phonemes.1,pure_phonemes.1
0,0,0,AY1_S,AY1,0.42,0.14,AY1_S,AY1
1,1,1,N_B,N,0.56,0.1,N_B,N
2,1,1,OW1_E,OW1,0.66,0.31,OW1_E,OW1
3,2,2,G_B,G,1.21,0.11,G_B,G
4,2,2,AE1_I,AE1,1.32,0.16,AE1_I,AE1


In [11]:
def validate_phoneme(phoneme_1, phoneme_2):
    assert phoneme_1 == phoneme_2

def validate_pure_phoneme(pure_phoneme_1, pure_phoneme_2):
    assert pure_phoneme_1 == pure_phoneme_2

columns = joined_metadata.columns
assert columns[2] == columns[6]
assert columns[3] == columns[7]

joined_metadata.apply(lambda x: validate_phoneme(x[2], x[6]), axis=1)
joined_metadata.apply(lambda x: validate_pure_phoneme(x[3], x[7]), axis=1)

0     None
1     None
2     None
3     None
4     None
5     None
6     None
7     None
8     None
9     None
10    None
11    None
12    None
13    None
14    None
15    None
16    None
17    None
18    None
19    None
20    None
21    None
22    None
23    None
24    None
25    None
26    None
27    None
28    None
29    None
30    None
31    None
32    None
33    None
34    None
35    None
36    None
37    None
38    None
39    None
40    None
41    None
42    None
43    None
44    None
45    None
46    None
dtype: object

In [12]:
word_ids = torch.tensor([int(i) for i in text_phone_df["word_id"].to_list()])
assert input_phone_ids.shape[0] == 1
input_phone_ids = input_phone_ids.view(-1)
phone_ids = input_phone_ids[input_phone_ids != -1]

In [13]:
import torch.nn.functional as F

one_hot = F.one_hot(word_ids, num_classes=word_ids.max().item()+1).float()
one_hot = one_hot / one_hot.sum(0, keepdim=True)
word_scores = torch.matmul(one_hot.transpose(0, 1), phn_scores.cpu())

In [14]:
from dataclasses import dataclass
from typing import List

@dataclass
class Phoneme:
    arpabet: str
    start_time: float
    end_time: float
    score: float
    
    def __init__(self, arpabet, start_time, end_time, score):
        self.arpabet = arpabet
        self.start_time = start_time
        self.end_time = end_time
        self.score = score
    
    def to_dict(self):
        return {
            "arpabet": self.arpabet,
            "start_time": self.start_time,
            "end_time": self.end_time,
            "score": self.score
        }
        
@dataclass
class Word:
    text: str
    arpabet: str
    start_time: float
    phonemes: List[Phoneme]
    end_time: float
    score: float
    
    def __init__(self, text, arpabet, start_time, end_time, score, phonemes):
        self.arpabet = arpabet
        self.text = text
        self.start_time = start_time
        self.end_time = end_time
        self.score = score
        self.phonemes = phonemes
    
    def append_phone(self, phone):
        self.phonemes.append(phone)
    
    def to_dict(self):
        return {
            "text": self.text,
            "arpabet": self.arpabet,
            "start_time": self.start_time,
            "end_time": self.end_time,
            "score": self.score,
            "phonemes": [phoneme.to_dict() for phoneme in self.phonemes]

        }

@dataclass
class Sentence:
    arpabet: str
    duration: float
    score: float
    words: List[Word]
    
    def __init__(self, text, arpabet, duration, score, words):
        self.arpabet = arpabet
        self.duration = duration
        self.text = text
        self.score = score
        self.words = words
        
    def append_word(self, word):
        self.words.append(word)
    
    def append_phoneme(self, word_index, phoneme):
        self.words[word_index].append_phone(phoneme)
        self.words[word_index].end_time = phoneme.end_time 
    
    def to_dict(self):
        return {
            "text": self.text,
            "arpabet": self.arpabet,
            "duration": self.duration,
            "phonemes": [word.to_dict() for word in self.words],
            "score": self.score
        }


In [16]:
def parse_score(transcript, utt_score, word_scores, phn_scores):
    words = transcript.split(" ")
    _tmp_word = Word(
        text=words[0], 
        arpabet=None,
        score=word_scores[0].item(), 
        end_time=0, 
        start_time=0, 
        phonemes=[]
    )
    
    utterance = Sentence(
        text=transcript,
        arpabet=None, 
        score=utt_score.item(), 
        duration=0.0, 
        words=[],
    )
    curr_word_id = -1
    for index in range(len(phone_ids)):
        word_id = int(joined_metadata["word_id"][index])
        start_time = joined_metadata["start"][index]
        end_time = start_time + joined_metadata["duration"][index]
        
        _tmp_phone = Phoneme(
            arpabet=id2phone[int(phone_ids[index])],
            end_time=end_time, start_time=start_time,
            score=phn_scores[index].item()
        )

        if word_id == curr_word_id:
            utterance.append_phoneme(word_index=word_id, phoneme=_tmp_phone)
        else:
            _tmp_word = Word(
                arpabet=None, start_time=_tmp_phone.start_time, end_time=_tmp_phone.end_time, text=words[word_id],
                score=word_scores[word_id].item(), phonemes=[_tmp_phone, ]
            )
            if len(utterance.words) == word_id:
                utterance.append_word(_tmp_word)
            else:
                utterance.append_phoneme(word_index=word_id, phoneme=_tmp_phone)
    scores = {
        "version": "None",
        "utterance": utterance.to_dict()
        }
    return scores

scores = parse_score(
    transcript=text, utt_score=utt_score, word_scores=word_scores, phn_scores=phn_scores)

IndexError: list index out of range

In [None]:
# def parse_score(transcript, utt_score, word_scores, phn_scores):
#     curr_word_id = -1
    
#     _tmp_word = {
#             "text": words[0],
#             "score": word_scores[0].item(),
#             "phonemes": []
#         }

#     scores = {
#         "version": "None",
#         "utterance": {
#             "text": text,
#             "score": utt_score.item(),
#             "words": []
#         }
#     }
#     for index in range(len(phone_ids)):
#         word_id = int(phone_df["word_id"][index])

#         _tmp_phone = {
#             "text": id2phone[int(phone_ids[index])],
#             "score": phn_scores[index].item(),
#         }

#         if word_id == curr_word_id:
#             scores["utterance"]["words"][word_id]["phonemes"].append(
#                 _tmp_phone
#             )
#         else:
#             _tmp_word = {
#                 "text": words[word_id],
#                 "score": word_scores[word_id].item(),
#                 "phonemes": []
#             }

#             _tmp_word["phonemes"].append(_tmp_phone)
#             if len(scores["utterance"]["words"]) == word_id:
#                 scores["utterance"]["words"].append(_tmp_word)
#             else:
#                 scores["utterance"]["words"][word_id]["phonemes"].append(
#                     _tmp_phone
#                 )
    
#     return scores

# scores = parse_score(
#     transcript=text, utt_score=utt_score, word_scores=word_scores, phn_scores=phn_scores)

In [None]:
with open("result.json", "w", encoding="utf-8") as f:
    json_obj = json.dumps(scores, indent=4, ensure_ascii=False)
    f.write(json_obj)