In [1]:
import os
from argparse import Namespace

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from fairseq.data import Dictionary
from fairseq.models.mST.w2v2_phone_transformer import W2V2Transformer
from fairseq.data.audio.multilingual_triplet_v2_phone_dataset import (
    MultilingualTripletDataConfig,
    MultilingualTripletDataset,
    MultilingualTripletDatasetCreator
)
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
from examples.speech_to_text.data_utils import load_df_from_tsv
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE, SentencepieceConfig

# Load checkpoints

In [2]:
args = Namespace()
task = Namespace()

In [3]:
def load_dict(vocab_filename):
    _dict_path = vocab_filename
    if not os.path.isfile(_dict_path):
        raise FileNotFoundError(f"Dict not found: {_dict_path}")
    _dict = Dictionary.load(_dict_path)
    for code in codes:
        _dict.add_symbol(MultilingualTripletDataset.LANG_TAG_TEMPLATE.format(code))
    _dict.add_symbol('<mask>')
    return _dict

In [4]:
lang_list_filename = '/mnt/raid5/siqi/checkpoints/pretrained/mbart50.ft.n1/ML50_langs.txt'
vocab_filename = '/mnt/raid5/siqi/checkpoints/pretrained/mbart50.ft.n1/dict.txt'
phone_vocab_filename = '/mnt/raid5/siqi/datasets/covost2/phone_dict.txt'

In [5]:
codes = MultilingualTripletDataset.get_lang_codes(lang_list_filename)
dict = load_dict(vocab_filename)
with open(phone_vocab_filename, 'r') as r:
    phone_list = [l.strip() for l in r.readlines() if l.strip() != '']
    phone_dict = {l: idx + 1 for idx, l in enumerate(phone_list)} # leave 0 as blank
    phone_list = ['|'] + phone_list

In [6]:
task.src_dict = task.tgt_dict = dict
task.phone_dict = phone_dict

In [7]:
args.w2v2_model_path = '/mnt/raid5/siqi/checkpoints/pretrained/xlsr2_300m.pt'
args.mbart50_dir = '/mnt/raid5/siqi/checkpoints/pretrained/mbart50.ft.n1'

In [9]:
model = W2V2Transformer.build_model(args, task)

In [10]:
ckpt_path = '/mnt/raid5/siqi/checkpoints/xlsr_phone_mbart_de_zh_ctc_phone_text/checkpoint_best.pt'
ckpt = load_checkpoint_to_cpu(ckpt_path)

In [11]:
model.load_state_dict(ckpt["model"])

<All keys matched successfully>

# Inference

In [12]:
with open(lang_list_filename, 'r') as r:
    lang_codes = [line.strip() for line in r.readlines() if line.strip() != ""]
lang2langcode = {code[:2] : code for code in lang_codes}

In [13]:
spm_config = SentencepieceConfig(sentencepiece_model='/mnt/raid0/siqi/checkpoints/pretrained/mbart50.ft.n1/sentence.bpe.model')
spm = SentencepieceBPE(spm_config)

In [14]:
data_root = '/mnt/raid5/siqi/datasets/covost2'

In [17]:
df = load_df_from_tsv('/mnt/raid5/siqi/datasets/covost2/zh-CN/train_st_zh-CN_en_phone.tsv')

In [18]:
idx = 0
df.iloc[idx]

id                                common_voice_zh-CN_18536372
audio                         common_voice_zh-CN_18536372.wav
n_frames                                                88704
src_text                                对于更高阶的导数，我们可以继续同样的过程。
tgt_text    For derivatives of higher order, we can use th...
speaker     c69453a9ae8cdd8ce47e767209e28e254719861de178ac...
src_lang                                                zh-CN
tgt_lang                                                   en
phones      t uə ɤɐ̞ ɪ ɥ uə k ɤɐ̞ ŋ k æ o tɕ ɪ ɤɐ̞ t ɪ p æ...
Name: 0, dtype: object

In [20]:
source = get_features_or_waveform(
    os.path.join(data_root, 'zh-CN', '16kHz', df.iloc[idx]['audio']),
    need_waveform=True,
    sample_rate=16000,
)
source = th.from_numpy(source).float()

lang_tag = "<lang:{}>".format(lang2langcode['zh'])
src_lang_tag_idx = dict.index(lang_tag)

In [21]:
src_tokens = source.unsqueeze(0)
src_lengths = th.LongTensor([source.size(0)])
src_lang_tag_indices = th.LongTensor([src_lang_tag_idx]).unsqueeze(-1)

In [22]:
encoder_out = model.forward_encoder(src_tokens, src_lengths, src_lang_tag_indices=src_lang_tag_indices)

In [23]:
print(*[phone_list[tok] for tok in encoder_out.phone_logp.argmax(dim=-1).squeeze().unique_consecutive()])
print(df.iloc[idx]['phones'])
print(df.iloc[idx]['src_text'])

| t uə ɤɐ̞ ɪ | ɥ uə | k | ɤɐ̞ | ŋ | k | æ o | tɕ ɪ ɤɐ̞ | t ɪ | t | æ | o | ʂ | uə | o | m | ɤɐ̞ | n | kʰ | ɤɐ̞ | i | tɕ | ɪ | ɕ | y | tʰ | o | ŋ | j | æ | ŋ | t ɪ | k | uə o | tʂʰ | ɤɐ̞ |
t uə ɤɐ̞ ɪ ɥ uə k ɤɐ̞ ŋ k æ o tɕ ɪ ɤɐ̞ t ɪ p æ o ʂ uə w uə o m ɤɐ̞ n kʰ ɤɐ̞ i tɕ ɪ ɕ y tʂʰ o ŋ l j æ ŋ t ɪ k uə o tʂʰ ɤɐ̞ ŋ
对于更高阶的导数，我们可以继续同样的过程。


In [24]:
print(*[dict[tok] for tok in encoder_out.text_logp.argmax(dim=-1).squeeze().unique_consecutive()])
print(spm.encode(df.iloc[idx]['src_text']))
print(df.iloc[idx]['src_text'])

<lang:zh_CN> ▁ 对 <mask> 于 <mask> 高 <mask> 的 <mask> 可以 <mask> 集 <mask> 的过程 <mask> 。 </s>
▁对于 更高 阶 的 导 数 , 我们可以 继续 同样 的过程 。
对于更高阶的导数，我们可以继续同样的过程。
