In [1]:
import os
import re
import random
from argparse import Namespace

import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import wandb

from tqdm import tqdm
import torchaudio
import numpy as np

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import textgrids
import sentencepiece as spm

from fairseq.data import Dictionary
from fairseq.models.mST.w2v2_phone_transformer import W2V2Transformer
from fairseq.data.audio.multilingual_triplet_v2_phone_dataset import (
    MultilingualTripletDataConfig,
    MultilingualTripletDataset,
    MultilingualTripletDatasetCreator
)
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform, _collate_frames
from examples.speech_to_text.data_utils import load_df_from_tsv
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE, SentencepieceConfig
from examples.speech_recognition.data.data_utils import padding_mask_to_lengths

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sns.set(font="Noto Sans CJK JP")
device = 'cuda:7'

In [3]:
zh_root = '/mnt/data/siqiouyang/datasets/covost2/zh-CN'

# Produce Word Segmentation Dataset

## Load Model Checkpoint

In [4]:
args = Namespace()
task = Namespace()

In [5]:
def load_dict(vocab_filename):
    _dict_path = vocab_filename
    if not os.path.isfile(_dict_path):
        raise FileNotFoundError(f"Dict not found: {_dict_path}")
    _dict = Dictionary.load(_dict_path)
    for code in codes:
        _dict.add_symbol(MultilingualTripletDataset.LANG_TAG_TEMPLATE.format(code))
    _dict.add_symbol('<mask>')
    return _dict

In [6]:
lang_list_filename = '/mnt/data/siqiouyang/runs/mST/pretrained/mbart50.ft.n1/ML50_langs.txt'
vocab_filename = '/mnt/data/siqiouyang/runs/mST/pretrained/mbart50.ft.n1/dict.txt'
phone_vocab_filename = '/mnt/data/siqiouyang/datasets/covost2/phone_dict.txt'

In [7]:
codes = MultilingualTripletDataset.get_lang_codes(lang_list_filename)
dict = load_dict(vocab_filename)
with open(phone_vocab_filename, 'r') as r:
    phone_list = [l.strip() for l in r.readlines() if l.strip() != '']
    phone_dict = {l: idx + 1 for idx, l in enumerate(phone_list)} # leave 0 as blank
    phone_list = ['|'] + phone_list

In [8]:
task.src_dict = task.tgt_dict = dict
task.phone_dict = phone_dict

In [9]:
args.w2v2_model_path = '/mnt/data/siqiouyang/runs/mST/pretrained/xlsr2_300m.pt'
args.mbart50_dir = '/mnt/data/siqiouyang/runs/mST/pretrained/mbart50.ft.n1'

In [10]:
model = W2V2Transformer.build_model(args, task)

In [14]:
model.eval()
model.training

False

In [14]:
# model.encoder.w2v2_model.encoder.layers[15:]
# model.encoder.w2v2_model.encoder.layer_norm
# model.encoder.text_embedding
# model.encoder.transformer_encoder.embed_tokens
# model.encoder.transformer_encoder.embed_positions
# model.encoder.transformer_encoder.layernorm_embedding
# model.encoder.transformer_encoder.layers[:4]

SinusoidalPositionalEmbedding()

In [11]:
ckpt_path = '/mnt/data/siqiouyang/runs/mST/xlsr_phone_mbart_de_zh/checkpoint_best.pt'
ckpt = load_checkpoint_to_cpu(ckpt_path)

In [12]:
model.load_state_dict(ckpt["model"], strict=False)
model = model.to(device)
model.eval()

W2V2Transformer(
  (encoder): W2V2PhoneTransformerEncoder(
    (w2v2_model): Wav2Vec2Model(
      (feature_extractor): ConvFeatureExtractionModel(
        (conv_layers): ModuleList(
          (0): Sequential(
            (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
            (1): Dropout(p=0.0, inplace=False)
            (2): Sequential(
              (0): TransposeLast()
              (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (2): TransposeLast()
            )
            (3): GELU()
          )
          (1): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
            (1): Dropout(p=0.0, inplace=False)
            (2): Sequential(
              (0): TransposeLast()
              (1): Fp32LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (2): TransposeLast()
            )
            (3): GELU()
          )
          (2): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,

## Build Word Segmentation Dataset

In [13]:
class WSDataset(Dataset):
    def __init__(self, root, split):
        df = load_df_from_tsv(os.path.join(root, '{}_st_zh-CN_en.tsv'.format(split)))

        self.audio_paths = []
        self.transcripts = []
        self.segmentations = []

        for id in tqdm(df['id']):
            if os.path.exists(os.path.join(root, '16kHz/align', '{}.TextGrid'.format(id))):
                audio_path = os.path.join(root, '16kHz', '{}.wav'.format(id))
                self.audio_paths.append(audio_path)
                grids = textgrids.TextGrid('{}/16kHz/align/{}.TextGrid'.format(zh_root, id))
                duration = torchaudio.info(audio_path).num_frames / 16000
                intervals = th.tensor([(grid.xmin, grid.xmax) for grid in grids['words'] if grid.text != '']) / duration
                
                tokens = [grid.text for grid in grids['words'] if grid.text != '']

                self.segmentations.append(intervals)
                self.transcripts.append(tokens)

    def __getitem__(self, idx):
        return self.audio_paths[idx], self.transcripts[idx], self.segmentations[idx]

    def __len__(self):
        return len(self.audio_paths)

In [14]:
train_dataset = WSDataset(zh_root, 'train')
dev_dataset = WSDataset(zh_root, 'dev')

100%|██████████| 7085/7085 [00:03<00:00, 2307.32it/s]
100%|██████████| 4843/4843 [00:02<00:00, 2387.87it/s]


In [65]:
def ws_collate_fn(samples):
    audio_paths = [ap for ap, _, _ in samples]
    sources = [
        th.from_numpy(get_features_or_waveform(
            ap,
            need_waveform=True,
            sample_rate=16000,
        )).float() for ap in audio_paths
    ]
    n_frames = th.tensor([source.size(0) for source in sources], dtype=th.long).to(device)
    frames = _collate_frames(sources, True).to(device)

    transcripts = [t for _, t, _ in samples]
    segmentations = [seg.to(device) for _, _, seg in samples]
    ntokens = sum([len(t) for t in transcripts])
    
    out = {
        "net_input": {
            "src_tokens": frames,
            "src_lengths": n_frames,
        },
        "transcripts": transcripts,
        "segmentations": segmentations,
        "ntokens": ntokens
    }
    return out

In [109]:
train_dataloader = DataLoader(train_dataset, batch_size=10, collate_fn=ws_collate_fn)
dev_dataloader = DataLoader(dev_dataset, batch_size=10, collate_fn=ws_collate_fn)

## Build ASR Dataset

In [13]:
sp = spm.SentencePieceProcessor()
sp.Load('/mnt/data/siqiouyang/runs/mST/pretrained/mbart50.ft.n1/sentence.bpe.model')

True

In [16]:
zh_train_df = load_df_from_tsv(os.path.join(zh_root, 'train_st_zh-CN_en.tsv'))
zh_dev_df = load_df_from_tsv(os.path.join(zh_root, 'dev_st_zh-CN_en.tsv'))
zh_test_df = load_df_from_tsv(os.path.join(zh_root, 'test_st_zh-CN_en.tsv'))

In [17]:
for df in [zh_train_df, zh_dev_df, zh_dev_df]:
    for id, src_text in zip(tqdm(df['id']), df['src_text']):
        tokenized = ' '.join(sp.EncodeAsPieces(src_text)).replace('▁', '')
        with open(os.path.join(zh_root, '16kHz', '{}.txt'.format(id)), 'w') as w:
            w.write(tokenized)

100%|██████████| 7085/7085 [00:00<00:00, 11741.51it/s]
100%|██████████| 4843/4843 [00:00<00:00, 11577.16it/s]
100%|██████████| 4843/4843 [00:00<00:00, 11283.10it/s]


In [14]:
def match(pieces, sp_ids, sp_pieces):
    j = 0
    ids = []
    for piece in pieces:
        while piece not in sp_pieces[j]:
            j += 1
            if j == len(sp_pieces):
                return -1
        ids.append(sp_ids[j])
        j += 1
    return ids

In [15]:
class ASRDataset(Dataset):
    def __init__(self, root, split):
        df = load_df_from_tsv(os.path.join(root, '{}_st_zh-CN_en.tsv'.format(split)))

        self.audio_paths = []
        self.segmentations = []
        self.tokenss = []

        for id, transcript in zip(tqdm(df['id']), df['src_text']):
            if os.path.exists(os.path.join(root, '16kHz/align_sp', '{}.TextGrid'.format(id))):
                audio_path = os.path.join(root, '16kHz', '{}.wav'.format(id))

                grids = textgrids.TextGrid('{}/16kHz/align_sp/{}.TextGrid'.format(zh_root, id))
                duration = torchaudio.info(audio_path).num_frames / 16000
                intervals = th.tensor([(grid.xmin, grid.xmax) for grid in grids['words'] if grid.text != '']) / duration
                
                pieces = [grid.text for grid in grids['words'] if grid.text != '']
                tokens = match(pieces, sp.Encode(transcript), sp.EncodeAsPieces(transcript))

                if tokens != -1:
                    self.segmentations.append(intervals)
                    self.audio_paths.append(audio_path)
                    self.tokenss.append(tokens)

                    assert len(tokens) == intervals.size(0)
                else:
                    print(transcript, pieces)
    
    def __getitem__(self, idx):
        return self.audio_paths[idx], self.segmentations[idx], self.tokenss[idx]

    def __len__(self):
        return len(self.audio_paths)

In [16]:
train_dataset = ASRDataset(zh_root, 'train')
dev_dataset = ASRDataset(zh_root, 'dev')

 14%|█▍        | 1008/7085 [00:00<00:01, 3266.66it/s]

代表作品有《明星金钟》、《脸赞时代》。 ['代表', '作品', '有', '《', '明星', '金', '钟', '》', '《', '脸', '赞', '时代', '》']
萧家怡，居港澳门作家，发表文章于《立场新闻》、《评台》、《香港独立媒体》。 ['萧', '家', '怡', '居', '港', '澳门', '作家', '发表', '文章', '于', '《', '立场', '新闻', '》', '《', '评', '台', '》', '《', '香港', '独立', '媒体', '》']


 29%|██▉       | 2068/7085 [00:00<00:01, 3435.09it/s]

同时他还刊有《海内奇观》、《图绘宗彝》等书籍。 ['同时', '他还', '刊', '有', '《', '海', '内', '奇', '观', '》', '《', '图', '绘', '宗', '彝', '》', '等', '书籍']


 82%|████████▏ | 5806/7085 [00:01<00:00, 3493.26it/s]

曾参演《绅士的品格》、《太阳的后裔》和《阳光先生》担任配角 ['曾', '参', '演', '《', '绅', '士', '的', '品', '格', '》', '《', '太阳', '的', '后', '裔', '》', '和', '《', '阳光', '先生', '》', '担任', '配', '角']
累官兵部右侍郎，蓟辽总督，削职有《汉书评》、《小史论》、《抚吴疏草》等。 ['累', '官兵', '部', '右', '侍', '郎', '蓟', '辽', '总', '督', '削', '职', '有', '《', '汉', '书', '评', '》', '《', '小', '史', '论', '》', '《', '抚', '吴', '疏', '草', '》', '等']


100%|██████████| 7085/7085 [00:01<00:00, 3566.67it/s]
 22%|██▏       | 1071/4843 [00:00<00:01, 3439.54it/s]

有《司空奏议》、《宣慈录》。 ['有', '《', '司', '空', '奏', '议', '》', '《', '宣', '慈', '录', '》']
大和和纪，日本漫画家，代表作有《窈窕淑女》、《源氏物语》等。 ['大', '和', '和', '纪', '日本', '漫画', '家', '代表', '作', '有', '《', '窈', '窕', '淑', '女', '》', '《', '源', '氏', '物', '语', '》', '等']
彭淮栋，台湾翻译家，译有《乡关何处》、《魔山》、《浮士德博士》、《美的历史》。 ['彭', '淮', '栋', '台湾', '翻译', '家', '译', '有', '《', '乡', '关', '何', '处', '》', '《', '魔', '山', '》', '《', '浮', '士', '德', '博士', '》', '《', '美', '的历史', '》']


 52%|█████▏    | 2527/4843 [00:00<00:00, 3567.60it/s]

Т̌ т̌是一个西里尔字母，由Т т与抑扬符组成。 ['т', '̌', 'т', '̌', '是一个', '西', '里', '尔', '字母', '由', 'т', 'т', '与', '抑', '扬', '符', '组成']
请问一下有不错的ㄟＰＰ吗 ['请问', '一下', '有', '不错的', 'ㄟ', 'pp', '吗']


 75%|███████▍  | 3629/4843 [00:01<00:00, 3604.23it/s]

И̃ и̃，是一个西里尔字母。 ['и', '̃', 'и', '̃', '是一个', '西', '里', '尔', '字母']


100%|██████████| 4843/4843 [00:01<00:00, 3432.83it/s]

代表作品为《暗花》、《两个只能活一个》、《非常突然》。 ['代表', '作品', '为', '《', '暗', '花', '》', '《', '两个', '只能', '活', '一个', '》', '《', '非常', '突然', '》']





In [18]:
token_mask = th.zeros(model.encoder.text_embedding.weight.size(0), dtype=th.bool)
for dataset in [train_dataset, dev_dataset]:
    for tokens in dataset.tokenss:
        for tok in tokens:
            token_mask[tok] = True
token_mask[-1] = True

In [20]:
def asr_collate_fn(samples):
    audio_paths = [ap for ap, _, _ in samples]
    sources = [
        th.from_numpy(get_features_or_waveform(
            ap,
            need_waveform=True,
            sample_rate=16000,
        )).float() for ap in audio_paths
    ]
    n_frames = th.tensor([source.size(0) for source in sources], dtype=th.long).to(device)
    frames = _collate_frames(sources, True).to(device)

    segmentations = [seg.to(device) for _, seg, _  in samples]

    tokens = [th.tensor(t, dtype=th.long).to(device) for _, _, t in samples]
    ntokens = sum([t.size(0) for t in tokens])
    
    out = {
        "net_input": {
            "src_tokens": frames,
            "src_lengths": n_frames,
        },
        "segmentations": segmentations,
        "tokens": tokens,
        "ntokens": ntokens
    }
    return out

In [21]:
train_dataloader = DataLoader(train_dataset, batch_size=10, collate_fn=asr_collate_fn)
dev_dataloader = DataLoader(dev_dataset, batch_size=10, collate_fn=asr_collate_fn)

## Losses

In [22]:
def compute_ws_loss(speech_encoder_out, inputs):
    speech_emb = speech_encoder_out['x']
    speech_emb = mlp(speech_emb)
    bsz = speech_emb.size(0)
    
    padding_mask = speech_encoder_out['padding_mask']
    lens = padding_mask_to_lengths(padding_mask)
    
    transcripts = inputs["transcripts"]
    segmentations = inputs["segmentations"]
    ntokens = inputs["ntokens"]

    loss = 0.

    for idx in range(bsz):
        length = lens[idx]
        seg = segmentations[idx]
        scaled_seg = (seg * length).long()

        sim_matrix = F.cosine_similarity(
            speech_emb[idx].unsqueeze(0),
            speech_emb[idx].unsqueeze(1),
            dim=-1
        ) / temp

        p_matrix = F.softmax(sim_matrix, dim=-1)

        for l, r in scaled_seg:
            loss = loss + -p_matrix[l : r, l : r].sum(dim=1).log().sum()
    
    return loss, ntokens

In [34]:
def compute_asr_loss(speech_encoder_out, inputs):
    speech_embs = speech_encoder_out['x']
    speech_embs = mlp(speech_embs)

    logits = th.matmul(speech_embs, model.encoder.text_embedding.weight.T.detach())

    logits[:, :, ~token_mask] = -1e4

    bsz = logits.size(0)

    padding_mask = speech_encoder_out['padding_mask']
    lens = padding_mask_to_lengths(padding_mask)

    segmentations = inputs['segmentations']

    cat_logits = []
    cat_labels = []

    blank_label = logits.size(-1) - 1

    batch_tokens = inputs['tokens']

    for idx in range(bsz):
        length = lens[idx]
        cat_logits.append(logits[idx, :length])

        labels = th.zeros(length, dtype=th.long).to(device) + blank_label

        tokens = batch_tokens[idx]

        seg = segmentations[idx]
        scaled_seg = (seg * length).long()
        for token, (l, r) in zip(tokens, scaled_seg):
            labels[l : r] = token
        
        cat_labels.append(labels)

    flat_logits = th.cat(cat_logits, dim=0)
    flat_labels = th.cat(cat_labels, dim=0)

    assert token_mask[flat_labels].all()

    loss = F.cross_entropy(flat_logits, flat_labels, reduction='sum')

    ntokens = flat_logits.size(0)

    return loss, ntokens

## Model Architecture

In [25]:
class MLP(nn.Module):
    def __init__(self, n_input, n_hidden, n_output, n_layer):
        super(MLP, self).__init__()

        self.layers = []
        for idx in range(n_layer):
            self.layers.append(nn.Linear(n_hidden if idx > 0 else n_input, n_hidden))
        self.layers = nn.ModuleList(self.layers)

        self.final_proj = nn.Linear(n_hidden, n_output)
        self.relu = nn.ReLU()

    def forward(self, x):
        for layer in self.layers:
            x = self.relu(layer(x))
        x = self.final_proj(x)
        return x

In [38]:
mlp = MLP(1024, 2048, 1024, 2).to(device)
optimizer = th.optim.Adam(mlp.parameters(), lr=1e-3)
temp = 0.05

## Training and Evaluation

In [39]:
wandb.init(project="ST", entity="owaski", name='test-1e-3-dict')

0,1
train_loss,██▇▆▅▃▃▂▃▂▁▂

0,1
train_loss,8.66283


In [28]:
loss_fn = compute_asr_loss

In [40]:
def eval(dataloader):
    mlp.eval()
    iterator = tqdm(dataloader)
    sum_loss = 0.
    sum_ntokens = 0
    with th.no_grad():
        for inputs in iterator:
            speech_encoder_out = model.encoder.forward_speech(**inputs["net_input"])
            loss, ntokens = loss_fn(speech_encoder_out, inputs)                
            sum_loss += loss
            sum_ntokens += ntokens
    print('eval loss {:.2f}'.format(sum_loss / sum_ntokens))
    wandb.log({'eval_loss': sum_loss / sum_ntokens})

In [41]:
def run_epoch(dataloader):
    mlp.train()
    iterator = tqdm(dataloader)

    sum_loss = 0.
    sum_ntokens = 0
    for inputs in iterator:
        with th.no_grad():
            speech_encoder_out = model.encoder.forward_speech(**inputs["net_input"])

        optimizer.zero_grad()
        
        loss, ntokens = loss_fn(speech_encoder_out, inputs)

        wandb.log({'train_loss': loss / ntokens})
            
        sum_loss += loss.item()
        sum_ntokens += ntokens

        loss = loss / ntokens

        loss.backward()
        optimizer.step()
        
        iterator.set_description('train_loss: {:.2f}'.format(loss.item()))
    print('train loss {:.2f}'.format(sum_loss / sum_ntokens))

In [42]:
n_epoch = 100
for _ in range(n_epoch):
    run_epoch(train_dataloader)
    eval(dev_dataloader)

train_loss: 6.15: 100%|██████████| 708/708 [06:17<00:00,  1.88it/s]


train loss 6.38


100%|██████████| 482/482 [01:59<00:00,  4.04it/s]


eval loss 5.39


train_loss: 5.26: 100%|██████████| 708/708 [06:19<00:00,  1.86it/s]


train loss 4.87


100%|██████████| 482/482 [01:58<00:00,  4.07it/s]


eval loss 4.80


train_loss: 4.46:  68%|██████▊   | 480/708 [04:23<02:05,  1.82it/s]


KeyboardInterrupt: 

In [32]:
th.cuda.empty_cache()