In [1]:
import os
import re
import ast
import glob
import random
import cutlet
import argparse
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import pandas as pd
import numpy as np
import seaborn as sns
from sacrebleu.metrics import BLEU

from sklearn.model_selection import KFold, train_test_split

import tensorflow as tf

import sentencepiece as spm

from tokenizers import ByteLevelBPETokenizer

from transformers import (
    T5Tokenizer,
    TFT5ForConditionalGeneration,
    GradientAccumulator,
    logging)

def seed_everything(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    print("Random seed set.")

seed_everything(42)
tf.get_logger().setLevel('FATAL')
logging.set_verbosity_error()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Random seed set.


# Data Preparation

In [2]:
katsu = cutlet.Cutlet()
katsu.use_foreign_spelling = True

JA_unicode = []
with open("E:\Datasets\ASR-dataset\JA_unicode.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()
    for line in lines:
        line = line.split()[1:]
        for char in line:
            JA_unicode.append(char)

def clean_kanji(sentence):
    symbols = f"[^{JA_unicode}]"
    sentence = re.sub(symbols, "", sentence.strip())
    symbols = r"[（.*?）！-～.,;..._。、-〿・■（）：ㇰ-ㇿ㈠-㉃㊀-㋾㌀-㍿「」『』→ー -~‘–※π—ゐ’“”]"
    sentence = re.sub(symbols, "", sentence.strip())
    return sentence

def clean_romaji(sentence):
    sentence = sentence.strip().lower()
    sentence = re.sub(r"[^a-zA-Z0-9\ ]", "", sentence)
    sentence = sentence.split()
    for i, mora in enumerate(sentence):
        if (mora == "n") | (mora == "u") & (i < len(sentence) - 1):
            prev_mora = sentence.pop(i-1)
            sentence[i-1] = "".join([prev_mora, mora])
    sentence = " ".join(sentence)
    return sentence

def kanji2romaji(text):
    try:
        new_line = clean_kanji(text)
        new_line = katsu.romaji(new_line)
        new_line = clean_romaji(new_line)
    except:
        new_line = None
    return new_line

def clean_en(sentence):
    sentence = sentence.lower()
    sentence = re.sub(r"[^a-z0-9\ \?\.\!\,\'\"]", "", sentence)
    return sentence

In [3]:
main_dir = "D:\School-stuff\Sem-2\PR-Project\HoloASR\Datasets"
opus_ja_paths = glob.glob(f"{main_dir}\OPUS100-dataset\*.ja")
tatoeba_ja_paths = glob.glob(f"{main_dir}\Tatoeba-dataset\*.ja")
coursera_ja_paths = glob.glob(f"{main_dir}\Coursera-dataset\*.ja.txt")

ja_paths = opus_ja_paths + tatoeba_ja_paths + coursera_ja_paths

ja_lines, en_lines = [], []
for ja_path in ja_paths:
    if ja_path.endswith(".ja"):
        en_path = ja_path.rsplit(".", 1)[0] + ".en"
    else:
        en_path = ja_path.replace("ja", "en")
    with open(ja_path, "r", encoding="utf-8") as f:
        lines = [line.strip("\n") for line in f.readlines()]
        ja_lines.extend(lines)
    with open(en_path, "r", encoding="utf-8") as f:
        lines = [line.strip("\n") for line in f.readlines()]
        en_lines.extend(lines)

tqdm.pandas()
data = pd.DataFrame({'ja_raw': ja_lines, 'en': en_lines})
data['ja_ro'] = data['ja_raw'].progress_apply(kanji2romaji)
data = data[data['ja_ro'].notnull()].reset_index(drop=True)
data['en'] = data['en'].progress_apply(clean_en)
data.to_csv(
    r"E:\Datasets\ASR-dataset\tokenizer_text\tokenizer_text.csv", 
    index=False, encoding="utf-8")
data

  0%|          | 0/1266032 [00:00<?, ?it/s]

In [None]:
data = pd.read_csv(r"E:\Datasets\ASR-dataset\tokenizer_text\tokenizer_text.csv")
data = data.dropna().reset_index(drop=True)

with open(r"E:\Datasets\ASR-dataset\tokenizer_text\tokenizer_text.txt", "w", encoding="utf-8") as f:
    for row in tqdm(data.iterrows(), total=len(data)):
        idx, (_, en, ja) = row
        row =  " ".join([en, ja]) + "\n"
        f.write(row)

In [None]:
text_file = r"E:\Datasets\ASR-dataset\tokenizer_text\tokenizer_text.txt"
model_prefix = r"E:\Datasets\ASR-dataset\tokenizer_text\t5"

spm.SentencePieceTrainer.train(
    f"--input={text_file} --model_prefix={model_prefix} --vocab_size={32128} --pad_id=0 --unk_id=1 --bos_id=-1 --eos_id=2 --pad_piece=<pad> --unk_piece=<unk> --eos_piece=</s>"
)

# Data Loading

In [None]:
class TFRWriter():
    def __init__(self):
        self.main_dir = "E://Datasets/ASR-dataset"
        self.n_shards = 10
        self.tokenizer = T5Tokenizer(
            vocab_file=f"{self.main_dir}/tokenizer_text/t5.model",
            eos_token="</s>",
            unk_token="<unk>",
            pad_token="<pad>")
        self.task_prefix = "translate Japanese to English: "
        self.data = self.get_data()

    def get_data(self):
        tqdm.pandas()
        data = pd.read_csv(
            f"{self.main_dir}/tokenizer_text/tokenizer_text.csv", 
            encoding="utf-8")
        data = data.dropna().reset_index(drop=True)[['ja_ro', 'en']]
        data['ja_token'] = data['ja_ro'].progress_apply(
            lambda x: self.tokenizer(self.task_prefix + x.lower()).input_ids)
        data['en_token'] = data['en'].progress_apply(
            lambda x: self.tokenizer(x).input_ids)
        data['ja_len'] = data['ja_token'].apply(len)
        data = data.query("ja_len <= 21")
        data = data.sample(n=50000, random_state=42, ignore_index=True)
        data = data.sort_values(by="ja_len", ignore_index=True, ascending=True)
        data.to_csv(
            os.path.join(self.main_dir, r"tokenizer_text\tokenizer_text2.csv"),
            index=False)
        return data

    def _bytes_feature(self, value):
        """Returns a bytes_list from a string / byte."""
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def serialize_example(self, *args):
        feature = {
            'input_ids': self._bytes_feature(args[0]),
            'attention_mask': self._bytes_feature(args[1]),
            'labels': self._bytes_feature(args[2])}

        example_proto = tf.train.Example(
            features=tf.train.Features(feature=feature))
        return example_proto.SerializeToString()

    def get_shards(self):
        skf = KFold(n_splits=self.n_shards, shuffle=False)
        return [j for i,j in skf.split(self.data)]

    def get_shard_data(self, samples):
        for sample in samples:
            input_ids = tf.convert_to_tensor(
                self.data['ja_token'][sample], dtype=tf.int32)
            attention_mask = tf.where(input_ids != 0, x=1, y=0)
            labels = tf.convert_to_tensor(
                self.data['en_token'][sample], dtype=tf.int32)
            yield {
                "input_ids": tf.io.serialize_tensor(input_ids),
                "attention_mask": tf.io.serialize_tensor(attention_mask),
                "labels": tf.io.serialize_tensor(labels)
            }

    def write(self):
        for shard, samples in tqdm(enumerate(self.get_shards()), total=self.n_shards):
            with tf.io.TFRecordWriter(f"{self.main_dir}/bart_tfrec/shard_{shard+1}.tfrec") as f:
                for sample in self.get_shard_data(samples):
                    example = self.serialize_example(
                        sample['input_ids'],
                        sample['attention_mask'],
                        sample['labels'],
                        )
                    f.write(example)

TFRWriter().write()

In [None]:
def ArgParser():
    parser = argparse.ArgumentParser()

    parser.add_argument("--random_state", default=42)
    parser.add_argument("--main_dir", default="E://Datasets/ASR-dataset")
    parser.add_argument("--n_shards", default=10)
    parser.add_argument("--test_size", default=0.1)
    parser.add_argument("--batch_size", default=16)
    parser.add_argument("--buffer_size", default=512)

    # Trainer
    parser.add_argument("--accum_steps", default=2)

    # Scheduler
    parser.add_argument("--learning_rate", default=5e-5)
    parser.add_argument("--epochs", default=5)

    args = parser.parse_known_args()[0]

    n_samples = len(pd.read_csv(
        f"{args.main_dir}/tokenizer_text/tokenizer_text2.csv"))   

    n_train = int(n_samples * (1 - args.test_size))
    n_val = int(n_samples * args.test_size)
    train_steps = int(np.ceil(n_train / args.batch_size))
    val_steps = int(np.ceil(n_val / args.batch_size))
        
    parser.add_argument("--n_samples", default=n_samples)
    parser.add_argument("--n_train", default=n_train)
    parser.add_argument("--n_val", default=n_val)
    parser.add_argument("--train_steps", default=train_steps)  
    parser.add_argument("--val_steps", default=val_steps)

    return parser.parse_known_args()[0]

args = ArgParser()
args

In [None]:
class DataLoader:
    def __init__(self, args):
        self.files = glob.glob(args.main_dir + "/bart_tfrec/*.tfrec")
        self.args = args
        self.AUTOTUNE = tf.data.AUTOTUNE
        self.train_files, self.val_files = train_test_split(
            self.files, test_size=args.test_size, shuffle=True, 
            random_state=args.random_state)
        self.train = self.get_train()
        self.val = self.get_val()

    def read_tfrecord(self, example):
        feature_description = {
            'input_ids': tf.io.FixedLenFeature([], tf.string),
            'attention_mask': tf.io.FixedLenFeature([], tf.string),
            'labels': tf.io.FixedLenFeature([], tf.string)
            }
        
        example = tf.io.parse_single_example(example, feature_description)
        example['input_ids'] = tf.io.parse_tensor(
            example['input_ids'], out_type=tf.int32)
        example['attention_mask'] = tf.io.parse_tensor(
            example['attention_mask'], out_type=tf.int32) 
        example['labels'] = tf.io.parse_tensor(
            example['labels'], out_type=tf.int32)
        return example

    def load_dataset(self, files):
        ignore_order = tf.data.Options()
        ignore_order.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(files)
        dataset = dataset.with_options(ignore_order)
        dataset = dataset.map(self.read_tfrecord, num_parallel_calls=self.AUTOTUNE)
        return dataset

    def get_train(self):
        dataset = self.load_dataset(self.train_files)
        dataset = dataset.padded_batch(
            self.args.batch_size,
            padded_shapes={
                'input_ids': [None],
                'attention_mask': [None],
                'labels': [None]
            },
            padding_values={
                'input_ids': tf.constant(0, dtype=tf.int32),
                'attention_mask': tf.constant(0, dtype=tf.int32),
                'labels': tf.constant(-100, dtype=tf.int32)
            })        
        dataset = dataset.shuffle(self.args.buffer_size)
        dataset = dataset.prefetch(self.AUTOTUNE)
        return dataset

    def get_val(self):
        dataset = self.load_dataset(self.val_files)
        dataset = dataset.padded_batch(
            self.args.batch_size,
            padded_shapes={
                'input_ids': [None],
                'attention_mask': [None],
                'labels': [None]
            },
            padding_values={
                'input_ids': tf.constant(0, dtype=tf.int32),
                'attention_mask': tf.constant(0, dtype=tf.int32),
                'labels': tf.constant(-100, dtype=tf.int32)
            })
        dataset = dataset.shuffle(self.args.buffer_size)
        dataset = dataset.cache()
        dataset = dataset.prefetch(self.AUTOTUNE)
        return dataset

# train = DataLoader(args).train

# inputs = next(iter(train))
# input_values = inputs['input_ids']
# labels = inputs['labels']
# attention_mask = inputs['attention_mask']
# print(inputs)

In [None]:
class BLEUMetric(tf.keras.metrics.Metric):
    def __init__(self, name="BLEU", **kwargs):
        super(BLEUMetric, self).__init__(name=name, **kwargs)
        self.bleu = BLEU()
        self.accumulator = self.add_weight(name="total_bleu", initializer="zeros")
        self.counter = self.add_weight(name="counter", initializer="zeros")
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        bleu_score = self.bleu.corpus_score(hypotheses=y_true, references=y_pred).score
        self.accumulator.assign_add(bleu_score)
        self.counter.assign_add(1)

    def result(self):
        return tf.math.divide_no_nan(self.accumulator, self.counter)

    def reset_states(self):
        self.accumulator.assign(0.0)
        self.counter.assign(0.0)

In [None]:
class Trainer:
    def __init__(self, args):
        self.args = args
        self.tokenizer = T5Tokenizer(
            vocab_file=f"{self.args.main_dir}/tokenizer_text/t5.model",
            eos_token="</s>",
            unk_token="<unk>",
            pad_token="<pad>")
        self.train_dataset = DataLoader(args).train
        self.val_dataset = DataLoader(args).val
        schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=args.learning_rate,
            decay_steps=1,
            decay_rate=0.5,
            staircase=False
        )
        self.optimizer = tf.keras.optimizers.Adam(args.learning_rate)
        self.bleu_metric = BLEUMetric()
        self.gradient_accumulator = GradientAccumulator()
        self.gradient_accumulator.accum_steps = args.accum_steps
        self.model = TFT5ForConditionalGeneration.from_pretrained(
            "t5-base",
            output_hidden_states=False,
            output_attentions=False,
            use_cache=True)

        self.model_name = f"model_{int(self.args.n_samples/1000)}k"
        self.log_path = f"{self.args.main_dir}/bart_model_weights\{self.model_name}.csv"
        if not os.path.exists(self.log_path):
            print("Log file created.")
            columns = "epoch,loss,bleu,val_loss,val_bleu\n"
            with open(self.log_path, "a") as f:
                f.write(columns)

    def decoder(self, labels, logits):
        labels = tf.where(labels < 0, x=0, y=labels)
        labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)        
        logits = tf.argmax(logits, axis=-1)
        logits = ['.' if x == '' else x for x in self.tokenizer.batch_decode(logits, skip_special_tokens=True)]
        return labels, logits

    def display(self, epoch, t_labels, t_logits, v_labels, v_logits):
        print("-" * 129)
        print("Training")
        for y_true, y_pred in zip(t_labels, t_logits):
            print(f"Target:    {y_true}")
            print(f"Predicted: {y_pred}") 

        print("\nValidation")
        for y_true, y_pred in zip(v_labels, v_logits):
            print(f"Target:    {y_true}")
            print(f"Predicted: {y_pred}")
        print("-" * 129)
        
    def fit(self):
        # Checkpointing
        self.ckpt_dir = f"{self.args.main_dir}/bart_checkpoints_{int(self.args.n_samples/1000)}k"
        self.ckpt = tf.train.Checkpoint(self.model)
        self.ckpt_manager = tf.train.CheckpointManager(
            checkpoint=self.ckpt, directory=self.ckpt_dir, max_to_keep=5)

        if self.ckpt_manager.latest_checkpoint:
            self.start_epoch = int(self.ckpt_manager.latest_checkpoint.split("-")[-1])
            self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
            print(f"Resuming from epoch {self.start_epoch + 1}...")
        else:
            self.start_epoch = 0
            print("Starting from epoch 1...")

        for epoch in range(self.args.epochs):
            print(f"Epoch {epoch+1}/{self.args.epochs}")
            # print(f"Epoch {epoch+1}/{self.args.epochs}: Learning rate @ {self.optimizer.lr(epoch):.2e}")
            stateful_metrics = ["loss", "bleu", "val_loss", "val_bleu"]
            progbar = tf.keras.utils.Progbar(
                self.args.train_steps, interval=0.05,
                stateful_metrics=stateful_metrics)

            # Training loop
            for step, t_batch in enumerate(self.train_dataset):
                with tf.GradientTape() as tape:
                    t_loss, t_logits = self.model(
                        input_ids=t_batch['input_ids'],
                        attention_mask=t_batch['attention_mask'],
                        labels=t_batch['labels'],
                        training=True)[:2]              
                self.gradient_accumulator(tape.gradient(t_loss, self.model.trainable_weights))
                self.optimizer.apply_gradients(zip(
                    self.gradient_accumulator.gradients, 
                    self.model.trainable_weights))
                t_labels, t_logits = self.decoder(t_batch['labels'], t_logits)
                self.bleu_metric.update_state(t_labels, t_logits)
                t_bleu = self.bleu_metric.result()
                t_values = [
                    ("loss", tf.reduce_mean(t_loss)),
                    ("bleu", t_bleu)]
                progbar.update(step, values=t_values, finalize=False)
            self.bleu_metric.reset_states()

            # Validation loop
            for v_batch in self.val_dataset:
                v_loss, v_logits = self.model(
                    input_ids=v_batch['input_ids'],
                    attention_mask=v_batch['attention_mask'],
                    labels=v_batch['labels'],
                    training=False)[:2]
                v_labels, v_logits = self.decoder(v_batch['labels'], v_logits)
                self.bleu_metric.update_state(v_labels, v_logits)
            
            v_bleu = self.bleu_metric.result()
            v_values = [
                ("loss", tf.reduce_mean(t_loss)), 
                ("bleu", t_bleu),
                ("val_loss", tf.reduce_mean(v_loss)),
                ("val_bleu", v_bleu)]
            progbar.update(self.args.train_steps, values=v_values, finalize=True)
            self.bleu_metric.reset_states()

            # Print sample transcriptions for both loops
            self.display(epoch, t_labels, t_logits, v_labels, v_logits)

            # Checkpointing
            self.ckpt.save(file_prefix=f"{self.ckpt_dir}/{self.model_name}")

            # Logging
            log = f"{epoch+1},{t_loss},{t_bleu},{v_loss},{v_bleu}\n"
            with open(self.log_path, "a") as f:
                f.write(log)

            save_path = f"{self.args.main_dir}/bart_model_weights"
            self.model.save_weights(f"{save_path}/{self.model_name}_{epoch+1}.h5")

Trainer(args).fit()