# Notebook for preprocessing Wikipedia (English) dataset

### Initilizing phonemizer and tokenizer

In [1]:
import yaml

config_path = "Configs/config.yml" # you can change it to anything else
config = yaml.safe_load(open(config_path))

In [2]:
from phonemize import phonemize

In [3]:
# import phonemizer
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True,  with_stress=True)

In [4]:
# from transformers import TransfoXLTokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(config['dataset_params']['tokenizer'])
# tokenizer = TransfoXLTokenizer.from_pretrained(config['dataset_params']['tokenizer']) # you can use any other tokenizers if you want to

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Process dataset

In [5]:
from datasets import load_dataset
# đã chỉnh sửa bộ dữ liệu wikipedia
dataset = load_dataset("vietgpt/wikipedia_vi")['train']
print(dataset[2])


{'id': 13, 'revid': '647989', 'url': 'https://vi.wikipedia.org/wiki?curid=13', 'title': 'Tiếng Việt', 'text': 'Tiếng Việt, cũng gọi là tiếng Việt Nam hay Việt ngữ là ngôn ngữ của người Việt và là ngôn ngữ chính thức tại Việt Nam. Đây là tiếng mẹ đẻ của khoảng 85% dân cư Việt Nam cùng với hơn 4 triệu người Việt kiều. Tiếng Việt còn là ngôn ngữ thứ hai của các dân tộc thiểu số tại Việt Nam và là ngôn ngữ dân tộc thiểu số được công nhận tại Cộng hòa Séc.\nDựa trên từ vựng cơ bản, tiếng Việt được phân loại là một ngôn ngữ thuộc ngữ hệ Nam Á. Tiếng Việt là ngôn ngữ có nhiều người nói nhất trong ngữ hệ này (nhiều hơn tổng số người nói của tất cả các ngôn ngữ còn lại trong ngữ hệ). Vì Việt Nam thuộc Vùng văn hoá Đông Á, tiếng Việt cũng chịu nhiều ảnh hưởng về từ tiếng Hán, do vậy là ngôn ngữ có ít điểm tương đồng nhất với các ngôn ngữ khác trong ngữ hệ Nam Á.\nLịch sử.\nTheo A. G. Haudricourt giải thích từ năm 1954, nhóm ngôn ngữ Việt-Mường ở thời kỳ khoảng đầu Công nguyên là những ngôn ngữ h

In [6]:
root_directory = "./wiki_phoneme" # set up root directory for multiprocessor processing

In [7]:
import os
num_shards = 50000

def process_shard(i):
    directory = root_directory + "/shard_" + str(i)
    if os.path.exists(directory):
        print("Shard %d already exists!" % i)
        return
    print('Processing shard %d ...' % i)
    shard = dataset.shard(num_shards=num_shards, index=i)
    # đã bỏ global_tokenizer
    processed_dataset = shard.map(lambda t: phonemize(t['text'], tokenizer), remove_columns=['text'])
    if not os.path.exists(directory):
        os.makedirs(directory)
    processed_dataset.save_to_disk(directory)
    print('Done %d ...' % i)

In [8]:
from pebble import ProcessPool
from concurrent.futures import TimeoutError

#### Note: You will need to run the following cell multiple times to process all shards because some will fail. Depending on how fast you process each shard, you will need to change the timeout to a longer value to make more shards processed before being killed.


In [None]:
max_workers = 16 # change this to the number of CPU cores your machine has 

with ProcessPool(max_workers=max_workers) as pool:
    pool.map(process_shard, range(num_shards), timeout=360)

In [None]:
import os
from datetime import datetime
from tqdm.notebook import tqdm
from pebble import ProcessPool
from concurrent.futures import TimeoutError
import logging
from pathlib import Path

# Setup basic logging
log_dir = "logs"
Path(log_dir).mkdir(exist_ok=True)
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f"{log_dir}/processing_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"),
        logging.StreamHandler()
    ]
)

# Track progress
progress_file = Path(root_directory) / "progress.txt"
processed_shards = set()
if progress_file.exists():
    with open(progress_file, 'r') as f:
        processed_shards = {int(line.strip()) for line in f}

def process_shard(i):
    try:
        directory = Path(root_directory) / f"shard_{i}"
        
        # Skip if already processed
        if i in processed_shards:
            logging.info(f"Shard {i} already processed, skipping...")
            return True
            
        if directory.exists():
            logging.info(f"Shard {i} directory exists, skipping...")
            return True
            
        logging.info(f'Processing shard {i} ...')
        shard = dataset.shard(num_shards=num_shards, index=i)
        processed_dataset = shard.map(
            lambda t: phonemize(t['text'], tokenizer), 
            remove_columns=['text']
        )
        
        directory.mkdir(parents=True, exist_ok=True)
        processed_dataset.save_to_disk(directory)
        
        # Record progress
        with open(progress_file, 'a') as f:
            f.write(f"{i}\n")
            
        logging.info(f'Done {i} ...')
        return True
        
    except Exception as e:
        logging.error(f"Error processing shard {i}: {str(e)}")
        return False

num_shards = 50000
max_workers = 16  # change this to the number of CPU cores your machine has

# Get remaining shards to process
remaining_shards = [i for i in range(num_shards) if i not in processed_shards]
logging.info(f"Starting processing of {len(remaining_shards)} remaining shards")

with ProcessPool(max_workers=max_workers) as pool:
    future = pool.map(process_shard, remaining_shards, timeout=360)
    iterator = future.result()
    
    # Use tqdm for progress tracking
    failed_shards = []
    with tqdm(total=len(remaining_shards), desc="Processing shards") as pbar:
        while True:
            try:
                result = next(iterator)
                pbar.update(1)
            except StopIteration:
                break
            except TimeoutError as error:
                logging.error(f"Shard processing timed out")
                failed_shards.append(pbar.n)
            except Exception as error:
                logging.error(f"Shard processing failed: {error}")
                failed_shards.append(pbar.n)

# Report final status
logging.info(f"Processing completed. Processed: {len(processed_shards)}, Failed: {len(failed_shards)}")
if failed_shards:
    logging.warning(f"Failed shards: {sorted(failed_shards)}")

### Collect all shards to form the processed dataset

In [9]:
from datasets import load_from_disk, concatenate_datasets

output = [dI for dI in os.listdir(root_directory) if os.path.isdir(os.path.join(root_directory,dI))]
datasets = []
for o in output:
    directory = root_directory + "/" + o
    try:
        shard = load_from_disk(directory)
        datasets.append(shard)
        print("%s loaded" % o)
    except:
        continue

shard_0 loaded
shard_1 loaded
shard_10 loaded
shard_100 loaded
shard_1000 loaded
shard_10000 loaded
shard_10001 loaded
shard_10002 loaded
shard_10003 loaded
shard_10004 loaded
shard_10005 loaded
shard_10006 loaded
shard_10007 loaded
shard_10008 loaded
shard_10009 loaded
shard_1001 loaded
shard_10010 loaded
shard_10011 loaded
shard_10012 loaded
shard_10013 loaded
shard_10014 loaded
shard_10015 loaded
shard_10016 loaded
shard_10017 loaded
shard_10018 loaded
shard_10019 loaded
shard_1002 loaded
shard_10020 loaded
shard_10021 loaded
shard_10022 loaded
shard_10023 loaded
shard_10024 loaded
shard_10025 loaded
shard_10026 loaded
shard_10027 loaded
shard_10028 loaded
shard_10029 loaded
shard_1003 loaded
shard_10030 loaded
shard_10031 loaded
shard_10032 loaded
shard_10033 loaded
shard_10034 loaded
shard_10035 loaded
shard_10036 loaded
shard_10037 loaded
shard_10038 loaded
shard_10039 loaded
shard_1004 loaded
shard_10040 loaded
shard_10041 loaded
shard_10042 loaded
shard_10043 loaded
shard_10044

In [12]:
dataset = concatenate_datasets(datasets)
dataset.save_to_disk(config['data_folder'])
print('Dataset saved to %s' % config['data_folder'])

Saving the dataset (0/8 shards):   0%|          | 0/1276984 [00:00<?, ? examples/s]

Dataset saved to wikipedia-vi.processed


In [13]:
# check the dataset size
dataset

Dataset({
    features: ['id', 'revid', 'url', 'title', 'input_ids', 'phonemes'],
    num_rows: 1276984
})

### Remove unneccessary tokens from the pre-trained tokenizer
The pre-trained tokenizer contains a lot of tokens that are not used in our dataset, so we need to remove these tokens. We also want to predict the word in lower cases because cases do not matter that much for TTS. Pruning the tokenizer is much faster than training a new tokenizer from scratch. 

In [14]:
from simple_loader import FilePathDataset, build_dataloader

file_data = FilePathDataset(dataset)
loader = build_dataloader(file_data, num_workers=32, batch_size=128)



In [15]:
special_token = config['dataset_params']['word_separator']

In [16]:
# get all unique tokens in the entire dataset

from tqdm import tqdm

unique_index = [special_token]
for _, batch in enumerate(tqdm(loader)):
    unique_index.extend(batch)
    unique_index = list(set(unique_index))

print(unique_index)


100%|██████████| 9976/9976 [06:14<00:00, 26.63it/s]

[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 32822, 57, 58, 59, 60, 61, 32830, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 32838, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 32854, 32861, 94, 32860, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 113, 114, 115, 118, 119, 32887, 121, 122, 123, 124, 126, 127, 128, 129, 32894, 131, 32895, 133, 32897, 135, 136, 137, 138, 139, 140, 32906, 142, 143, 32905, 145, 32908, 147, 148, 149, 150, 151, 152, 32920, 154, 155, 156, 159, 32928, 163, 164, 166, 167, 170, 171, 32938, 173, 175, 176, 177, 178, 179, 181, 182, 183, 184, 185, 186, 187, 188, 189, 32950, 32956, 192, 193, 32957, 32963, 196, 197, 199, 200, 201, 202, 203, 204, 205, 207, 208, 212, 32980, 213, 215, 216, 214, 32986, 219, 221, 222, 223, 224, 226, 227, 32994, 229, 230, 231, 232, 3




In [17]:
# get each token's lower case

lower_tokens = []
for t in tqdm(unique_index):
    word = tokenizer.decode([t])
    if word.lower() != word:
        t = tokenizer.encode([word.lower()])[1]
        lower_tokens.append(t)
    else:
        lower_tokens.append(t)

100%|██████████| 12444/12444 [00:00<00:00, 19002.33it/s]


In [18]:
lower_tokens = (list(set(lower_tokens)))

In [19]:
# redo the mapping for lower number of tokens

token_maps = {}
for t in tqdm(unique_index):
    word = tokenizer.decode([t])
    word = word.lower()
    new_t = tokenizer.encode([word.lower()])[1]
    token_maps[t] = {'word': word, 'token': lower_tokens.index(new_t)}

100%|██████████| 12444/12444 [00:10<00:00, 1220.16it/s]


In [20]:
import pickle
with open(config['dataset_params']['token_maps'], 'wb') as handle:
    pickle.dump(token_maps, handle)
print('Token mapper saved to %s' % config['dataset_params']['token_maps'])

Token mapper saved to token_maps.pkl


### Test the dataset with dataloader


In [21]:
from dataloader import build_dataloader

train_loader = build_dataloader(dataset, batch_size=32, num_workers=0, dataset_config=config['dataset_params'])

177


In [22]:
_, (words, labels, phonemes, input_lengths, masked_indices) = next(enumerate(train_loader))