In [1]:
from concurrent.futures import TimeoutError
import os
from pebble import ProcessPool
import pickle

from dataloader import build_dataloader as build_trainloader
import datasets
from datasets import load_from_disk, concatenate_datasets
import pathlib
# import phonemizer
import torch
from tqdm import tqdm
from transformers import BertJapaneseTokenizer
import yaml

from simple_loader import FilePathDataset, build_dataloader
from phonemize import phonemize

In [2]:
##### config #####
config_path = "Configs/config.yml" # you can change it to anything else
config = yaml.safe_load(open(config_path))

##### set tokenizer #####
tokenizer = BertJapaneseTokenizer.from_pretrained(config['dataset_params']['tokenizer'])

##### download dataset #####
# comment out the following line in hogehoge/datasets/wikipedia/wikipedia.py
# | "Distribute" >> beam.transforms.Reshuffle()
datasets.config.DOWNLOADED_DATASETS_PATH = pathlib.Path("./dataset/wikipedia-ja")
dataset = datasets.load_dataset("wiki40b", "ja", split="train")



In [3]:
import ast

def decode_text(sample_text):
    """
    Chuyển đổi chuỗi có dạng b'...' thành chuỗi UTF-8.
    
    Args:
        sample_text (str): Chuỗi có thể chứa biểu diễn bytes dạng string.
        
    Returns:
        str: Chuỗi đã giải mã UTF-8, hoặc giữ nguyên nếu không cần giải mã.
    """
    if isinstance(sample_text, str):
        if (sample_text.startswith("b'") and sample_text.endswith("'")) or (sample_text.startswith('b"') and sample_text.endswith('"')):
            try:
                sample_text = ast.literal_eval(sample_text)  # Chuyển từ string thành bytes thực sự
                sample_text = sample_text.decode("utf-8")   # Giải mã bytes thành UTF-8
            except (SyntaxError, ValueError):
                pass  # Trả về chuỗi gốc nếu gặp lỗi
    return sample_text

decoded_text = decode_text(dataset[3024]['text'])
decoded_text

'\n_START_ARTICLE_\n弘前市立大和沢小学校\n_START_SECTION_\n学区\n_START_PARAGRAPH_\n大和沢、一野渡、狼森'

In [4]:
def clean_wiki_text(text):
    text = ''.join(text.split('_START_PARAGRAPH_')[1:])
    markers = ["_START_ARTICLE_", "_START_SECTION_", "_START_PARAGRAPH_", "_NEWLINE_", "_START_HEADING_", 
               "_START_BULLET_", "_START_LIST_", "_START_TABLE_", "_START_CAPTION_", "_START_IMAGE_"]
    for marker in markers:
        text = text.replace(marker, "")
    return text.strip()

decoded_text = clean_wiki_text(decoded_text)
decoded_text

'大和沢、一野渡、狼森'

In [5]:
def preprocess_text(text):
    text = decode_text(text)
    text = clean_wiki_text(text)
    return text
preprocess_text(dataset[3024]['text'])

'大和沢、一野渡、狼森'

In [6]:
import pyopenjtalk
import unicodedata
from convert_label import openjtalk2julius

_japanese = ['ky','sp', 'sh', 'ch', 'ts','ty', 'ry', 'ny', 'by', 'hy', 'gy', 'kw', 'gw', 'kj', 'gj', 'my', 'py','dy']
japanese = ['$', '%', '&', '「', '」', '=', '~', '^', '|', '[', ']', '{', '}', '*', '+', '#', '<', '>']
_japanese2japanese = {
    'ky': '$',
    'sp': '%',
    'sh': '&',
    'ch': '「',
    'ts': '」',
    'ty': '=',
    'ry': '~',
    'ny': '^',
    'by': '|',
    'hy': '[',
    'gy': ']',
    'kw': '{',
    'gw': '}',
    'kj': '*',
    'gj': '+',
    'my': '#',
    'py': '<',
    'dy': '>',
}

def global_phonemize(text: str):
    phonemes = pyopenjtalk.g2p(text).split(' ')
    print("phonemes: ", phonemes)
    phonemes = [openjtalk2julius(p) for p in phonemes if p != '']
    for i in range(len(phonemes)):
        phoneme = phonemes[i]
        if phoneme in _japanese:
            phonemes[i] = _japanese2japanese[phoneme]
    return phonemes

text = unicodedata.normalize("NFKC", "おはようございます。")
words = tokenizer.tokenize(text)
input_ids_ = tokenizer.convert_tokens_to_ids(words)
phonemes = []
input_ids = []
for i in range(len(words)):
    word = words[i]
    input_id = input_ids_[i]
    print("word: ", word)
    phoneme = global_phonemize(word.replace('#', ''))
    if len(phoneme) != 0:
        phonemes.append(''.join(phoneme))
        input_ids.append(input_id)

word:  おはよう
phonemes:  ['o', 'h', 'a', 'y', 'o', 'o']
word:  ござい
phonemes:  ['g', 'o', 'z', 'a', 'i']
word:  ます
phonemes:  ['m', 'a', 's', 'U']
word:  。
phonemes:  ['']




In [7]:
phonemize(preprocess_text(dataset[3024]['text']), tokenizer)



{'input_ids': [4612, 29331, 52, 28737, 29173, 10082, 29356],
 'phonemes': ['yamato', 'sawa', 'i「i', 'no', 'watari', 'ookami', 'mori']}

In [8]:
root_directory = "./wiki_phoneme"

In [None]:
import os
import gc

num_shards = 1000
def process_shard(i):
    directory = root_directory+"/shard_" + str(i)
    if os.path.exists(directory):
        print("Shard %d already exists!" % i)
        return
    print('Processing shard %d ...' % i)
    try:
        shard = dataset.shard(num_shards=num_shards, index=i)
        processed_dataset = shard.map(lambda t: phonemize(preprocess_text(t['text']), tokenizer), remove_columns=['text'])
        if not os.path.exists(directory):
            os.makedirs(directory)
        processed_dataset.save_to_disk(directory)
        print(f'Shard {i} processed successfully.')
        del processed_dataset  # Free memory
        gc.collect()
    except Exception as e:
        print(f'Error processing shard {i}: {e}')

In [10]:
from pebble import ProcessPool
from concurrent.futures import TimeoutError

In [11]:
import os

# Lấy số core của CPU
num_cores = os.cpu_count()
print(f"Số core của CPU là: {num_cores}")

Số core của CPU là: 80


In [None]:
max_workers = num_cores # change this to the number of CPU cores your machine has 

with ProcessPool(max_workers=max_workers) as pool:
    pool.map(process_shard, range(num_shards), timeout=120)