In [1]:
import os
import re
import random
from argparse import Namespace
from string import punctuation
from collections import defaultdict

import IPython

import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import wandb

from tqdm import tqdm
import torchaudio
import numpy as np

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import sentencepiece as spm

import textgrids
import faiss

from sentence_transformers import SentenceTransformer

from fairseq.data import Dictionary
from fairseq.models.mST.w2v2_phone_transformer import W2V2Transformer
from fairseq.data.audio.multilingual_triplet_v2_phone_dataset import (
    MultilingualTripletDataConfig,
    MultilingualTripletDataset,
    MultilingualTripletDatasetCreator
)
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
from examples.speech_to_text.data_utils import load_df_from_tsv, save_df_to_tsv
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE, SentencepieceConfig

from fairseq.models.mST.w2v2_phone_transformer import W2V2Transformer
from fairseq.data.audio.multilingual_triplet_v2_phone_dataset import (
    MultilingualTripletDataConfig,
    MultilingualTripletDataset,
    MultilingualTripletDatasetCreator
)

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
device = 'cpu'
model = SentenceTransformer('sentence-transformers/LaBSE').to(device)

# Process CV-9.0 Zh

In [2]:
cv_zh_root = '/mnt/data/siqiouyang/datasets/cv-corpus-9.0-2022-04-27/zh-CN'

In [3]:
cv_zh_df = load_df_from_tsv(os.path.join(cv_zh_root, 'validated.tsv'))

In [4]:
cv_zh_df

Unnamed: 0,client_id,path,sentence,up_votes,down_votes,age,gender,accents,locale,segment
0,02ec74191c6ccc7dcf6ecaa217268263c477273b4de93f...,common_voice_zh-CN_22069600.mp3,宋朝末年年间定居粉岭围。,2,0,,,,zh-CN,
1,0431cf00d4491b99a93700d7aa0b1948a057b2c162a620...,common_voice_zh-CN_22006851.mp3,渐渐行动不便,2,0,,,,zh-CN,
2,04742f27bccab99619bd4ec3f256b36c639afd058c8664...,common_voice_zh-CN_22115132.mp3,二十一年去世。,2,0,,,,zh-CN,
3,0648def3862cbb968eec23fad967f50e35fc8e0eea67b4...,common_voice_zh-CN_22120171.mp3,他们自称恰哈拉。,2,0,,,,zh-CN,
4,0697ece1f99a08477906d0f3b4e74e1d6ffca76c20a7db...,common_voice_zh-CN_18646658.mp3,局部干涩的例子包括有口干、眼睛干燥、及阴道干燥。,2,1,,,,zh-CN,
...,...,...,...,...,...,...,...,...,...,...
48511,25bc975d06200b7b1c9135db090561cb0d9b28d172e51c...,common_voice_zh-CN_19717327.mp3,这被共和派和社会主义者称为一次巨大胜利。,2,0,thirties,female,出生地：31 上海市,zh-CN,
48512,25bc975d06200b7b1c9135db090561cb0d9b28d172e51c...,common_voice_zh-CN_19717330.mp3,汉默史密斯是伦敦的一大波兰人聚居地。,2,1,thirties,female,出生地：31 上海市,zh-CN,
48513,25bc975d06200b7b1c9135db090561cb0d9b28d172e51c...,common_voice_zh-CN_19717333.mp3,被处理的晶片试样放置于真空室中的样品架上。,2,0,thirties,female,出生地：31 上海市,zh-CN,
48514,25bc975d06200b7b1c9135db090561cb0d9b28d172e51c...,common_voice_zh-CN_19717335.mp3,曾连获三届国家新闻出版总署颁发的国家期刊奖。,2,1,thirties,female,出生地：31 上海市,zh-CN,


In [5]:
# for path in tqdm(cv_zh_df['path']):
#     audio_path = os.path.join(cv_zh_root, 'clips/{}'.format(path))
#     waveform, sample_rate = torchaudio.load(audio_path)
#     resampled_waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
#     torchaudio.save(os.path.join(cv_zh_root, '16kHz/{}.wav'.format(path[:-4])), resampled_waveform, sample_rate=16000)

In [6]:
sentences = cv_zh_df['sentence'].tolist()
sentences

['宋朝末年年间定居粉岭围。',
 '渐渐行动不便',
 '二十一年去世。',
 '他们自称恰哈拉。',
 '局部干涩的例子包括有口干、眼睛干燥、及阴道干燥。',
 '嘉靖三十八年，登进士第三甲第二名。',
 '这一名称一直沿用至今。',
 '阿列河畔贝赛。',
 '同时乔凡尼还得到包税合同和许多明矾矿的经营权。',
 '为了惩罚西扎城和塞尔柱的结盟，盟军在抵达后将外城烧毁。',
 '河内盛产黄色无鱼鳞的鳍射鱼。',
 '毛柄新木姜子为樟科新木姜子属下的一个变种。',
 '他主要演出泰米尔语电影。',
 '粗体字表示为主演。',
 '福崎町是位于日本兵库县中部的行政区划。',
 '下行月台设有厕所。',
 '耶尔河畔圣伊莱尔人口变化图示',
 '光绪八年再中举人。',
 '台湾北部地区家庭多以在农历年前时段包润饼则是在清明期间。',
 '蔡声白。',
 '该区舰队主要负责为公海舰队的战列分舰队提供屏护。',
 '雷诺在回归的第一年比赛中以第四名的成绩完成了比赛。',
 '这样都可以啊',
 '此原理也广泛应用于家庭之中用于生产软水。',
 '本片的导演是赵秀贤和梁铉锡。',
 '大戟亚科是大戟科旗下三个亚科的其中之一。',
 '奥特拉德诺耶农村居民点是俄罗斯联邦沃罗涅日州新乌斯曼区所属的一个农村居民点。',
 '吉内斯塔。',
 '其名称来源于当时图书馆的创始人。',
 '长期信用银行法。',
 '主要用户是法国陆军以及阿拉伯联合大公国。',
 '如今该符号已变成一个表示迷因的网络词汇。',
 '戴七宝冠作通身光。',
 '最少数百名购不到门票的影迷要求加场。',
 '她有两个哥哥。',
 '七',
 '尽力和武市半平太一起创立土佐勤王党。',
 '他也是蚬壳电器创办人翁佑的女婿。',
 '厚实的墙壁让旧机内部冬暖夏凉。',
 '槽果扁莎为莎草科扁莎属下的一个种。',
 '正模标本是多个可能的生物型中的一个。',
 '乌大经清朝军事人物。',
 '卡尔雷尔。',
 '滨江县先后隶属于吉林省西北路道和滨江道。',
 '大使馆全年都为各类政治人物和公众举办了许多活动，包括庆祝以色列独立日独立日。',
 '苏光彩之子。',
 '感冒茶起源于中国岭南一带。',
 '拉沙尔克。',
 '伍斯特座堂是英格兰伍斯特郡的一座教堂。',
 '电子浩室是浩室音乐的一个类型。',
 '韦

In [7]:
# with open('/home/siqiouyang/work/projects/WMSeg/zh.txt', 'w') as w:
#     w.write('\n'.join(sentences).replace(' ', ''))

In [8]:
punc = re.compile(",|\?|・|\u3002|\uff1f|\uff01|\uff0c|\u3001|\uff1b|\uff1a|\u201c|\u201d|\u2018|\u2019|\uff08|\uff09|\u300a|\u300b|\u3008|\u3009|\u3010|\u3011|\u300e|\u300f|\u300c|\u300d|\ufe43|\ufe44|\u3014|\u3015|\u2026|\u2014|\uff5e|\ufe4f|\uffe5")

In [9]:
tokenized_sentences = []
with open('/home/siqiouyang/work/projects/WMSeg/cv9_zh.txt.tok', 'r') as r:
    for line in r.readlines():
        line = line.strip()
        if line != '':
            line = punc.sub(' ', line)
            tokens = [tok for tok in line.split(' ') if tok != '']
            tokenized_sentences.append(tokens)

In [10]:
# for path, tokens in zip(cv_zh_df['path'], tokenized_sentences):
#     with open(os.path.join(cv_zh_root, '16kHz', '{}.txt'.format(path[:-4])), 'w') as w:
#         w.write(' '.join(tokens))

In [11]:
final_tokenized_sentences = []
final_segmentations = []
final_audio_paths = []
for i, path in enumerate(tqdm(cv_zh_df['path'])):
    id = path[:-4]
    grid_path = os.path.join(cv_zh_root, '16kHz/align_wmseg/{}.TextGrid'.format(id))
    if os.path.exists(grid_path):
        grid = textgrids.TextGrid(grid_path)
        filtered_grid = [tok for tok in grid['words'] if tok.text != '']

        if len(filtered_grid) != len(tokenized_sentences[i]):
            print(filtered_grid, tokenized_sentences[i], sep='\n', end='\n' + '-'*10 + '\n')

        interval = np.array([(word.xmin, word.xmax) for word in filtered_grid])
        audio_path = os.path.join(cv_zh_root, '16kHz/{}.wav'.format(id))
        info = torchaudio.info(audio_path)
        duration = info.num_frames / info.sample_rate
        interval = interval / duration
        final_segmentations.append(interval)

        final_audio_paths.append(audio_path)
        final_tokenized_sentences.append(tokenized_sentences[i])


100%|██████████| 48516/48516 [00:28<00:00, 1722.70it/s]


In [12]:
class NgramDictionary:
    def __init__(self, tokenized_sentences, segmentations, audio_paths, gram=3, sep=''):
        self.tokenized_sentences = tokenized_sentences

        self.id2tok = []
        self.tok2id = {}
        self.embs = []
        self.waveforms = {}

        assert len(tokenized_sentences) == len(segmentations) and len(segmentations) == len(audio_paths)

        for tok_sent, seg, path in zip(tqdm(tokenized_sentences), segmentations, audio_paths):
            waveform = torchaudio.load(path)[0][0]
            len_wf = waveform.size(0)
            assert len(tok_sent) == len(seg)
            for l in range(1, gram + 1):
                for idx in range(len(tok_sent) - l + 1):
                    tok = sep.join(tok_sent[idx : idx + l])
                    xmin, xmax = seg[idx][0], seg[idx + l - 1][1]
                    
                    tok_wf = waveform[int(xmin * len_wf) : int(xmax * len_wf)]
                    if tok not in self.tok2id:
                        self.tok2id[tok] = len(self.id2tok)
                        self.id2tok.append(tok)
                        self.waveforms[self.tok2id[tok]] = [tok_wf]
                    else:
                        self.waveforms[self.tok2id[tok]].append(tok_wf)

        batch_size = 10000
        for idx in tqdm(range(0, len(self.id2tok), batch_size)):
            self.embs.append(model.encode(self.id2tok[idx : idx + batch_size]))
        
        self.embs = np.concatenate(self.embs, axis=0)
        self.index = faiss.IndexFlatIP(self.embs.shape[-1])
        self.index.add(self.embs)

        self.id2tok = np.array(self.id2tok)

    def search(self, x, k):
        return self.index.search(x, k)

In [None]:
cv_zh_dict = NgramDictionary(final_tokenized_sentences, final_segmentations, final_audio_paths, gram=3)

In [13]:
cv_zh_dict = th.load('dict/zh-CN_dict.pt')

# Build Chinese Dictionary

In [None]:
root = '/mnt/data/siqiouyang/datasets/covost2/zh-CN'
df = load_df_from_tsv(os.path.join(root, 'train_st_zh-CN_en.tsv'))

In [None]:
audio_paths = []
for id in df['id']:
    audio_paths.append(os.path.join(root, '16kHz/{}.wav'.format(id)))

In [None]:
sentences = df['src_text'].tolist()
sentences

In [None]:
# with open('/home/siqiouyang/work/projects/WMSeg/zh.txt', 'w') as w:
#     w.write('\n'.join(sentences).replace(' ', ''))

In [None]:
tokenized_sentences = []
with open('/home/siqiouyang/work/projects/WMSeg/zh.txt.tok', 'r') as r:
    for line in r.readlines():
        line = line.strip()
        if line != '':
            line = punc.sub(' ', line)
            tokens = [tok for tok in line.split(' ') if tok != '']
            tokenized_sentences.append(tokens)

In [None]:
# for id, tokens in zip(df['id'], tokenized_sentences):
#     with open(os.path.join(root, '16kHz', '{}.txt'.format(id)), 'w') as w:
#         w.write(' '.join(tokens))

In [None]:
# mfa to force align

In [None]:
filtered_grids = []
for i, id in enumerate(df['id']):
    grid = textgrids.TextGrid(os.path.join(root, '16kHz/align_wmseg/{}.TextGrid'.format(id)))
    filtered_grid = [tok for tok in grid['words'] if tok.text != '']

    if len(filtered_grid) != len(tokenized_sentences[i]):
        print(filtered_grid, tokenized_sentences[i], sep='\n', end='\n' + '-'*10 + '\n')

    filtered_grids.append(filtered_grid)

In [None]:
segmentations = []
for grid, path in zip(filtered_grids, audio_paths):
    interval = np.array([(word.xmin, word.xmax) for word in grid])
    info = torchaudio.info(path)
    duration = info.num_frames / info.sample_rate
    interval = interval / duration
    segmentations.append(interval)

In [None]:
class Dictionary:
    def __init__(self, tokenized_sentences, segmentations, audio_paths):
        self.tokenized_sentences = tokenized_sentences

        self.id2tok = []
        self.tok2id = {}
        self.embs = []
        self.waveforms = {}

        assert len(tokenized_sentences) == len(segmentations) and len(segmentations) == len(audio_paths)

        for tok_sent, seg, path in zip(tqdm(tokenized_sentences), segmentations, audio_paths):
            waveform = torchaudio.load(path)[0][0]
            len_wf = waveform.size(0)
            assert len(tok_sent) == len(seg), print(tok_sent, seg)
            for tok, (xmin, xmax) in zip(tok_sent, seg):
                tok_wf = waveform[int(xmin * len_wf) : int(xmax * len_wf)]
                if tok not in self.tok2id:
                    self.tok2id[tok] = len(self.id2tok)
                    self.id2tok.append(tok)
                    self.waveforms[self.tok2id[tok]] = [tok_wf]
                else:
                    self.waveforms[self.tok2id[tok]].append(tok_wf)

        batch_size = 10000
        for idx in tqdm(range(0, len(self.id2tok), batch_size)):
            self.embs.append(model.encode(self.id2tok[idx : idx + batch_size]))
        
        self.embs = np.concatenate(self.embs, axis=0)
        self.index = faiss.IndexFlatIP(self.embs.shape[-1])
        self.index.add(self.embs)

        self.id2tok = np.array(self.id2tok)

    def search(self, x, k):
        return self.index.search(x, k)

In [None]:
# zh_dict = Dictionary(tokenized_sentences, segmentations, audio_paths)
zh_dict = NgramDictionary(tokenized_sentences, segmentations, audio_paths, gram=3)

# Build German Dictionary

In [None]:
root = '/mnt/data/siqiouyang/datasets/covost2/de'
df = load_df_from_tsv(os.path.join(root, 'train_st_de_en.tsv'))

In [None]:
sentences = df['src_text'].tolist()
sentences

In [None]:
de_punctuation = punctuation + '„“”‚’«»ʿ‹›‘'

In [None]:
tokenized_sentences = []
for sent in sentences:
    sent = ''.join(c for c in sent if c not in punctuation)
    tokens = [tok for tok in sent.split(' ') if tok != '']
    tokenized_sentences.append(tokens)

In [None]:
# mfa to force align

In [None]:
filtered_grids = []
audio_paths = []
filtered_tokenized_sentences = []
for i, id in enumerate(tqdm(df['id'])):
    grid_path = os.path.join(root, '16kHz/align/{}.TextGrid'.format(id))
    if os.path.exists(grid_path):
        grid = textgrids.TextGrid(grid_path)
        filtered_grid = [tok for tok in grid['words'] if tok.text != '']

        if len(filtered_grid) != len(tokenized_sentences[i]):
            # print(filtered_grid, tokenized_sentences[i], sep='\n', end='\n' + '-'*10 + '\n')
            continue
        
        filtered_grids.append(filtered_grid)
        filtered_tokenized_sentences.append(tokenized_sentences[i])
        audio_paths.append(os.path.join(root, '16kHz/{}.wav'.format(id)))

In [None]:
segmentations = []
for grid, path in zip(filtered_grids, audio_paths):
    interval = np.array([(word.xmin, word.xmax) for word in grid])
    info = torchaudio.info(path)
    duration = info.num_frames / info.sample_rate
    interval = interval / duration
    segmentations.append(interval)

In [None]:
# de_dict = Dictionary(filtered_tokenized_sentences, segmentations, audio_paths)
de_dict = NgramDictionary(filtered_tokenized_sentences, segmentations, audio_paths, gram=3, sep=' ')

# Build German Features

In [14]:
de_root = '/mnt/data/siqiouyang/datasets/covost2/de'
de_df = load_df_from_tsv(os.path.join(de_root, 'train_st_de_en.tsv'))

In [15]:
de_sentences = de_df['src_text'].tolist()

In [16]:
de_tokenized_sentences = []
for sent in de_sentences:
    sent = ''.join(c for c in sent if c not in punctuation)
    tokens = [tok for tok in sent.split(' ') if tok != '']
    de_tokenized_sentences.append(tokens)

In [17]:
# for id, tokens in zip(de_df['id'], tokenized_sentences):
#     with open(os.path.join(de_root, '16kHz', '{}.txt'.format(id)), 'w') as w:
#         w.write(' '.join(tokens))

In [18]:
de_gram = 1

In [19]:
all_de_tokens = set()
de_freq = defaultdict(int)
for tokens in de_tokenized_sentences:
    for idx in range(len(tokens)):
        for g in range(min(de_gram, len(tokens) - idx)):
            token = ' '.join(tokens[idx: idx + g + 1])
            all_de_tokens.add(token)
            de_freq[token] += 1

In [20]:
n_total_token = sum(de_freq.values())
n_total_token

1110670

In [None]:
th.save([all_de_tokens, de_freq], 'dict/de_token.pt')

In [24]:
de_embs = model.encode(list(all_de_tokens), device=device)

KeyboardInterrupt: 

In [None]:
de_embs.shape, 

In [None]:
D, I = zh_dict.search(de_embs, k=1)

In [None]:
mask = (D > 0.9).flatten()

In [None]:
all_de_tokens = list(all_de_tokens)
n_total_found = 0
for idx in range(len(all_de_tokens)):
    if mask[idx]:
        n_total_found += de_freq[all_de_tokens[idx]]

In [None]:
n_total_found / n_total_token

In [None]:
tokens = np.array(all_de_tokens)

In [None]:
match = {}
for w_de, w_zh in zip(tokens[mask], zh_dict.id2tok[I.flatten()[mask]]):
    match[w_de] = w_zh
    print(w_de, w_zh)

In [None]:
# need to place a language tag

# Build Chinese Features

In [None]:
zh_root = '/mnt/data/siqiouyang/datasets/covost2/zh-CN'
zh_df = load_df_from_tsv(os.path.join(zh_root, 'train_st_zh-CN_en.tsv'))

In [None]:
zh_sentences = zh_df['src_text'].tolist()

In [None]:
zh_tokenized_sentences = []
with open('/home/siqiouyang/work/projects/WMSeg/zh.txt.tok', 'r') as r:
    for line in r.readlines():
        line = line.strip()
        if line != '':
            line = punc.sub(' ', line)
            tokens = [tok for tok in line.split(' ') if tok != '']
            zh_tokenized_sentences.append(tokens)

In [None]:
all_zh_tokens = set()
zh_freq = defaultdict(int)
for tokens in zh_tokenized_sentences:
    for token in tokens:
        all_zh_tokens.add(token)
        zh_freq[token] += 1

In [None]:
n_total_token = sum(zh_freq.values())
n_total_token

In [None]:
zh_embs = model.encode(list(all_zh_tokens))

In [None]:
zh_embs.shape, 

In [None]:
D, I = de_dict.search(zh_embs, k=1)

In [None]:
mask = (D > 0.9).flatten()

In [None]:
all_zh_tokens = list(all_zh_tokens)
n_total_found = 0
for idx in range(len(all_zh_tokens)):
    if mask[idx]:
        n_total_found += zh_freq[all_zh_tokens[idx]]

In [None]:
n_total_found / n_total_token

In [None]:
tokens = np.array(all_zh_tokens)

In [None]:
match = {}
for w_zh, w_de in zip(tokens[mask], de_dict.id2tok[I.flatten()[mask]]):
    match[w_zh] = w_de
    print(w_zh, w_de)

# Mix Chinese Audio into German

In [None]:
de_punctuation = punctuation + '„“”‚’«»ʿ‹›‘'

In [None]:
version = 'de-zh-3g'

In [None]:
de_zh_root = '/mnt/data/siqiouyang/datasets/covost2/{}'.format(version)

In [None]:
silence = th.zeros(1600)

In [None]:
n_pass = 0
n_notpass = 0
for id, transcript in zip(tqdm(de_df['id']), de_df['src_text']):
    grid_path = os.path.join(de_zh_root, '16kHz/align/{}.TextGrid'.format(id))
    if os.path.exists(grid_path):
        grid = textgrids.TextGrid(grid_path)
        tokens = [tok for tok in grid['words'] if tok.text != '']

        transcript = ''.join(c for c in transcript if c not in de_punctuation)
        ori_tokens = [tok for tok in transcript.split(' ') if tok != '']

        string = ' '.join([tok.text.lower() for tok in tokens])
        ori_string = ' '.join([tok.lower() for tok in ori_tokens])

        if string == ori_string:
            audio_path = os.path.join(de_zh_root, '16kHz/{}.wav'.format(id))
            waveform = torchaudio.load(audio_path)[0][0]
            n_frame = waveform.size(0)
            frame_rate = 16000
            duration = n_frame / frame_rate

            new_waveform = []
            new_tokens = []
            
            intervals = (np.array([(tok.xmin, tok.xmax) for tok in tokens]) * frame_rate).astype(int)

            average_volume = 0.
            word_duration = 0.
            for xmin, xmax in intervals:
                word_duration += xmax - xmin
                average_volume += waveform[xmin : xmax].abs().sum()
            average_volume /= word_duration

            mask = []
            orig_intervals = []
            mixed_intervals = []
            last_end = idx = 0
            prefix_length = 0.
            while idx < len(intervals):
                xmin, xmax = intervals[idx]
                new_waveform.append(waveform[last_end : xmin])
                prefix_length += xmin - last_end
                replace = False
                for g in range(min(de_gram, len(intervals) - idx), 0, -1):
                    token = ' '.join(ori_tokens[idx : idx + g])
                    if token in match:
                        replace = True
                        token_zh = match[token]
                        selectable_waveforms = zh_dict.waveforms[zh_dict.tok2id[token_zh]]

                        selected_waveform = random.choice(selectable_waveforms)
                        selected_volume = selected_waveform.abs().mean()
                        selected_waveform *= average_volume / selected_volume

                        sil_selected_waveform = th.cat([silence, selected_waveform, silence], dim=0)

                        new_waveform.append(sil_selected_waveform)
                        new_tokens.append(token_zh)
                        last_end = intervals[idx + g - 1][1]
                        idx += g
                        mask.append(True)
                        orig_intervals.append((xmin, last_end))
                        mixed_intervals.append((
                            prefix_length + silence.size(0), 
                            prefix_length + silence.size(0) + selected_waveform.size(0)
                        ))
                        prefix_length += silence.size(0) * 2 + selected_waveform.size(0)
                        break
                if not replace:
                    new_waveform.append(waveform[xmin : xmax])
                    new_tokens.append(ori_tokens[idx])
                    last_end = xmax
                    idx += 1
                    mask.append(False)
                    orig_intervals.append((xmin, xmax))
                    mixed_intervals.append((
                        prefix_length,
                        prefix_length + xmax - xmin
                    ))
                    prefix_length += xmax - xmin

            new_waveform = th.cat(new_waveform, dim=0).unsqueeze(0)
            new_audio_path = os.path.join(de_zh_root, '16kHz/{}-mixed.wav'.format(id))
            torchaudio.save(new_audio_path, new_waveform, sample_rate=16000)

            mask = th.tensor(mask, dtype=bool)
            orig_intervals = th.tensor(orig_intervals, dtype=float) / n_frame
            mixed_intervals = th.tensor(mixed_intervals, dtype=float) / new_waveform.size(1)
            match_path = os.path.join(de_zh_root, '16kHz/{}.pt'.format(id))
            th.save([mask, orig_intervals, mixed_intervals], match_path)

In [None]:
n_total_matched = 0
for id, transcript in zip(tqdm(de_df['id']), de_df['src_text']):
    grid_path = os.path.join(de_zh_root, '16kHz/align/{}.TextGrid'.format(id))
    if os.path.exists(grid_path):
        grid = textgrids.TextGrid(grid_path)
        tokens = [tok for tok in grid['words'] if tok.text != '']

        transcript = ''.join(c for c in transcript if c not in de_punctuation)
        ori_tokens = [tok for tok in transcript.split(' ') if tok != '']

        string = ' '.join([tok.text.lower() for tok in tokens])
        ori_string = ' '.join([tok.lower() for tok in ori_tokens])

        if string == ori_string:
            for token in ori_tokens:
                replace_flag = np.random.rand() > 0.5
                if token in match and replace_flag:
                    n_total_matched += 1

In [None]:
n_total_matched / n_total_token

In [None]:
de_zh_df1 = load_df_from_tsv(os.path.join(de_zh_root, 'train_st_de-zh-3g_en.tsv'))
de_zh_df2 = load_df_from_tsv(os.path.join(de_zh_root, 'train_st_de-zh-3g_en_phone.tsv'))

In [None]:
retain = []
for id in de_zh_df1['id']:
    if os.path.exists(os.path.join(de_zh_root, '16kHz/{}-mixed.wav'.format(id))):
        retain.append(True)
    else:
        retain.append(False)

In [None]:
de_zh_df1_filtered = de_zh_df1.loc[retain]
de_zh_df2_filtered = de_zh_df2.loc[retain]

In [None]:
de_zh_df1_filtered['src_lang'] = [version] * len(de_zh_df1_filtered)
de_zh_df2_filtered['src_lang'] = [version] * len(de_zh_df2_filtered)

In [None]:
save_df_to_tsv(de_zh_df1_filtered, os.path.join(de_zh_root, 'train_st_{}_en.tsv'.format(version)))
save_df_to_tsv(de_zh_df2_filtered, os.path.join(de_zh_root, 'train_st_{}_en_phone.tsv'.format(version)))

# Mix German Audio into Chinese

In [None]:
zh_de_root = '/mnt/data/siqiouyang/datasets/covost2/zh-de-3g'

In [None]:
n_pass = 0
n_notpass = 0
for id, ori_tokens in zip(tqdm(zh_df['id']), zh_tokenized_sentences):
    grid_path = os.path.join(zh_de_root, '16kHz/align_wmseg/{}.TextGrid'.format(id))
    if os.path.exists(grid_path):
        grid = textgrids.TextGrid(grid_path)
        tokens = [tok for tok in grid['words'] if tok.text != '']

        assert len(tokens) == len(ori_tokens)

        audio_path = os.path.join(zh_de_root, '16kHz/{}.wav'.format(id))
        waveform = torchaudio.load(audio_path)[0][0]
        n_frame = waveform.size()
        frame_rate = 16000

        new_waveform = []
        
        intervals = (np.array([(tok.xmin, tok.xmax) for tok in tokens]) * frame_rate).astype(int)
        last_end = 0
        for token, (xmin, xmax) in zip(ori_tokens, intervals):
            new_waveform.append(waveform[last_end : xmin])

            if token in match:
                token_de = match[token]
                selectable_waveforms = de_dict.waveforms[de_dict.tok2id[token_de]]
                new_waveform.append(random.choice(selectable_waveforms))
            else:
                new_waveform.append(waveform[xmin : xmax])

            last_end = xmax

        new_waveform = th.cat(new_waveform, dim=0).unsqueeze(0)
        torchaudio.save(audio_path, new_waveform, sample_rate=16000)

# Adjust Number of Frame

In [None]:
version = 'de-zh-3g-hf'
adj_root = '/mnt/data/siqiouyang/datasets/covost2/' + version

In [None]:
adj_df1 = load_df_from_tsv(os.path.join(adj_root, 'train_st_{}_en.tsv'.format(version)))
adj_df2 = load_df_from_tsv(os.path.join(adj_root, 'train_st_{}_en_phone.tsv'.format(version)))

In [None]:
for adj_df in [adj_df1, adj_df2]:
    n_frames = []
    for id in tqdm(adj_df['id']):
        info = torchaudio.info(os.path.join(adj_root, '16kHz/{}.wav'.format(id)))
        n_frames.append(info.num_frames)
    adj_df['n_frames'] = n_frames

In [None]:
save_df_to_tsv(adj_df1, os.path.join(adj_root, 'train_st_{}_en.tsv'.format(version)))
save_df_to_tsv(adj_df2, os.path.join(adj_root, 'train_st_{}_en_phone.tsv'.format(version)))

# Generate Examples

In [None]:
de_punctuation = punctuation + '„“”‚’«»ʿ‹›‘'
de_zh_root = '/mnt/data/siqiouyang/datasets/covost2/de'

In [None]:
indices = list(range(len(de_df)))
random.shuffle(indices)

In [None]:
n_max = 100

In [None]:
silence = th.zeros(1600)

In [None]:
n_replaced = 0
n_total = 0

In [None]:
de_origs = []
de_mixeds = []
de_audios = []
de_orig_audios = []
for idx in tqdm(indices[:n_max]):
    id = de_df['id'][idx]
    transcript = de_df['src_text'][idx]
    grid_path = os.path.join(de_zh_root, '16kHz/align/{}.TextGrid'.format(id))
    if os.path.exists(grid_path):
        grid = textgrids.TextGrid(grid_path)
        tokens = [tok for tok in grid['words'] if tok.text != '']

        transcript = ''.join(c for c in transcript if c not in de_punctuation)
        ori_tokens = [tok for tok in transcript.split(' ') if tok != '']

        string = ' '.join([tok.text.lower() for tok in tokens])
        ori_string = ' '.join([tok.lower() for tok in ori_tokens])

        if string == ori_string:
            audio_path = os.path.join(de_zh_root, '16kHz/{}.wav'.format(id))
            waveform = torchaudio.load(audio_path)[0][0]
            n_frame = waveform.size()
            frame_rate = 16000

            new_waveform = []
            new_tokens = []
            
            intervals = (np.array([(tok.xmin, tok.xmax) for tok in tokens]) * frame_rate).astype(int)

            average_volume = 0.
            total_duration = 0.
            for xmin, xmax in intervals:
                total_duration += xmax - xmin
                average_volume += waveform[xmin : xmax].abs().sum()
            average_volume /= total_duration

            last_end = 0

            idx = 0
            while idx < len(intervals):
                xmin, xmax = intervals[idx]
                new_waveform.append(waveform[last_end : xmin])
                replace = False
                for g in range(min(de_gram, len(intervals) - idx), 0, -1):
                    token = ' '.join(ori_tokens[idx : idx + g])
                    if token in match:
                        replace = True
                        token_zh = match[token]
                        selectable_waveforms = zh_dict.waveforms[zh_dict.tok2id[token_zh]]

                        selected_waveform = random.choice(selectable_waveforms)
                        selected_volume = selected_waveform.abs().mean()
                        selected_waveform *= average_volume / selected_volume

                        selected_waveform = th.cat([silence, selected_waveform, silence], dim=0)

                        new_waveform.append(selected_waveform)
                        new_tokens.append(token_zh)
                        last_end = intervals[idx + g - 1][1]
                        idx += g
                        n_replaced += g
                        n_total += g
                        break
                if not replace:
                    new_waveform.append(waveform[xmin : xmax])
                    new_tokens.append(ori_tokens[idx])
                    last_end = xmax
                    idx += 1
                    n_total += 1

            # for token, (xmin, xmax) in zip(ori_tokens, intervals):
            #     new_waveform.append(waveform[last_end : xmin])

            #     replace_flag = True # np.random.rand() > 0.5

            #     if token in match and replace_flag:
            #         token_zh = match[token]
            #         selectable_waveforms = zh_dict.waveforms[zh_dict.tok2id[token_zh]]
            #         new_waveform.append(random.choice(selectable_waveforms))
            #         new_tokens.append(match[token])
            #     else:
            #         new_waveform.append(waveform[xmin : xmax])
            #         new_tokens.append(token)

            #     last_end = xmax

            new_waveform = th.cat(new_waveform, dim=0).unsqueeze(0)
            new_audio_path = 'de-zh-sample/{}.wav'.format(len(de_origs))
            orig_audio_path = 'de-zh-sample/{}-orig.wav'.format(len(de_origs))
            torchaudio.save(new_audio_path, new_waveform, sample_rate=16000)
            torchaudio.save(orig_audio_path, waveform.unsqueeze(0), sample_rate=16000)

            de_origs.append(' '.join(ori_tokens))
            de_mixeds.append(' '.join(new_tokens))
            de_audios.append(new_audio_path)
            de_orig_audios.append(orig_audio_path)

In [None]:
n_replaced / n_total

In [None]:
string = '<table>\n'
string += '\t<tr>\n\t\t<th>de-orig</th>\n\t\t<th>audio-orig</th>\n\t\t<th>de-mixed</th>\n\t\t<th>audio-mixed</th>\n\t</tr>\n'
for de_orig, de_mixed, de_audio, de_orig_audio in zip(de_origs, de_mixeds, de_audios, de_orig_audios):
    string += '\t<tr>\n\t\t<td>{}</td>\n\t\t<td>{}</td>\n\t\t<td>{}</td>\n\t\t<td>{}</td>\n\t</tr>\n'.format(
        de_orig, 
        '<audio controls><source src="{}" type="audio/wav"></audio>'.format(de_orig_audio),
        de_mixed,
        '<audio controls><source src="{}" type="audio/wav"></audio>'.format(de_audio)
    )
string += '</table>'

In [None]:
with open('de-zh-sample.html', 'w', encoding='utf-8') as w:
    w.write(string)

# Evaluate Trained Models

In [None]:
de_origs = []
de_orig_intervals = []
de_mixeds = []
de_mixed_intervals = []
de_audios = []
de_orig_audios = []
for idx in tqdm(indices[:n_max]):
    id = de_df['id'][idx]
    transcript = de_df['src_text'][idx]
    grid_path = os.path.join(de_zh_root, '16kHz/align/{}.TextGrid'.format(id))
    if os.path.exists(grid_path):
        grid = textgrids.TextGrid(grid_path)
        tokens = [tok for tok in grid['words'] if tok.text != '']

        transcript = ''.join(c for c in transcript if c not in de_punctuation)
        ori_tokens = [tok for tok in transcript.split(' ') if tok != '']

        string = ' '.join([tok.text.lower() for tok in tokens])
        ori_string = ' '.join([tok.lower() for tok in ori_tokens])

        if string == ori_string:
            audio_path = os.path.join(de_zh_root, '16kHz/{}.wav'.format(id))
            waveform = torchaudio.load(audio_path)[0][0]
            n_frame = waveform.size()
            frame_rate = 16000

            new_waveform = []
            new_tokens = []

            orig_intervals = []
            new_intervals = []
            
            intervals = (np.array([(tok.xmin, tok.xmax) for tok in tokens]) * frame_rate).astype(int)
            last_end = 0
            sum_lengths = 0
            for token, (xmin, xmax) in zip(ori_tokens, intervals):
                new_waveform.append(waveform[last_end : xmin])
                sum_lengths += xmin - last_end

                replace_flag = True # np.random.rand() > 0.5

                orig_intervals.append((xmin, xmax))

                if token in match and replace_flag:
                    token_zh = match[token]
                    selectable_waveforms = zh_dict.waveforms[zh_dict.tok2id[token_zh]]
                    selected_waveform = random.choice(selectable_waveforms)
                    new_waveform.append(selected_waveform)
                    new_tokens.append(match[token])
                    new_intervals.append((sum_lengths, sum_lengths + selected_waveform.size(0)))
                    sum_lengths += selected_waveform.size(0)
                else:
                    new_waveform.append(waveform[xmin : xmax])
                    new_tokens.append(token)
                    new_intervals.append((sum_lengths, sum_lengths + xmax - xmin))
                    sum_lengths += xmax - xmin

                last_end = xmax

            new_waveform = th.cat(new_waveform, dim=0).unsqueeze(0)
            new_audio_path = 'de-zh-sample/{}.wav'.format(len(de_origs))
            orig_audio_path = 'de-zh-sample/{}-orig.wav'.format(len(de_origs))
            torchaudio.save(new_audio_path, new_waveform, sample_rate=16000)
            torchaudio.save(orig_audio_path, waveform.unsqueeze(0), sample_rate=16000)

            de_origs.append(ori_tokens)
            de_orig_intervals.append(orig_intervals)
            de_mixeds.append(new_tokens)
            de_mixed_intervals.append(new_intervals)
            de_audios.append(new_audio_path)
            de_orig_audios.append(orig_audio_path)

## Load Model

In [None]:
args = Namespace()
task = Namespace()

In [None]:
def load_dict(vocab_filename):
    _dict_path = vocab_filename
    if not os.path.isfile(_dict_path):
        raise FileNotFoundError(f"Dict not found: {_dict_path}")
    _dict = Dictionary.load(_dict_path)
    for code in codes:
        _dict.add_symbol(MultilingualTripletDataset.LANG_TAG_TEMPLATE.format(code))
    _dict.add_symbol('<mask>')
    return _dict

In [None]:
lang_list_filename = '/mnt/data/siqiouyang/runs/mST/pretrained/mbart50.ft.n1/ML50_langs.txt'
vocab_filename = '/mnt/data/siqiouyang/runs/mST/pretrained/mbart50.ft.n1/dict.txt'
phone_vocab_filename = '/mnt/data/siqiouyang/datasets/covost2/phone_dict.txt'

In [None]:
codes = MultilingualTripletDataset.get_lang_codes(lang_list_filename)
dict = load_dict(vocab_filename)
with open(phone_vocab_filename, 'r') as r:
    phone_list = [l.strip() for l in r.readlines() if l.strip() != '']
    phone_dict = {l: idx + 1 for idx, l in enumerate(phone_list)} # leave 0 as blank
    phone_list = ['|'] + phone_list

In [None]:
task.src_dict = task.tgt_dict = dict
task.phone_dict = phone_dict

In [None]:
args.w2v2_model_path = '/mnt/data/siqiouyang/runs/mST/pretrained/xlsr2_300m.pt'
args.mbart50_dir = '/mnt/data/siqiouyang/runs/mST/pretrained/mbart50.ft.n1'

In [None]:
model = W2V2Transformer.build_model(args, task)

In [None]:
ckpt_path = '/mnt/data/siqiouyang/runs/mST/xlsr_phone_mbart_zh_de-zh/checkpoint_13_10000.pt'
ckpt = load_checkpoint_to_cpu(ckpt_path)

In [None]:
model.load_state_dict(ckpt["model"], strict=False)
model = model.to(device)
model.eval()

## Compute Features

In [None]:
dict.symbols[-55:]

In [None]:
def compute(path, lang_tag):
    source = get_features_or_waveform(
        path,
        need_waveform=True,
        sample_rate=16000,
    )
    source = th.from_numpy(source).float()
    src_tokens = source.unsqueeze(0).to(device)
    src_lengths = th.LongTensor([source.size(0)]).to(device)
    src_lang_tag_idx = dict.index(lang_tag)
    src_lang_tag_indices = th.LongTensor([src_lang_tag_idx]).unsqueeze(-1).to(device)
    encoder_out = model.encoder(src_tokens, src_lengths, src_lang_tag_indices)
    return encoder_out

In [None]:
path = 'de-zh-sample/1-orig.wav'
info = torchaudio.info(path)
orig_encoder_out = compute(path, '<lang:de_DE>')
orig_length, _, d = orig_encoder_out.encoder_out.size()
orig_interval = ((np.array(de_orig_intervals[1]) / info.num_frames) * orig_length).astype(int)

In [None]:
path = 'de-zh-sample/1.wav'
info = torchaudio.info(path)
mixed_encoder_out = compute(path, '<lang:de_DE>')
mixed_length, _, d = mixed_encoder_out.encoder_out.size()
mixed_interval = ((np.array(de_orig_intervals[1]) / info.num_frames) * mixed_length).astype(int)

In [None]:
orig_interval

In [None]:
mixed_interval

In [None]:
A = orig_encoder_out.encoder_out[47:56].mean(dim=0).squeeze()

In [None]:
B = mixed_encoder_out.encoder_out[56:65].mean(dim=0).squeeze()

In [None]:
C = mixed_encoder_out.encoder_out[47:55].mean(dim=0).squeeze()

In [None]:
(A * B).sum() / A.norm() / B.norm(), ((A - B) ** 2).sum()

In [None]:
(A * C).sum() / A.norm() / C.norm(), ((A - C) ** 2).sum()