In [1]:
from transformers import (
    Wav2Vec2Processor, TFWav2Vec2ForCTC, 
)
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras import Model

import glob
import json
import jiwer
import re
import random
import pykakasi
import pandas as pd
import numpy as np
import argparse
import heapq
import operator
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict

from sklearn.model_selection import StratifiedKFold, train_test_split

def seed_everything(SEED):
   random.seed(SEED)
   np.random.seed(SEED)
   tf.random.set_seed(SEED)
   print("Random seed set.")

seed_everything(SEED=42)

Random seed set.


In [2]:
def Parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--main_dir', default="Datasets\WIKI-corpus-ja-dataset")
    parser.add_argument("--n_samples", default=79848)
    parser.add_argument('--input_length', default=32)
    parser.add_argument('--vocab_size', default=50000)
    parser.add_argument('--embedding_dim', default=400)
    parser.add_argument('--rnn_units', default=400)
    parser.add_argument('--batch_size', default=64)
    parser.add_argument('--n_splits', default=5)
    parser.add_argument('--random_state', default=42)
    parser.add_argument('--buffer_size', default=1024)
    parser.add_argument('--dropout', default=0.2)
    parser.add_argument('--learning_rate', default=1e-2)
    parser.add_argument('--epochs', default=10)

    args = parser.parse_known_args()[0]    

    test_size = (1 / args.n_splits)
    n_train = int(args.n_samples * (1 - test_size))
    train_steps = int(np.ceil(n_train / args.batch_size)) - 1
    parser.add_argument("--test_size", type=float, default=test_size)
    parser.add_argument("--train_steps", type=int, default=train_steps)
    args = parser.parse_known_args()[0]
    return args

args = Parser()
args

Namespace(batch_size=64, buffer_size=1024, dropout=0.2, embedding_dim=400, epochs=10, input_length=32, learning_rate=0.01, main_dir='Datasets\\WIKI-corpus-ja-dataset', n_samples=79848, n_splits=5, random_state=42, rnn_units=400, test_size=0.2, train_steps=998, vocab_size=50000)

## WikiCorpus

In [13]:
def XMLExtract(xml_path):
    symbols = r'[（.*?）！-～.,;..._。、-〿・■（）：ㇰ-ㇿ㈠-㉃㊀-㋾㌀-㍿「」『』→ー -~]'

    def get_hira(text):
        kks = pykakasi.kakasi()
        result = kks.convert(text)
        return [item['hira'] for item in result]

    lines = open(xml_path, "r", encoding="utf-8").readlines()
    data = []
    for line in lines:
        if line.startswith("<j>"):
            ja_pattern = "<j>(.*?)<\/j>"
            line = re.findall(ja_pattern, line)[0]
            line = re.sub(symbols, "", line)
            line = line.strip("\n|\t| |　")
            data.append(line)
        elif line.startswith('<e type="trans" ver="2">'):
            en_pattern = "<e.*?>(.*?)<\/e>"
            line = re.findall(en_pattern, line)[0]
            line = re.sub("\([^()]*\)|(&.*?;)+|\n|\t|", "", line)
            data.append(line)
        else:
            pass
    return {
        "hira": [get_hira(text) for text in data[0::2]],
        "en": data[1::2]}

all_hira, all_en = [], []
for xml_path in tqdm(glob.glob("Datasets\WIKI-corpus-ja-dataset\*\*.xml")):
    hira, en = list(XMLExtract(xml_path).values())
    hira = ["".join(words) + "\n" for words in hira]
    en = [words + "\n" for words in en]
    all_hira += hira
    all_en += en

with open(r"Datasets\WIKI-corpus-ja-dataset\all_hira.txt", "w", encoding="utf-8") as f:
    f.writelines(all_hira)

with open(r"Datasets\WIKI-corpus-ja-dataset\all_en.txt", "w", encoding="utf-8") as f:
    f.writelines(all_en)

  0%|          | 0/14111 [00:00<?, ?it/s]

In [None]:
# lines = open(r"Datasets\WIKI-corpus-ja-dataset\all_data.txt", 'r', encoding='utf-8').readlines()
# seq_lens = [len(line.split()) for line in lines]
# sns.countplot(x=seq_lens)
# plt.show()

In [None]:
corpus_path = args.main_dir + "\wiki_dataset_mecab_80000.txt"
lines = open(corpus_path, 'r', encoding='utf-8').readlines()
seq_lens = [len(line.split()) for line in lines]
print("Number of samples:", len(lines))
sns.countplot(x=seq_lens)
plt.show()

In [None]:
class WikiMeCabDataset:
    def __init__(self, args):
        self.args = args
        self.buffer_size = 1024
        self.AUTOTUNE = tf.data.AUTOTUNE
        self.dict_path = f"{args.main_dir}/vocab_dict.json"
        self.corpus_path = f"{args.main_dir}/wiki_dataset_mecab_80000.txt"
        self.tfrec_path = f"{args.main_dir}/wiki_tfrec"
        self.lines = self.apply_encoding(self.random_sample())

    def random_sample(self):
        lines = open(self.corpus_path, 'r', encoding='utf-8').readlines()
        return random.sample(lines, len(lines) - 3)

    def get_length(self, lines):
        return [len(line) for line in lines]

    def preprocess(self, lines):
        for i, line in enumerate(lines):
            line = line.strip("\n").split()
            lines[i] = self.apply_padding(line)
        return lines

    def apply_padding(self, line):
        line = ["<BOS>/<BOS>"] + line + ["<EOS>/<EOS>"]
        pad_len = self.args.input_length - len(line)
        line = np.pad(line, pad_width=(0, pad_len+1), constant_values="<PAD>/<PAD>")
        return line

    def get_vocab(self, lines):
        markers = ["<PAD>/<PAD>", "<BOS>/<BOS>", "<EOS>/<EOS>", "<UNK>/<UNK>"]
        words = Counter(word for line in lines for word in line)
        words = sorted(words, key=words.get, reverse=True)
        words = [word for word in words if word not in markers]
        return {word: i for i, word in enumerate(markers + words)}

    def apply_encoding(self, lines):
        lines = self.preprocess(lines)
        vocab = self.get_vocab(lines)
        with open(self.dict_path, "w", encoding="utf-8") as f:
            json.dump(vocab, f, sort_keys=False, indent=4, ensure_ascii=False)        
        for i, line in enumerate(lines):
            lines[i] = list(map(vocab.get, line))
        return lines

    def get_shards(self, lines):
        skf = StratifiedKFold(
            n_splits=self.args.n_splits, shuffle=True, 
            random_state=self.args.random_state)
        self.length = self.get_length(lines)
        return [
            list(map(lambda x: lines[x], j))
            for i, j in skf.split(lines, self.length)]

    def get_shard_data(self, samples):
        for sample in tqdm(samples):
            yield {
                'input': tf.io.serialize_tensor(sample[:-1]),
                'label': tf.io.serialize_tensor(sample[1:])}

    def _bytes_feature(self, value):
        """Returns a bytes_list from a string / byte."""
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def serialize_example(self, *args):
        feature = {
            'input': self._bytes_feature(args[0]),
            'label': self._bytes_feature(args[1])}

        example_proto = tf.train.Example(
            features=tf.train.Features(feature=feature))
        return example_proto.SerializeToString()  
    
    def write(self):
        for shard, samples in enumerate(self.get_shards(self.lines)):
            with tf.io.TFRecordWriter(
                    f"{self.tfrec_path}/shard_{shard+1}.tfrec") as f:
                for sample in self.get_shard_data(samples):
                    example = self.serialize_example(
                        sample['input'], sample['label'])
                    f.write(example)

# WikiMeCabDataset(args).write()

In [None]:
class DataLoader:
    def __init__(self, args):
        self.args = args
        self.tfrec_path = f"{args.main_dir}/wiki_tfrec"
        self.files = [os.path.join(self.tfrec_path, f) for f in os.listdir(self.tfrec_path)]
        self.AUTOTUNE = tf.data.experimental.AUTOTUNE
        self.train_files, self.val_files = train_test_split(
            self.files, test_size=1/args.n_splits, shuffle=True)
        self.train = self.train()
        self.val = self.val()

    def read_tfrecord(self, example):
        feature_description = {
            'input': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.string)}
        
        example = tf.io.parse_single_example(example, feature_description)
        example['input'] = tf.io.parse_tensor(
            example['input'], out_type=tf.int32)
        example['label'] = tf.io.parse_tensor(
            example['label'], out_type=tf.int32)
        example['label'] = tf.one_hot(example['label'], self.args.vocab_size)
        return example['input'], example['label']

    def load_dataset(self, files):
        ignore_order = tf.data.Options()
        ignore_order.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(files)
        dataset = dataset.with_options(ignore_order)
        dataset = dataset.map(self.read_tfrecord, num_parallel_calls=self.AUTOTUNE)
        return dataset

    def train(self):
        dataset = self.load_dataset(self.train_files)
        dataset = dataset.shuffle(self.args.buffer_size)
        dataset = dataset.batch(self.args.batch_size)
        dataset = dataset.prefetch(self.AUTOTUNE)
        return dataset

    def val(self):
        dataset = self.load_dataset(self.val_files)
        dataset = dataset.batch(self.args.batch_size)
        return dataset

val = DataLoader(args).val
next(iter(val))

In [None]:
class HiraKanji(tf.keras.Model):
    def __init__(self, args, name='HiraKanji'):
        super(HiraKanji, self).__init__(name=name)
        self.args = args
        self.embedding = Embedding(
            input_dim=args.vocab_size, 
            output_dim=args.embedding_dim, 
            input_length=args.input_length, 
            mask_zero=True)
        self.lstm = Bidirectional(LSTM(
            args.rnn_units, 
            dropout=args.dropout,
            return_sequences=True, 
            return_state=True))
        self.dropout = Dropout(args.dropout)
        self.dense = Dense(args.vocab_size, activation='softmax')

    def call(self, inputs, hidden_states, training):
        x = self.embedding(inputs)
        mask = tf.not_equal(inputs, 0)
        x, forward_h, forward_c, backward_h, backward_c = self.lstm(
            x, mask=mask, initial_state=hidden_states)
        hidden_states = [forward_h, forward_c, backward_h, backward_c]
        x = self.dropout(x, training=training)
        x = self.dense(x)
        return x, hidden_states

    def initialize_hidden_states(self):
        return [tf.zeros((self.args.batch_size, self.args.rnn_units)) for _ in range(4)]

In [None]:
class Trainer:
    def __init__(self, args):
        self.args = args
        self.model = HiraKanji(args)
        self.optimizer = tf.keras.optimizers.Adam(
            learning_rate=args.learning_rate)
        self.loss_fn = tf.keras.losses.CategoricalCrossentropy(
            from_logits=False, label_smoothing=0.2)
        self.metric = tf.keras.metrics.CategoricalAccuracy()
        self.dataloader = DataLoader(args)   

    def train(self):
        for epoch in range(self.args.epochs):
            stateful_metrics = ['loss', 'acc', 'val_loss', 'val_acc']
            print(f"Epoch {epoch+1}/{self.args.epochs}")
            progbar = tf.keras.utils.Progbar(
                self.args.train_steps, interval=0.05,
                stateful_metrics=stateful_metrics)
            hidden_states = self.model.initialize_hidden_states()

            for step, (t_inputs, t_labels) in enumerate(self.dataloader.train):
                t_mask = tf.not_equal(tf.math.argmax(t_labels, axis=-1), 0)

                with tf.GradientTape() as tape:
                    t_logits, hidden_states = self.model(t_inputs, hidden_states, training=True)
                    t_loss = self.loss_fn(t_labels, t_logits)

                grads = tape.gradient(t_loss, self.model.trainable_weights)
                self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
                self.metric.update_state(t_labels, t_logits, sample_weight=t_mask)
                t_acc = self.metric.result()
                values=[('loss', t_loss), ('acc', t_acc)]
                progbar.update(step, values=values, finalize=False)
                self.metric.reset_states()

            for v_inputs, v_labels in self.dataloader.val:
                v_mask = tf.not_equal(tf.math.argmax(v_labels, axis=-1), 0)
                v_logits, hidden_states = model(v_inputs, hidden_states, training=False)
                v_loss = self.loss_fn(v_labels, v_logits)
                self.metric.update_state(v_labels, v_logits, sample_weight=v_mask)
            values = [
                ('loss', t_loss), ('acc', t_acc), ('val_loss', v_loss),
                ('val_acc', self.metric.result())]
            progbar.update(self.args.train_steps, values=values, finalize=True)
            self.metric.reset_states()
        return model

model = Trainer(args).train()

In [None]:
def load_dict(vocab_path):
    dictionary = defaultdict(list)
    with open(vocab_path, 'r', encoding='utf-8') as j:
        data = list(json.load(j).keys())
    for i, inputs in enumerate(data):
        target, source = inputs.split("/")
        dictionary[source].append((target, i))
    return dictionary

def create_lattice(inputs, dictionary):
    lattice = [[[] for _ in range(len(inputs) + 1)] for _ in range(len(inputs) + 2)]
    unk_id = dictionary['<UNK>'][0][1]
    for i in range(1, len(inputs) + 1):
        for j in range(i):
            key = inputs[j:i]
            if key in dictionary:
                for target, word_id in dictionary[key]:
                    lattice[i][j].append((target, word_id))
            elif len(key) == 1:
                lattice[i][j].append((key, unk_id))

    eos_id = dictionary['<EOS>'][0][1]
    lattice[-1][-1].append(('', eos_id))
    return lattice

def initalize_queues(lattice, model, dictionary):
    bos_id = dictionary['<BOS>'][0][1]
    inputs = tf.expand_dims([bos_id], axis=0)
    hidden_states = model.initialize_hidden_state()  
    bos_pred = model(inputs, hidden_states, training=False)

    bos_pred = tf.squeeze(bos_pred, axis=0)
    bos_pred = tf.squeeze(bos_pred, axis=0)
    bos_pred = -1 * tf.nn.log_softmax(bos_pred, axis=0)

    hidden_states = tf.expand_dims(hidden_states, axis=0)
    bos_hypothesis = (0.0, '', hidden_states[0], bos_pred)
    queues = [[] for _ in range(len(lattice))]
    queues[0].append(bos_hypothesis)
    return queues

def search(lattice, queues, model, beam_size, viterbi_size):
    for i in range(len(lattice)):
        queue = []

        for j in range(len(lattice[i])):
            for target, word_id in lattice[i][j]:
                word_queue = []
                for prev_cost, prev_str, prev_states, prev_pred in queues[j]:
                    cost = prev_cost + prev_pred[word_id]
                    string = prev_str + target
                    hypothesis = (cost, string, word_id, prev_states)
                    word_queue.append(hypothesis)
                if viterbi_size > 0:
                    word_queue = heapq.nsmallest(
                        viterbi_size, word_queue, key=operator.itemgetter(0))
                queue += word_queue

        if beam_size > 0:
            queue = heapq.nsmallest(beam_size, queue, key=operator.itemgetter(0))
        
        for cost, string, word_id, prev_states in queue:
            inputs = tf.expand_dims([word_id], axis=0)
            pred, hidden_states = model(inputs, [prev_states], training=False)

            pred = tf.squeeze(pred, axis=0)
            pred = tf.squeeze(pred, axis=0)
            pred = -1 * tf.nn.log_softmax(pred, axis=0)

            hidden_states = tf.expand_dims(hidden_states, axis=0)
            hypothesis = (cost, string, hidden_states[0], pred)
            queues[i].append(hypothesis)
    return queues

vocab_path = f"{args.main_dir}/vocab_dict.json"

In [None]:
# idx = 0
# data = next(iter(val))
# y_pred = model.predict(data[0])[idx]
# y_pred = tf.argmax(y_pred, axis=-1)
# y_true = tf.argmax(data[1][idx], axis=-1)
# print(y_pred, y_true)

In [None]:
hira = list("このまえさがったときはとちゆうにはんこんのりゆうきがあったのでついそこがいきどまりだとばかりおもってああいったんですが")
kanji = "この前探った時は途中に瘢痕の隆起があったのでついそこが行きどまりだとばかり思って、ああ云ったんですが"