In [2]:
import os
import re
import random
from argparse import Namespace
from string import punctuation
from collections import defaultdict

import IPython

import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import wandb

from tqdm import tqdm
import torchaudio
import numpy as np

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import sentencepiece as spm

import textgrids
import faiss

from sentence_transformers import SentenceTransformer
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

from fairseq.data import Dictionary
from fairseq.models.mST.w2v2_phone_transformer import W2V2Transformer
from fairseq.data.audio.multilingual_triplet_v2_phone_dataset import (
    MultilingualTripletDataConfig,
    MultilingualTripletDataset,
    MultilingualTripletDatasetCreator
)
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
from examples.speech_to_text.data_utils import load_df_from_tsv, save_df_to_tsv
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE, SentencepieceConfig

from fairseq.models.mST.w2v2_phone_transformer import W2V2Transformer
from fairseq.data.audio.multilingual_triplet_v2_phone_dataset import (
    MultilingualTripletDataConfig,
    MultilingualTripletDataset,
    MultilingualTripletDatasetCreator
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = 'cuda:7'
model = SentenceTransformer('sentence-transformers/LaBSE').to(device)

# Build Code-Switch for CV 9.0 Language

In [4]:
root = '/mnt/data/siqiouyang/datasets/cv-corpus-9.0-2022-04-27/'

## Build Dictionary

In [5]:
dict_lang = 'de'
dict_lang_root = os.path.join(root, dict_lang)

In [6]:
dict_lang_df = load_df_from_tsv(os.path.join(dict_lang_root, 'validated.tsv'))

In [6]:
# Optionally resample the audio to 16kHz
os.makedirs(os.path.join(dict_lang_root, '16kHz'))
for path in tqdm(dict_lang_df['path']):
    audio_path = os.path.join(dict_lang_root, 'clips/{}'.format(path))
    waveform, sample_rate = torchaudio.load(audio_path)
    resampled_waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
    torchaudio.save(os.path.join(dict_lang_root, '16kHz/{}.wav'.format(path[:-4])), resampled_waveform, sample_rate=16000)

FileExistsError: [Errno 17] File exists: '/mnt/data/siqiouyang/datasets/cv-corpus-9.0-2022-04-27/zh-CN/16kHz'

In [7]:
sentences = dict_lang_df['sentence'].tolist()
sentences

['Zieht euch bitte draußen die Schuhe aus.',
 'Es gibt auch mehrere Campingplätze.',
 'Es kommt zum Showdown in Gstaad.',
 'Ihre Fotostrecken erschienen in Modemagazinen wie der Vogue, Harper’s Bazaar und Marie Claire.',
 'Aber weißt du, wer den Stein wirklich ins Rollen gebracht hat?',
 'Felipe hat eine auch für Monarchen ungewöhnlich lange Titelliste.',
 'Was solls, ich bin bereit.',
 'Sein Atelier diente nebenher zur Tarnung unerlaubter politischer Aktivitäten.',
 'Das Internet besteht aus vielen Computern, die miteinander verbunden sind.',
 'Jörg van Essen ist stets über die Landesliste Nordrhein-Westfalen in den Deutschen Bundestag eingezogen.',
 'Der Uranus ist der siebente Planet in unserem Sonnensystem.',
 'Sie gewannen das Match mit sechs Löchern Vorsprung.',
 'Die Wagen erhielten ein einheitliches Erscheinungsbild in weiß mit rotem Fensterband.',
 'Seinen Vornamen erhielt er in Gedenken an seinen früh verstorbenen Onkel.',
 'Mit den Senators nahm er an drei All-Star-Spielen t

In [8]:
tokenized_sentences = []

# specifically for zh-CN
# pattern = re.compile(",|\?|・|\u3002|\uff1f|\uff01|\uff0c|\u3001|\uff1b|\uff1a|\u201c|\u201d|\u2018|\u2019|\uff08|\uff09|\u300a|\u300b|\u3008|\u3009|\u3010|\u3011|\u300e|\u300f|\u300c|\u300d|\ufe43|\ufe44|\u3014|\u3015|\u2026|\u2014|\uff5e|\ufe4f|\uffe5")
# with open('/home/siqiouyang/work/projects/WMSeg/cv9_zh.txt.tok', 'r') as r:
#     for line in r.readlines():
#         line = line.strip()
#         if line != '':
#             line = pattern.sub(' ', line)
#             tokens = [tok for tok in line.split(' ') if tok != '']
#             tokenized_sentences.append(tokens)

# for other languages
for sent in sentences:
    sent = ''.join(c for c in sent if c not in punctuation)
    tokens = [tok for tok in sent.split(' ') if tok != '']
    tokenized_sentences.append(tokens)

In [9]:
final_tokenized_sentences = []
final_segmentations = []
final_audio_paths = []
for i, path in enumerate(tqdm(dict_lang_df['path'])):
    id = path[:-4]
    # grid_path = os.path.join(dict_lang_root, '16kHz/align_wmseg/{}.TextGrid'.format(id))
    grid_path = os.path.join(dict_lang_root, '16kHz/align/{}.TextGrid'.format(id))
    if os.path.exists(grid_path):
        grid = textgrids.TextGrid(grid_path)
        filtered_grid = [tok for tok in grid['words'] if tok.text != '']

        if len(filtered_grid) != len(tokenized_sentences[i]):
            # print(filtered_grid, tokenized_sentences[i], sep='\n', end='\n' + '-'*10 + '\n')
            continue

        interval = np.array([(word.xmin, word.xmax) for word in filtered_grid])
        audio_path = os.path.join(dict_lang_root, '16kHz/{}.wav'.format(id))
        info = torchaudio.info(audio_path)
        duration = info.num_frames / info.sample_rate
        interval = interval / duration
        final_segmentations.append(interval)

        final_audio_paths.append(audio_path)
        final_tokenized_sentences.append(tokenized_sentences[i])


100%|██████████| 765790/765790 [02:44<00:00, 4644.89it/s]


In [13]:
class NgramDictionary:
    def __init__(self, tokenized_sentences, segmentations, audio_paths, gram=3, sep=''):
        self.tokenized_sentences = tokenized_sentences

        self.id2tok = []
        self.tok2id = {}
        self.embs = []
        self.waveforms = {}

        assert len(tokenized_sentences) == len(segmentations) and len(segmentations) == len(audio_paths)

        for tok_sent, seg, path in zip(tqdm(tokenized_sentences), segmentations, audio_paths):
            waveform = torchaudio.load(path)[0][0]
            len_wf = waveform.size(0)
            assert len(tok_sent) == len(seg)
            for l in range(1, gram + 1):
                for idx in range(len(tok_sent) - l + 1):
                    tok = sep.join(tok_sent[idx : idx + l])
                    xmin, xmax = seg[idx][0], seg[idx + l - 1][1]
                    
                    tok_wf = waveform[int(xmin * len_wf) : int(xmax * len_wf)]
                    if tok not in self.tok2id:
                        self.tok2id[tok] = len(self.id2tok)
                        self.id2tok.append(tok)
                        self.waveforms[self.tok2id[tok]] = [tok_wf]
                    else:
                        self.waveforms[self.tok2id[tok]].append(tok_wf)

        batch_size = 10000
        for idx in tqdm(range(0, len(self.id2tok), batch_size)):
            self.embs.append(model.encode(self.id2tok[idx : idx + batch_size], device=device))
        
        self.embs = np.concatenate(self.embs, axis=0)
        self.index = faiss.IndexFlatIP(self.embs.shape[-1])
        self.index.add(self.embs)

        self.id2tok = np.array(self.id2tok)

    def search(self, x, k):
        return self.index.search(x, k)

In [16]:
dict = NgramDictionary(final_tokenized_sentences, final_segmentations, final_audio_paths, gram=3, sep=' ')

100%|██████████| 761690/761690 [13:43<00:00, 925.07it/s]  
100%|██████████| 473/473 [26:04<00:00,  3.31s/it]  


In [15]:
th.save(dict, 'dict/{}_dict.pt'.format(dict_lang))

OverflowError: serializing a string larger than 4 GiB requires pickle protocol 4 or higher

In [10]:
dict = th.load('dict/{}_dict.pt'.format(dict_lang))

## Build Features

In [21]:
target_lang = 'zh-CN'

In [22]:
target_lang_punctuation = punctuation + '„“”‚’«»ʿ‹›‘'

In [23]:
target_lang_root = os.path.join(root, target_lang)
target_lang_df = load_df_from_tsv(os.path.join(target_lang_root, 'validated.tsv'))

In [25]:
os.path.join(target_lang_root, 'validated.tsv')

'/mnt/data/siqiouyang/datasets/cv-corpus-9.0-2022-04-27/zh-CN/validated.tsv'

In [20]:
# Optionally resample the audio to 16kHz
# os.makedirs(os.path.join(target_lang_root, '16kHz'))
# for path in tqdm(target_lang_df['path']):
#     audio_path = os.path.join(target_lang_root, 'clips/{}'.format(path))
#     waveform, sample_rate = torchaudio.load(audio_path)
#     resampled_waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
#     torchaudio.save(os.path.join(target_lang_root, '16kHz/{}.wav'.format(path[:-4])), resampled_waveform, sample_rate=16000)

  0%|          | 796/765790 [00:18<4:52:30, 43.59it/s]


KeyboardInterrupt: 

In [26]:
target_lang_sentences = target_lang_df['sentence'].tolist()

In [27]:
target_tokenized_sentences = []

# specifically for zh-CN
pattern = re.compile(",|\?|・|\u3002|\uff1f|\uff01|\uff0c|\u3001|\uff1b|\uff1a|\u201c|\u201d|\u2018|\u2019|\uff08|\uff09|\u300a|\u300b|\u3008|\u3009|\u3010|\u3011|\u300e|\u300f|\u300c|\u300d|\ufe43|\ufe44|\u3014|\u3015|\u2026|\u2014|\uff5e|\ufe4f|\uffe5")
with open('/home/siqiouyang/work/projects/WMSeg/cv9_zh.txt.tok', 'r') as r:
    for line in r.readlines():
        line = line.strip()
        if line != '':
            line = pattern.sub(' ', line)
            tokens = [tok for tok in line.split(' ') if tok != '']
            target_tokenized_sentences.append(tokens)

# for sent in target_lang_sentences:
#     sent = ''.join(c for c in sent if c not in target_lang_punctuation)
#     tokens = [tok for tok in sent.split(' ') if tok != '']
#     target_tokenized_sentences.append(tokens)

In [47]:
for path, tokens in zip(target_lang_df['path'], target_tokenized_sentences):
    id = path[:-4]
    with open(os.path.join(target_lang_root, '16kHz', '{}.txt'.format(id)), 'w') as w:
        w.write(' '.join(tokens))

In [30]:
target_gram = 1

In [31]:
all_target_tokens = set()
target_freq = defaultdict(int)
for tokens in target_tokenized_sentences:
    for idx in range(len(tokens)):
        for g in range(min(target_gram, len(tokens) - idx)):
            token = ' '.join(tokens[idx: idx + g + 1])
            all_target_tokens.add(token)
            target_freq[token] += 1

In [32]:
n_target_token = sum(target_freq.values())
n_target_token

377439

In [33]:
target_embs = model.encode(list(all_target_tokens), device=device)

In [34]:
target_embs.shape, 

((60168, 768),)

In [35]:
k = 5
D, I = dict.search(target_embs, k=k)

In [36]:
mask = D > 0.9

In [37]:
all_target_tokens = list(all_target_tokens)
n_target_found = 0
for idx in range(len(all_target_tokens)):
    if mask[idx][0]:
        n_target_found += target_freq[all_target_tokens[idx]]

In [38]:
n_target_found / n_target_token

0.4364546324041766

In [39]:
tokens = np.array(all_target_tokens)

In [40]:
match = {}
for idx in range(tokens.shape[0]):
    if mask[idx][0]:
        choices = []
        for j in range(k):
            if mask[idx][j]:
                choices.append(dict.id2tok[I[idx][j]])
        match[tokens[idx]] = choices

In [41]:
match

{'音质': ['Sound'],
 '简便': ['Easy'],
 '正确性': ['Richtigkeit', 'Correctness'],
 '有效': ['Effective'],
 '警察厅': ['Polizeiamt', 'Polizeidepartement'],
 '接收机': ['Receiver'],
 '旗子': ['Flaggen', 'Flag'],
 '隐私': ['Privacy', 'Privatsphäre', 'Datenschutz'],
 '年数': ['Jahre', 'Year', 'Jahr', 'Jahreszahl', 'years'],
 '模板': ['Template', 'Templates'],
 '小心': ['Beware', 'Vorsichtig', 'Vorsicht', 'Aufgepasst'],
 '妇': ['Frau', 'femme', 'Woman'],
 '圣米切尔': ['SaintMichel'],
 '提供者': ['Provider', 'Anbieter'],
 '赞助商': ['Sponsoren'],
 '值': ['Value', 'Wert', 'values', 'Arvo'],
 '双人': ['Double', 'Doppel', 'Doubles'],
 '照亮': ['Auflicht'],
 '宁波市': ['Ningbo'],
 '帐户': ['Konto', 'Account'],
 '其它': ['Other', 'Andere', 'Anderweitige', 'Sonstige', 'Alte'],
 '大王': ['King'],
 '姐弟': ['Sister', 'Schwester'],
 '它': ['it', 'It'],
 '外面': ['Außen', 'Outside', 'Draußen', 'außen', 'Aussen'],
 '创立人': ['Gründer'],
 '外来语': ['Fremdsprache'],
 '祖母': ['Großmutter'],
 '沙滩': ['Beach'],
 '滩上': ['auf Strand', 'in der Beach'],
 '申请人': ['Antrags

## Mix Audio

In [42]:
new_df = pd.DataFrame(columns=['id', 'audio', 'n_frames', \
    'src_text', 'speaker', 'src_lang', 'tgt_lang'])

In [43]:
silence = th.zeros(1600)

In [73]:
ids = []
audios = []
n_frames = []
src_texts = []
speakers = []
src_langs = []
for idx in tqdm(range(len(target_lang_df))):
    id = target_lang_df['path'][idx][:-4]
    audio = id + '.wav'
    audio_path = os.path.join(target_lang_root, '16kHz/{}.wav'.format(id))
    info = torchaudio.info(audio_path)
    n_frame = info.num_frames
    src_text = target_lang_df['sentence'][idx]
    speaker = target_lang_df['client_id'][idx]

    ids.append(id)
    audios.append(audio)
    n_frames.append(n_frame)
    src_texts.append(src_text)
    speakers.append(speaker)
    src_langs.append(target_lang + '-cv')

100%|██████████| 48516/48516 [00:01<00:00, 30891.39it/s]


In [48]:
# mfa forced align

In [65]:
ids = []
audios = []
n_frames = []
src_texts = []
speakers = []
src_langs = []
for o_idx in tqdm(range(len(target_lang_df))):
    id = target_lang_df['path'][o_idx][:-4]
    audio = id + '.wav'
    audio_path = os.path.join(target_lang_root, '16kHz/{}.wav'.format(id))
    info = torchaudio.info(audio_path)
    n_frame = info.num_frames
    src_text = target_lang_df['sentence'][o_idx]
    speaker = target_lang_df['client_id'][o_idx]

    grid_path = os.path.join(target_lang_root, '16kHz/align_wmseg/{}.TextGrid'.format(id))
    if os.path.exists(grid_path):
        grid = textgrids.TextGrid(grid_path)
        tokens = [tok for tok in grid['words'] if tok.text != '']

        transcript = src_text

        # transcript = ''.join(c for c in transcript if c not in target_lang_punctuation)
        # ori_tokens = [tok for tok in transcript.split(' ') if tok != '']
        ori_tokens = target_tokenized_sentences[o_idx]

        string = ' '.join([tok.text.lower() for tok in tokens])
        ori_string = ' '.join([tok.lower() for tok in ori_tokens])

        if string == ori_string:
            
            waveform = torchaudio.load(audio_path)[0][0]
            n_frame = waveform.size(0)
            frame_rate = 16000
            duration = n_frame / frame_rate

            new_waveform = []
            new_tokens = []
            
            intervals = (np.array([(tok.xmin, tok.xmax) for tok in tokens]) * frame_rate).astype(int)

            average_volume = 0.
            word_duration = 0.
            for xmin, xmax in intervals:
                word_duration += xmax - xmin
                average_volume += waveform[xmin : xmax].abs().sum()
            average_volume /= word_duration

            mask = []
            orig_intervals = []
            mixed_intervals = []
            last_end = idx = 0
            prefix_length = 0.
            while idx < len(intervals):
                xmin, xmax = intervals[idx]
                new_waveform.append(waveform[last_end : xmin])
                prefix_length += xmin - last_end
                replace = False
                for g in range(min(target_gram, len(intervals) - idx), 0, -1):
                    token = ' '.join(ori_tokens[idx : idx + g])
                    if token in match:
                        replace = True
                        token_zh = random.choice(match[token])
                        selectable_waveforms = dict.waveforms[dict.tok2id[token_zh]]

                        selected_waveform = random.choice(selectable_waveforms)
                        selected_volume = selected_waveform.abs().mean()
                        selected_waveform *= average_volume / selected_volume

                        sil_selected_waveform = th.cat([silence, selected_waveform, silence], dim=0)

                        new_waveform.append(sil_selected_waveform)
                        new_tokens.append(token_zh)
                        last_end = intervals[idx + g - 1][1]
                        idx += g
                        mask.append(True)
                        orig_intervals.append((xmin, last_end))
                        mixed_intervals.append((
                            prefix_length + silence.size(0), 
                            prefix_length + silence.size(0) + selected_waveform.size(0)
                        ))
                        prefix_length += silence.size(0) * 2 + selected_waveform.size(0)
                        break
                if not replace:
                    new_waveform.append(waveform[xmin : xmax])
                    new_tokens.append(ori_tokens[idx])
                    last_end = xmax
                    idx += 1
                    mask.append(False)
                    orig_intervals.append((xmin, xmax))
                    mixed_intervals.append((
                        prefix_length,
                        prefix_length + xmax - xmin
                    ))
                    prefix_length += xmax - xmin

            new_waveform = th.cat(new_waveform, dim=0).unsqueeze(0)
            new_audio_path = os.path.join(target_lang_root, '16kHz/{}-mixed.wav'.format(id))
            torchaudio.save(new_audio_path, new_waveform, sample_rate=16000)

            mask = th.tensor(mask, dtype=bool)
            orig_intervals = th.tensor(orig_intervals, dtype=float) / n_frame
            mixed_intervals = th.tensor(mixed_intervals, dtype=float) / new_waveform.size(1)
            match_path = os.path.join(target_lang_root, '16kHz/{}.pt'.format(id))
            th.save([mask, orig_intervals, mixed_intervals], match_path)

100%|██████████| 48516/48516 [02:27<00:00, 328.93it/s]


In [74]:
new_df['id'] = ids
new_df['audio'] = audios
new_df['n_frames'] = n_frames
new_df['src_text'] = src_texts
new_df['speaker'] = speakers
new_df['src_lang'] = src_langs

## Forward Translation

In [67]:
device2 = 'cuda:1'

In [68]:
mbart = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device2)
mbart_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")

In [69]:
mbart_tokenizer.src_lang = "zh_CN"

In [70]:
mbart_tokenizer('你好')

{'input_ids': [250025, 6, 124084, 2], 'attention_mask': [1, 1, 1, 1]}

In [76]:
batch_size = 50
tgt_texts = []
with th.no_grad():
    for idx in tqdm(range(0, len(src_texts), batch_size)):
        encoded_src_text = mbart_tokenizer(src_texts[idx : idx + batch_size], padding=True, return_tensors="pt").to(device2)
        generated_tokens = mbart.generate(**encoded_src_text)
        batch_tgt_text = mbart_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        tgt_texts.extend(batch_tgt_text)

100%|██████████| 971/971 [36:41<00:00,  2.27s/it] 


In [16]:
tgt_texts = []
for idx in range(8):
    t = th.load('translation/de_DE/{}.pt'.format(idx))
    tgt_texts.extend(t)

In [83]:
new_df['tgt_lang'] = ['en'] * len(new_df)
new_df['tgt_text'] = tgt_texts

In [85]:
save_df_to_tsv(new_df, os.path.join(target_lang_root, 'train_st_{}-cv_en.tsv'.format(target_lang)))