In [1]:
import os
import re
import _pickle as cPickle
import gc
import random
from tqdm import tqdm
import numpy as np
import torch

import math
from itertools import groupby, chain
BATCH_SIZE = 16


cache_dir = "/data4/yoomcache"
model_cache_dir = os.path.join(cache_dir, 'huggingfaces')
data_cache_dir = os.path.join(cache_dir, 'datasets')
checkpoint_dir = os.path.join(cache_dir, 'checkpoint')

seed = 0
random.seed(0)
np.random.seed(seed)
torch.manual_seed(seed)

import IPython
import seaborn as sns
sns.set(rc = {'figure.figsize':(16,9)})

In [2]:
from wav2vec2GPTwCTC import *
from configuration_wav2vec2gpt import Wav2Vec2GPTConfig

from transformers import GPT2Tokenizer, AddedToken, BertTokenizer

In [3]:
wav2vec_pretrained = "facebook/wav2vec2-base"
# wav2vec_pretrained = "facebook/wav2vec2-base-960h"
gpt_pretrained = "gpt2"

# Should aware that pad_token_id is used to compute CTC loss, 
# so pad_token configuration for both tokenizer and model should be the same
args = {
    
    'pad_token': "<|endoftext|>", 'pad_token_id': 50256,  # 'Ġ' ... 220,
    'unk_token': "<|endoftext|>", 'unk_token_id': 50256,
    'bos_token': "<|endoftext|>", 'bos_token_id': 50256,
    'eos_token': "<|endoftext|>", 'eos_token_id': 50256,
    'ctc_loss_reduction': 'mean',
    
    
    'select_random': True,
    'loss_ver': 'ctc',  # 'ctc-ce', 'ce-ctc', 'ctc', 'ce'
    
    
    'add_adapter': True,
    'num_adapter_layers': 3,
    'output_hidden_size': [128, 256, 512, 256],
    'adapter_kernel_size': [4, 3, 3, 3, 3], 
    'adapter_stride':      [2, 2, 2, 1, 1],
    'adapter_padding':     [4, 0, 0, 0, 0],
    'adapter_bias': False,
    
    
    # Wav2Vec
    'hidden_dropout': 0.0,
    'activation_dropout': 0.0,
    'attention_dropout': 0.0,
    'feat_proj_dropout': 0.0,
    'feat_quantizer_dropout': 0.0,
    'final_dropout': 0.2,
    'layerdrop': 0.0,
    # GPT
    'resid_pdrop': 0.0,
    'embd_pdrop': 0.0,
    'attn_pdrop': 0.0,
    
    
    
}

config = Wav2Vec2GPTConfig(**args)

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained(gpt_pretrained,
                                          cache_dir=model_cache_dir,
                                          **args)

In [5]:
data_dir = '/data4/TTS/'
data_fname = 'VCTK-Corpus/dataset-vctk-16k.pkl'
# data_fname = 'LibriTTS/train-clean-100.pkl'
# data_fname = 'LibriTTS/dev-clean.pkl'

dataset_ratio = 0.5
#########################################################

with open(os.path.join(data_dir, data_fname), 'rb') as f:
    gc.disable()
    dataset = cPickle.load(f)
    gc.enable()
print('entire dataset length: {}'.format(len(dataset['text'])))


dataset_size = int(len(dataset['text']) * dataset_ratio)
for k in dataset.keys():
    if k == 'sample_rate': continue
    del dataset[k][dataset_size:]
print('dataset length: {}'.format(dataset_size))


min_audio_length, max_audio_length = 1e6, 0
for arr in dataset['audio_array']:
    if len(arr) > max_audio_length:
        max_audio_length = len(arr)
    if len(arr) < min_audio_length:
        min_audio_length = len(arr)
print('maximum audio length: {} ({} sec)'.format(max_audio_length, 
                                                 max_audio_length / dataset['sample_rate']))
print('minimum audio length: {} ({} sec)'.format(min_audio_length, 
                                                 min_audio_length / dataset['sample_rate']))

entire dataset length: 44070
dataset length: 22035
maximum audio length: 308533 (19.2833125 sec)
minimum audio length: 13715 (0.8571875 sec)


In [6]:
_RE_REPLACE_PARENTHESIS = (r'[(){}_\[\]]', '')
_RE_REPLACE_QUESTIONMARK = ('\s\?', '?')
_RE_REPLACE_EXCLAMATIONMARK = ('\s\!', '!')
_RE_REPLACE_DOT = ('\s\.', '.')
_RE_COMBINE_WHITESPACE = (r'\s+', ' ')

re_list = [
    _RE_REPLACE_PARENTHESIS,
    _RE_COMBINE_WHITESPACE, 
    _RE_REPLACE_QUESTIONMARK,
    _RE_REPLACE_EXCLAMATIONMARK,
    _RE_REPLACE_DOT,
]

dataset['retext'] = list()

for txt in dataset['text']:
    retxt = txt
    retxt = retxt.lstrip(' .,?!')
    if retxt[:-1] in ['"', "'"]:
        retxt.strip('\"\'')    
    for pattern, repl in re_list:
        retxt = re.sub(pattern, repl, retxt).strip()
    
    dataset['retext'].append(retxt)
    

In [7]:
dataset['text'] = tokenizer(dataset['retext'],
                            return_tensors="pt",
                            padding='longest'  # VCTK: 42, libritts: 92
                            )

print(dataset['text']['attention_mask'].shape)

torch.Size([22035, 41])


In [8]:
split_ratio = (0.8, 0.9)
indices = np.arange(dataset_size)
np.random.shuffle(indices)

train_idx = indices[:int(dataset_size * split_ratio[0])]
val_idx = indices[int(dataset_size * split_ratio[0]):int(dataset_size * split_ratio[1])]
test_idx = indices[int(dataset_size * split_ratio[1]):]

In [11]:
dataset['text']

{'input_ids': tensor([[ 2990,  1392,   262,  ..., 50256, 50256, 50256],
        [ 1135,   836,   470,  ..., 50256, 50256, 50256],
        [ 1026,   318,   900,  ..., 50256, 50256, 50256],
        ...,
        [  464, 12929,   318,  ..., 50256, 50256, 50256],
        [ 2202,  3909,    11,  ..., 50256, 50256, 50256],
        [ 1532,   484,  8288,  ..., 50256, 50256, 50256]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [12]:
print(dataset['text'].input_ids[train_idx])

tensor([[47230,   468,  3402,  ..., 50256, 50256, 50256],
        [28065,    11,   339,  ..., 50256, 50256, 50256],
        [   40,   815,   892,  ..., 50256, 50256, 50256],
        ...,
        [ 2990,   550,   645,  ..., 50256, 50256, 50256],
        [  464,  4495,   318,  ..., 50256, 50256, 50256],
        [ 1544,  3636,   470,  ..., 50256, 50256, 50256]])


In [17]:
unique_train = dataset['text'].input_ids[train_idx].unique()
unique_train, len(unique_train)

(tensor([    0,     6,    11,  ..., 50147, 50238, 50256]), 6132)

In [34]:
unique_valid = dataset['text'].input_ids[val_idx].unique()
unique_valid, len(unique_valid)

(tensor([    0,    11,    12,  ..., 49685, 49984, 50256]), 2950)

In [35]:
len(set(unique_valid.numpy()) - set(unique_train.numpy()))

161

In [37]:
dataset['text'].input_ids[train_idx]

tensor([[47230,   468,  3402,  ..., 50256, 50256, 50256],
        [28065,    11,   339,  ..., 50256, 50256, 50256],
        [   40,   815,   892,  ..., 50256, 50256, 50256],
        ...,
        [ 2990,   550,   645,  ..., 50256, 50256, 50256],
        [  464,  4495,   318,  ..., 50256, 50256, 50256],
        [ 1544,  3636,   470,  ..., 50256, 50256, 50256]])

In [38]:
from collections import Counter

In [53]:
for key in set(unique_valid.numpy()) - set(unique_train.numpy()):
    print(tokenizer.decode(key), )

Counter({47230: 44,
         468: 766,
         3402: 18,
         262: 6631,
         835: 169,
         13: 16635,
         50256: 566396,
         28065: 23,
         11: 3685,
         339: 707,
         3767: 23,
         284: 3096,
         1394: 66,
         257: 4226,
         1877: 17,
         7034: 12,
         40: 1840,
         815: 292,
         892: 208,
         523: 280,
         1165: 252,
         10128: 12,
         9918: 5,
         7415: 117,
         1234: 81,
         14802: 7,
         1986: 18,
         319: 742,
         12928: 9,
         6385: 48,
         788: 140,
         37898: 41,
         423: 1342,
         1043: 81,
         326: 1077,
         340: 1728,
         318: 3603,
         407: 1295,
         14580: 92,
         475: 296,
         1006: 41,
         7861: 41,
         416: 529,
         6290: 207,
         49253: 78,
         543: 149,
         5640: 42,
         25435: 43,
         1135: 1171,
         1392: 176,
         651: 187,
     

In [55]:
c = Counter(dataset['text'].input_ids.numpy().reshape(-1).tolist())

In [56]:
for key in set(unique_valid.numpy()) - set(unique_train.numpy()):
    print(tokenizer.decode(key), c[key])

 alliance 1
anting 1
 bomber 1
 fairy 1
ams 1
liament 1
 frank 1
 entertain 1
 distortion 1
JOHN 1
 reap 1
 Services 1
iance 1
 sought 2
 telling 1
External 1
 backed 1
appers 2
irement 1
 Leicester 1
 monitor 1
ire 1
 mystical 1
 dealing 1
ash 1
Ret 1
 Neil 3
der 1
 Chris 1
keeper 1
 multitude 1
 sow 1
Law 1
 skin 1
 abusers 1
Web 1
itors 1
ying 1
 inherited 1
 comedian 1
 intentions 2
 flashy 2
 Citizens 1
 urge 2
 UN 1
 Ly 1
 bigger 3
Lots 1
 catching 1
 counts 1
 signings 1
 combat 1
 Tibetan 1
 Cooper 3
 Today 2
 amended 2
 aud 1
 alter 1
 Virgin 1
anc 1
 disappear 1
Cont 1
thing 1
 balanced 1
 district 1
 thinks 3
 uncanny 1
Home 1
 population 1
vine 1
 opportun 1
PR 1
 lab 2
elling 2
 Meat 2
ummy 1
 Lions 1
 cat 1
 attacked 1
 paths 1
 escaped 1
 Augusta 1
Sports 1
 Red 2
 insulting 1
 accused 1
 shortest 2
 west 1
 aims 1
 inherent 1
 gra 1
 satisfactory 4
Ir 1
 fl 1
is 1
icking 1
Note 1
om 3
 amongst 1
 Belgium 2
 resolutions 1
 sevent 1
 treating 1
rollers 1
 declare 1
 uphea

In [59]:
dataset['retext']

['They got the job done.',
 "We don't ask for much.",
 'It is set in Paris.',
 'Mentally, you have to be tough.',
 'It is not really used by many people.',
 "I feel it's time to make the switch.",
 'We owe it to the public of Strathclyde.',
 'She can scoop these things into three red bags, and we will go meet her Wednesday at the train station.',
 'We keep a distance.',
 'I was in tears in the street.',
 'My position on the euro is quite clear.',
 'People look, but no one ever finds it.',
 'Eriksson, though, was having none of it.',
 'The whole thing has been a nightmare.',
 "It's time to say enough is enough.",
 'Those Were The Days, indeed.',
 'Yet what is supposed to be said?',
 'Sadly, in this case, that was not so.',
 "We've not won anything yet.",
 'They said we were out of touch.',
 'And the winners are.',
 "I can't believe we didn't win that game.",
 'For us, that decision was hard to understand.',
 'Awareness in Europe of Scotland is high.',
 'It is the whole package.',
 'The 

In [100]:
tokenizer2 = BertTokenizer.from_pretrained('bert-base-cased',
                                          cache_dir=model_cache_dir,)

In [105]:
text_bert = tokenizer2(dataset['retext'])

maxlen = 0
for t in text_bert:
    if maxlen < len(t):
        maxlen = len(t)
    print(maxlen)

KeyboardInterrupt: 

In [102]:
c2 = Counter(chain(*text_bert.input_ids))
c2

Counter({101: 22035,
         1220: 889,
         1400: 230,
         1103: 8294,
         2261: 136,
         1694: 101,
         119: 20760,
         102: 22035,
         1284: 1499,
         1274: 122,
         112: 2662,
         189: 755,
         2367: 18,
         1111: 1590,
         1277: 134,
         1135: 2629,
         1110: 4435,
         1383: 71,
         1107: 2323,
         2123: 14,
         16725: 5,
         1193: 188,
         117: 4676,
         1128: 402,
         1138: 1688,
         1106: 3880,
         1129: 1549,
         8035: 48,
         1136: 1619,
         1541: 167,
         1215: 97,
         1118: 664,
         1242: 147,
         1234: 303,
         146: 2556,
         1631: 121,
         1122: 2143,
         188: 1338,
         1159: 396,
         1294: 149,
         6878: 12,
         12972: 2,
         1470: 88,
         1104: 2956,
         1457: 24,
         7625: 14,
         1324: 14,
         1665: 57,
         2007: 14,
         1153: 357,


In [103]:
unique_train2 = text_bert.input_ids[train_idx].unique()
unique_valid2 = text_bert.input_ids[val_idx].unique()
unique_train2, len(unique_train2), unique_valid2, len(unique_valid2)

TypeError: only integer scalar arrays can be converted to a scalar index

In [107]:
tokenizer2.vocab_size

28996

In [115]:
idx = 11

print(text_bert.input_ids[idx])
print(tokenizer2.decode(text_bert.input_ids[idx]))

[101, 2563, 1440, 117, 1133, 1185, 1141, 1518, 4090, 1122, 119, 102]
[CLS] People look, but no one ever finds it. [SEP]
