# Notebook for preprocessing Wikipedia (Japanese) dataset

### Initilizing phonemizer and tokenizer

In [1]:
import yaml

config_path = "Configs/config_ja.yml" # you can change it to anything else
config = yaml.safe_load(open(config_path))

In [2]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(config['dataset_params']['tokenizer']) # you can use any other tokenizers if you want to



In [3]:
tokenizer.tokenize("おはようございます。") # just to make sure it works

['おはよう', 'ござい', 'ます', '。']

In [4]:
from utils import hiragana2IPA, phonemize

# hiragana2IPA("おはようございます。")
phonemize("おはようございます。", tokenizer)

{'input_ids': [26750, 30714, 12995, 385],
 'phonemes': ['オハヨウ', 'ゴザイ', 'マス', '。']}

### Process dataset

In [5]:
from datasets import load_dataset
dataset = load_dataset("wiki40b", "ja", split="train") # you can use other version of this dataset

In [6]:
dataset

Dataset({
    features: ['wikidata_id', 'text', 'version_id'],
    num_rows: 745392
})

In [7]:
print(dataset[4]['text'])

b'\n_START_ARTICLE_\n\xe5\x8d\x97\xe9\x83\xa8\xe7\x85\x8e\xe9\xa4\x85\n_START_SECTION_\n\xe6\xa6\x82\xe8\xa6\x81\n_START_PARAGRAPH_\n\xe5\x85\x83\xe3\x80\x85\xe3\x81\xaf\xe5\x85\xab\xe6\x88\xb8\xe8\x97\xa9\xe3\x81\xa7\xe4\xbd\x9c\xe3\x82\x89\xe3\x82\x8c\xe3\x81\x9f\xe9\x9d\x9e\xe5\xb8\xb8\xe9\xa3\x9f\xe3\x81\xa7\xe3\x81\x82\xe3\x82\x8a\xe3\x80\x81\xe5\xb0\x8f\xe9\xba\xa6\xe7\xb2\x89\xe3\x82\x92\xe6\xb0\xb4\xe3\x81\xa7\xe7\xb7\xb4\xe3\x81\xa3\xe3\x81\xa6\xe5\x86\x86\xe5\xbd\xa2\xe3\x81\xae\xe5\x9e\x8b\xe3\x81\xab\xe5\x85\xa5\xe3\x82\x8c\xe3\x81\xa6\xe5\xa0\x85\xe3\x81\x8f\xe7\x84\xbc\xe3\x81\x84\xe3\x81\xa6\xe4\xbd\x9c\xe3\x82\x89\xe3\x82\x8c\xe3\x82\x8b\xe3\x80\x82\xe7\xb8\x81\xe3\x81\xab\xe3\x80\x8c\xe3\x81\xbf\xe3\x81\xbf\xe3\x80\x8d\xe3\x81\xa8\xe5\x91\xbc\xe3\x81\xb0\xe3\x82\x8c\xe3\x82\x8b\xe8\x96\x84\xe3\x81\x8f\xe3\x82\xab\xe3\x83\xaa\xe3\x83\x83\xe3\x81\xa8\xe3\x81\x97\xe3\x81\x9f\xe9\x83\xa8\xe5\x88\x86\xe3\x81\x8c\xe3\x81\x82\xe3\x82\x8b\xe3\x81\xae\xe3\x81\x8c\xe7\x89\xb9\xe

In [34]:
import ast

def decode_text(sample_text):
    """
    Chuyển đổi chuỗi có dạng b'...' thành chuỗi UTF-8.
    
    Args:
        sample_text (str): Chuỗi có thể chứa biểu diễn bytes dạng string.
        
    Returns:
        str: Chuỗi đã giải mã UTF-8, hoặc giữ nguyên nếu không cần giải mã.
    """
    if isinstance(sample_text, str):
        if (sample_text.startswith("b'") and sample_text.endswith("'")) or (sample_text.startswith('b"') and sample_text.endswith('"')):
            try:
                sample_text = ast.literal_eval(sample_text)  # Chuyển từ string thành bytes thực sự
                sample_text = sample_text.decode("utf-8")   # Giải mã bytes thành UTF-8
            except (SyntaxError, ValueError):
                pass  # Trả về chuỗi gốc nếu gặp lỗi
    return sample_text

decoded_text = decode_text(dataset[3024]['text'])
decoded_text

'\n_START_ARTICLE_\n弘前市立大和沢小学校\n_START_SECTION_\n学区\n_START_PARAGRAPH_\n大和沢、一野渡、狼森'

In [35]:
def clean_wiki_text(text):
    text = ''.join(text.split('_START_PARAGRAPH_')[1:])
    markers = ["_START_ARTICLE_", "_START_SECTION_", "_START_PARAGRAPH_", "_NEWLINE_", "_START_HEADING_", 
               "_START_BULLET_", "_START_LIST_", "_START_TABLE_", "_START_CAPTION_", "_START_IMAGE_"]
    for marker in markers:
        text = text.replace(marker, "")
    return text.strip()

decoded_text = clean_wiki_text(decoded_text)
decoded_text

'大和沢、一野渡、狼森'

In [36]:
def preprocess_text(text):
    text = decode_text(text)
    text = clean_wiki_text(text)
    return text

In [37]:
preprocess_text(dataset[3024]['text'])

'大和沢、一野渡、狼森'

In [38]:
phonemize(preprocess_text(dataset[3024]['text']), tokenizer)

{'input_ids': [1, 384, 1, 384, 1, 385],
 'phonemes': ['オオワサワ', '、', 'イチノワタリ', '、', 'オイノモリ', '。']}

In [5]:
root_directory = "./wiki_phoneme" # set up root directory for multiprocessor processing

In [6]:
import os
import gc

num_shards = 1000
processed_shards = []
def process_shard(i):
    directory = root_directory+"/shard_" + str(i)
    if os.path.exists(directory):
        print("Shard %d already exists!" % i)
        return
    print('Processing shard %d ...' % i)
    try:
        shard = dataset.shard(num_shards=num_shards, index=i)
        processed_dataset = shard.map(lambda t: phonemize(preprocess_text(t['text']), tokenizer), remove_columns=['text'])
        if not os.path.exists(directory):
            os.makedirs(directory)
        processed_dataset.save_to_disk(directory)
        processed_shards.append(i)
        print(f'Shard {i} processed successfully.')
        del processed_dataset  # Free memory
        gc.collect()
    except Exception as e:
        print(f'Error processing shard {i}: {e}')

In [8]:
from pebble import ProcessPool
from concurrent.futures import TimeoutError

In [9]:
import os

# Lấy số core của CPU
num_cores = os.cpu_count()
print(f"Số core của CPU là: {num_cores}")

Số core của CPU là: 80


#### Note: You will need to run the following cell multiple times to process all shards because some will fail. Depending on how fast you process each shard, you will need to change the timeout to a longer value to make more shards processed before being killed.

In [None]:
max_workers = num_cores # change this to the number of CPU cores your machine has 

with ProcessPool(max_workers=max_workers) as pool:
    pool.map(process_shard, range(num_shards), timeout=120)

Processing shard 0 ...
Processing shard 1 ...
Processing shard 2 ...
Processing shard 3 ...Processing shard 4 ...

Processing shard 5 ...
Processing shard 6 ...Processing shard 7 ...
Processing shard 8 ...Processing shard 9 ...

Processing shard 10 ...
Processing shard 11 ...
Processing shard 12 ...
Processing shard 13 ...
Processing shard 14 ...

Processing shard 15 ...Processing shard 16 ...

Processing shard 17 ...Processing shard 18 ...

Processing shard 19 ...Processing shard 20 ...

Processing shard 21 ...
Processing shard 22 ...Processing shard 23 ...Processing shard 24 ...

Processing shard 25 ...
Processing shard 26 ...
Processing shard 27 ...

Processing shard 28 ...
Processing shard 29 ...
Processing shard 30 ...
Processing shard 32 ...Processing shard 33 ...Processing shard 34 ...

Processing shard 35 ...
Processing shard 36 ...Processing shard 37 ...


Processing shard 39 ...Processing shard 40 ...
Processing shard 41 ...

Processing shard 42 ...Processing shard 43 ...
Pro

### Collect all shards to form the processed dataset

In [8]:
import os
output = [dI for dI in os.listdir(root_directory) if os.path.isdir(os.path.join(root_directory,dI))]

from datasets import load_from_disk, concatenate_datasets

datasets = []

for o in output:
    directory = f"{root_directory}/{o}"
    try:
        shard = load_from_disk(directory)
        datasets.append(shard)
        print("%s loaded" % o)
    except:
        continue

shard_979 loaded
shard_990 loaded
shard_197 loaded
shard_417 loaded
shard_782 loaded
shard_650 loaded
shard_654 loaded
shard_268 loaded
shard_784 loaded
shard_495 loaded
shard_340 loaded
shard_996 loaded
shard_336 loaded
shard_141 loaded
shard_711 loaded
shard_651 loaded
shard_129 loaded
shard_188 loaded
shard_997 loaded
shard_880 loaded
shard_883 loaded
shard_663 loaded
shard_250 loaded
shard_323 loaded
shard_229 loaded
shard_520 loaded
shard_763 loaded
shard_515 loaded
shard_555 loaded
shard_848 loaded
shard_605 loaded
shard_359 loaded
shard_760 loaded
shard_465 loaded
shard_753 loaded
shard_193 loaded
shard_363 loaded
shard_289 loaded
shard_117 loaded
shard_118 loaded
shard_464 loaded
shard_414 loaded
shard_709 loaded
shard_51 loaded
shard_665 loaded
shard_678 loaded
shard_457 loaded
shard_579 loaded
shard_801 loaded
shard_817 loaded
shard_471 loaded
shard_257 loaded
shard_274 loaded
shard_155 loaded
shard_451 loaded
shard_908 loaded
shard_112 loaded
shard_905 loaded
shard_287 loade

In [9]:
dataset = concatenate_datasets(datasets)
dataset.save_to_disk(config['data_folder'])
print('Dataset saved to %s' % config['data_folder'])

Saving the dataset (0/13 shards):   0%|          | 0/684501 [00:00<?, ? examples/s]

Dataset saved to wiki40b_ja.processed


In [10]:
# check the dataset size
dataset

Dataset({
    features: ['wikidata_id', 'version_id', 'input_ids', 'phonemes'],
    num_rows: 684501
})

In [11]:
dataset[0]['wikidata_id'], dataset[0]['version_id']

("b'Q11575576'", "b'14594515120598057630'")

In [14]:
print(dataset[0]['input_ids'])
print(tokenizer.decode(dataset[0]['input_ids']))

[14406, 15, 1, 23, 32425, 15, 458, 1, 384, 1129, 643, 4035, 1128, 601, 2002, 1, 1, 15, 220, 15, 1129, 643, 4035, 938, 1128, 656, 2002, 1, 1, 24, 465, 384, 12500, 464, 12973, 1781, 385, 13199, 14636, 385, 15135, 4092, 12758, 385, 19881, 13799, 13070, 12533, 12801, 2127, 384, 13199, 2011, 461, 1, 385, 1, 6104, 17319, 19164, 1, 384, 935, 1, 15450, 13868, 12792, 6341, 458, 12510, 385, 1129, 643, 4035, 1422, 1128, 604, 2002, 1, 384, 13199, 2011, 14683, 12909, 461, 12500, 1, 12488, 16519, 384, 1019, 464, 1, 593, 13199, 14636, 458, 441, 456, 14483, 385, 13199, 14636, 461, 12493, 449, 2734, 464, 14697, 465, 607, 1128, 660, 3181, 457, 384, 12647, 12961, 1, 457, 12517, 384, 12565, 461, 12594, 456, 484, 15874, 1, 464, 1, 14636, 457, 12485, 385, 23, 1, 500, 22993, 458, 656, 1128, 604, 3181, 457, 14322, 1, 461, 12958, 441, 449, 15134, 29869, 24, 1129, 643, 4035, 660, 1128, 1422, 2002, 464, 12548, 13444, 384, 1, 18093, 464, 2849, 6046, 2734, 461, 465, 935, 20623, 461, 1, 441, 449, 385, 13199, 14636,

In [None]:
print(dataset[0]['phonemes'])

['タナカ', ' ', 'トシフミ', '(', 'タナカ', ' ', 'ト', 'シブミ', '、', 'セン', 'キュウ', 'ヒャク', 'ジュウ', 'イチ', 'ネン', 'ジュウイチガツ', 'ココノカ', ' ', '−', ' ', 'セン', 'キュウ', 'ヒャク', 'ハチ', 'ジュウ', 'ニ', 'ネン', 'ジュウニガツ', 'ハツカ', ')', 'ハ', '、', 'ニッポン', 'ノ', 'セイジ', 'カ', '。', 'ホッカイドウ', 'チジ', '。', 'アオモリ', 'ケン', 'シュッシン', '。', 'ケイレキ', 'キュウシュウ', 'テイコク', 'ダイガク', 'ソツギョウ', 'ゴ', '、', 'ホッカイドウ', 'チョウ', 'ニ', 'ニュウチョウ', '。', 'リンセイ', 'ブ', 'シンリン', 'ドボク', 'カカリチョウ', '、', 'ゼン', 'ドウチョウ', 'ショクイン', 'クミアイ', 'イイン', 'チョウ', 'ト', 'ナル', '。', 'セン', 'キュウ', 'ヒャク', 'ヨン', 'ジュウ', 'ナナ', 'ネン', 'シガツ', '、', 'ホッカイドウ', 'チョウ', 'チョウカン', 'センキョ', 'ニ', 'ニッポン', 'シャカイトウ', 'カラ', 'シュツバ', '、', 'ハツ', 'ノ', 'コウセン', '・', 'ホッカイドウ', 'チジ', 'ト', 'シ', 'テ', 'トウセン', '。', 'ホッカイドウ', 'チジ', 'ニ', 'ナッ', 'タ', 'トキ', 'ノ', 'ネンレイ', 'ハ', 'サン', 'ジュウ', 'ゴ', 'サイ', 'デ', '、', 'トウジ', 'ゼンコク', 'サイネンショウ', 'デ', 'アリ', '、', 'ゲンザイ', 'ニ', 'オイ', 'テ', 'モ', 'シジョウ', 'サイネンショウ', 'ノ', 'コウセン', 'チジ', 'デ', 'アル', '。', '(', 'カンセン', 'ヲ', 'フクメル', 'ト', 'ニ', 'ジュウ', 'ナナ', 'サイ', 'デ', 'ヒョウゴ', 'ケンチジ', 'ニ', 'シュウニン', 'シ', 'タ', 'イトウ',

### Remove unneccessary tokens from the pre-trained tokenizer
The pre-trained tokenizer contains a lot of tokens that are not used in our dataset, so we need to remove these tokens. We also want to predict the word in lower cases because cases do not matter that much for TTS. Pruning the tokenizer is much faster than training a new tokenizer from scratch. 

In [17]:
from simple_loader import FilePathDataset, build_dataloader

file_data = FilePathDataset(dataset)
loader = build_dataloader(file_data, num_workers=32, batch_size=128)

In [18]:
special_token = config['dataset_params']['word_separator']
special_token

3039

In [19]:
# get all unique tokens in the entire dataset

from tqdm import tqdm

unique_index = [special_token]
for _, batch in enumerate(tqdm(loader)):
    unique_index.extend(batch)
    unique_index = list(set(unique_index))

100%|██████████| 5347/5347 [02:15<00:00, 39.59it/s]


In [20]:
# create the mapping

token_maps = {}
for t in tqdm(unique_index):
    word = tokenizer.decode([t])
    token_maps[t] = {'word': word, 'token': unique_index.index(t)}

100%|██████████| 22008/22008 [00:33<00:00, 657.37it/s] 


In [21]:
token_maps

{1: {'word': '[UNK]', 'token': 0},
 15: {'word': '', 'token': 1},
 16: {'word': '!', 'token': 2},
 18: {'word': '#', 'token': 3},
 19: {'word': '$', 'token': 4},
 20: {'word': '%', 'token': 5},
 21: {'word': '&', 'token': 6},
 23: {'word': '(', 'token': 7},
 24: {'word': ')', 'token': 8},
 25: {'word': '*', 'token': 9},
 26: {'word': '+', 'token': 10},
 27: {'word': ',', 'token': 11},
 29: {'word': '.', 'token': 12},
 30: {'word': '/', 'token': 13},
 31: {'word': '0', 'token': 14},
 41: {'word': ':', 'token': 15},
 42: {'word': ';', 'token': 16},
 43: {'word': '<', 'token': 17},
 44: {'word': '=', 'token': 18},
 45: {'word': '>', 'token': 19},
 46: {'word': '?', 'token': 20},
 47: {'word': '@', 'token': 21},
 48: {'word': 'A', 'token': 22},
 49: {'word': 'B', 'token': 23},
 50: {'word': 'C', 'token': 24},
 51: {'word': 'D', 'token': 25},
 52: {'word': 'E', 'token': 26},
 53: {'word': 'F', 'token': 27},
 54: {'word': 'G', 'token': 28},
 55: {'word': 'H', 'token': 29},
 56: {'word': 'I',

In [22]:
import pickle
with open(config['dataset_params']['token_maps'], 'wb') as handle:
    pickle.dump(token_maps, handle)
print('Token mapper saved to %s' % config['dataset_params']['token_maps'])

Token mapper saved to token_maps_ja.pkl


### Test the dataset with dataloader


In [23]:
from dataloader import build_dataloader

train_loader = build_dataloader(dataset, batch_size=32, num_workers=0, dataset_config=config['dataset_params'])

177


In [24]:
_, (words, labels, phonemes, input_lengths, masked_indices) = next(enumerate(train_loader))

In [25]:
words.shape

torch.Size([32, 512])

In [26]:
labels.shape # labels are the original phoneme tokens

torch.Size([32, 512])

In [27]:
phonemes.shape # phonemes are the phoneme tokens after masked

torch.Size([32, 512])

In [28]:
len(input_lengths)

32

In [29]:
len(masked_indices)

32