In [None]:
import os
import re
import glob
import json
import random
import argparse
import pandas as pd
import numpy as np

import librosa
import librosa.display
import soundfile as sf

from tqdm import tqdm
import subprocess
from functools import partial
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns

import MeCab
import cutlet

from sklearn.model_selection import train_test_split, KFold

import tensorflow as tf
import tensorflow_io as tfio
from tensorflow.keras.layers import *

from transformers import (
    Wav2Vec2CTCTokenizer,
    TFWav2Vec2ForCTC,
    Wav2Vec2Processor,
    Wav2Vec2FeatureExtractor)

def seed_everything(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    print("Random seed set.")

seed_everything(42)
tf.get_logger().setLevel('FATAL')
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Preprocessing

In [None]:
class Dataset:
    def __init__(self):
        self.main_dir = "E://Datasets/ASR-dataset"
        self.sample_rate = 16000
        self.n_shards = 20
        self.data = pd.concat([
            self.get_kokoro(),
            self.get_jsut(),
            self.get_commonvoice()
            ], 
            ignore_index=True)
        self.katsu = cutlet.Cutlet()
        self.wakati = MeCab.Tagger("-Owakati")
    
        tqdm.pandas()
        self.data['sentence'] = self.data['sentence'].progress_apply(self.clean_kanji)
        self.data['romaji'] = self.data['sentence'].progress_apply(self.katsu.romaji)
        self.data['romaji'] = self.data['romaji'].progress_apply(self.clean_romaji)
        self.data['romaji'] = self.data['romaji'].str.lower()
        self.data['length'] = self.data['path'].progress_apply(self.get_length)
        self.data.query("(length >= 48000) & (length <= 80000)", inplace=True)
        self.data = self.data[self.data['sentence'].apply(list).apply(len)>=5]
        self.data = self.data.dropna()
        self.data = self.data.sample(n=10000, random_state=42, ignore_index=True)
        self.data.sort_values(by="length", axis=0, ascending=False, inplace=True, ignore_index=True)
        self.data.to_csv(f"{self.main_dir}/ASRDataset.csv", encoding="utf-8", index=False)

    def get_kokoro(self):
        in_dir = "Datasets\KOKORO-dataset"

        data = []
        transcript_path = f"{in_dir}/transcripts/*.metadata.txt"
        for transcript in glob.glob(transcript_path):
            with open(transcript, "r", encoding="utf-8") as f:
                for line in f.readlines():
                    data.append(line.split("|"))

        data = pd.DataFrame(
            data, columns=[
                'text_id', 'path', 'start_idx', 
                'end_idx', 'sentence', 'phonemes'])       

        # paths = data['path'].unique()
        # for path in tqdm(paths, total=len(paths)):
        #     folder_name = path.split("_", 1)[0]
        #     in_path = os.path.join(in_dir, folder_name, path)
        #     y, sr = librosa.load(in_path, sr=None)
        #     for text_id in data.loc[data['path']==path, 'text_id']:
        #         out_path = os.path.join(self.main_dir, 'wav_cleaned', text_id) + ".wav"
        #         if not os.path.exists(out_path):
        #             start_idx = int(data.loc[data['text_id']==text_id, 'start_idx'].item())
        #             end_idx = int(data.loc[data['text_id']==text_id, 'end_idx'].item())
        #             y_slice = librosa.resample(
        #                 y[start_idx:end_idx], orig_sr=sr, target_sr=self.sample_rate)
        #             sf.write(out_path, y_slice, samplerate=self.sample_rate, subtype='PCM_16')

        data = data[['text_id', 'sentence']]
        data['text_id'] = data['text_id'].apply(lambda x: x + ".wav")
        data.columns = ['path', 'sentence']
        data['corpus'] = ['kokoro'] * len(data)
        return data

    def get_jsut(self):
        filenames, sentences = [], []
        for transcript in glob.glob(r"Datasets/JSUT-dataset/*/transcript_utf8.txt"):
            file_path = transcript.rsplit("\\", 1)[0]
            with open(transcript, "r", encoding="utf-8") as f:
                lines = f.readlines()
                for line in lines: 
                    filename, sentence = line.split(":")
                    filenames.append(os.path.join(file_path, "wav", filename) + ".wav")
                    sentences.append(sentence.strip("\n"))
        data = pd.DataFrame({'path': filenames, 'sentence': sentences}) 
        data['corpus'] = ['jsut'] * len(data)
        for i, in_path in tqdm(enumerate(data['path']), total=len(data['path'])):
            in_path = in_path.replace("\\", "/")
            out_path = f"{self.main_dir}\wav_cleaned"
            filename = in_path.rsplit("/", 1)[-1]
            out_path = os.path.join(out_path, filename)
            if not os.path.exists(out_path):
                subprocess.call([
                    "ffmpeg", "-i", in_path,"-acodec", "pcm_s16le", 
                    "-ar", str(self.sample_rate), out_path])
            data['path'][i] = filename
        return data

    def get_commonvoice(self):
        data = pd.read_csv(r"Datasets/CommonVoice-dataset/validated.tsv", sep="\t")
        data = data[['path', 'sentence']]    
        data['path'] = data['path'].apply(
            lambda x: r"Datasets/CommonVoice-dataset/mp3/" + x)
        data['corpus'] = ['common_voice'] * len(data)
        for i, in_path in tqdm(enumerate(data['path']), total=len(data['path'])):
            in_path = in_path.replace("\\", "/")
            out_path = f"{self.main_dir}\wav_cleaned"
            filename = in_path.rsplit("/", 1)[-1]
            filename = filename.replace("mp3", "wav")
            out_path = os.path.join(out_path, filename)
            if not os.path.exists(out_path):
                subprocess.call([
                    "ffmpeg", "-i", in_path,"-acodec", "pcm_s16le", 
                    "-ar", str(self.sample_rate), out_path])
            data['path'][i] = filename
        return data

    def clean_kanji(self, sentence):
        symbols = r"[（.*?）！-～.,;..._。、-〿・■（）：ㇰ-ㇿ㈠-㉃㊀-㋾㌀-㍿「」『』→ー -~‘–※π—ゐ’“”]"
        sentence = re.sub(symbols, "", sentence)
        sentence = self.wakati.parse(sentence).strip("\n")          
        return sentence

    def clean_romaji(self, sentence):
        return re.sub(r'[.,"\'\/?]', "", sentence)

    def get_length(self, path):
        path = os.path.join(self.main_dir, 'wav_cleaned', path)
        y, sr = librosa.load(path, sr=None)
        return len(y)

data = Dataset().data
data

In [None]:
# fig, ax = plt.subplots(1,1,figsize=(10, 4))
# sns.histplot(x=data['length'], hue=data['corpus'], ax=ax, palette="bright")
# plt.show()

# Arguments

In [None]:
def ArgParser():
    parser = argparse.ArgumentParser()

    # DataLoader
    parser.add_argument("--main_dir", default="E://Datasets/ASR-dataset")
    parser.add_argument("--sample_rate", default=16000)
    parser.add_argument("--test_size", default=0.1)
    parser.add_argument("--random_state", default=42)
    parser.add_argument("--batch_size", default=4)
    parser.add_argument("--n_shards", default=20)
    parser.add_argument("--buffer_size", default=512)

    # Trainer
    parser.add_argument("--epochs", default=20)
    parser.add_argument("--learning_rate", default=1e-4)
    parser.add_argument("--lr_start", default=1e-7)
    parser.add_argument("--lr_min", default=1e-7)
    parser.add_argument("--lr_max", default=1e-4)
    parser.add_argument("--n_cycles", default=0.5)
    parser.add_argument("--warmup_epochs", default=4)
    parser.add_argument("--sustain_epochs", default=0)

    args = parser.parse_known_args()[0]

    with open(f"{args.main_dir}/vocab.json", "r") as f:
        vocab_size = len(json.load(f))
    
    n_samples = len(pd.read_csv(os.path.join(args.main_dir, "ASRDataset.csv")))
    n_train = int(n_samples * (1 - args.test_size))
    n_val = int(n_samples * args.test_size)
    train_steps = int(np.ceil(n_train / args.batch_size))

    parser.add_argument("--vocab_size", default=vocab_size)
    parser.add_argument("--n_samples", default=n_samples)
    parser.add_argument("--n_train", default=n_train)
    parser.add_argument("--n_val", default=n_val)
    parser.add_argument("--train_steps", default=train_steps)  

    return parser.parse_known_args()[0]

args = ArgParser()
args

# Data Loading

In [None]:
class TFRWriter():
    def __init__(self, args):
        self.data = pd.read_csv(os.path.join(args.main_dir, "ASRDataset.csv"))
        self.args = args
        self.tokenizer = tokenizer = Wav2Vec2CTCTokenizer(
            f"{args.main_dir}/vocab.json",
            word_delimiter_token=' ',
            do_lower_case=False)

    def _bytes_feature(self, value):
        """Returns a bytes_list from a string / byte."""
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def _int64_feature(self, value):
        """Returns an int64_list from a bool / enum / int / uint."""
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def _float_feature(self, value):
        """Returns a float_list from a float / double."""
        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

    def serialize_example(self, *args):
        feature = {
            'input_values': self._bytes_feature(args[0]),
            'labels': self._bytes_feature(args[1]),
            }

        example_proto = tf.train.Example(
            features=tf.train.Features(feature=feature))
        return example_proto.SerializeToString()

    def get_labels(self, sample):
        labels = self.data.loc[self.data['path']==sample, "romaji"].item()
        labels = (self.tokenizer.bos_token + labels + 
            self.tokenizer.eos_token)
        labels = self.tokenizer(labels)['input_ids']
        return tf.convert_to_tensor(labels, dtype=tf.int32)

    def get_audio(self, sample):
        path = os.path.join(self.args.main_dir, "wav_cleaned", sample)
        audio = librosa.load(path, sr=None)[0]
        return tf.convert_to_tensor(audio, dtype=tf.float32)

    def get_shards(self):
        skf = KFold(n_splits=self.args.n_shards, shuffle=False)
        return [
            list(map(lambda x: self.data['path'][x], j))
            for i, j in skf.split(self.data['path'])]

    def get_shard_data(self, samples):
        for sample in samples:
            audio = self.get_audio(sample)
            labels = self.get_labels(sample)
            yield {
                'input_values': tf.io.serialize_tensor(audio),
                'labels': tf.io.serialize_tensor(labels),
            }

    def write(self):
        for shard, samples in tqdm(enumerate(self.get_shards()), total=self.args.n_shards):
            with tf.io.TFRecordWriter(f"{self.args.main_dir}/wav2vec2_tfrec/shard_{shard+1}.tfrec") as f:
                for sample in self.get_shard_data(samples):
                    example = self.serialize_example(
                        sample['input_values'], 
                        sample['labels'], 
                        )
                    f.write(example)

# TFRWriter(args).write()

In [None]:
class DataLoader:
    def __init__(self, args):
        self.files = glob.glob(args.main_dir + "/wav2vec2_tfrec/*.tfrec")
        self.args = args
        self.AUTOTUNE = tf.data.AUTOTUNE
        self.train_files, self.val_files = train_test_split(
            self.files, test_size=args.test_size, shuffle=True, 
            random_state=args.random_state)
        self.train = self.get_train()
        self.val = self.get_val()     

    def read_tfrecord(self, example):
        feature_description = {
            'input_values': tf.io.FixedLenFeature([], tf.string),
            'labels': tf.io.FixedLenFeature([], tf.string),
            }
        
        example = tf.io.parse_single_example(example, feature_description)
        example['input_values'] = tf.io.parse_tensor(
            example['input_values'], out_type=tf.float32)
        example['labels'] = tf.io.parse_tensor(
            example['labels'], out_type=tf.int32)
        return example

    def load_dataset(self, files):
        ignore_order = tf.data.Options()
        ignore_order.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(files)
        dataset = dataset.with_options(ignore_order)
        dataset = dataset.map(self.read_tfrecord, num_parallel_calls=self.AUTOTUNE)
        return dataset

    def get_train(self):
        dataset = self.load_dataset(self.train_files)
        dataset = dataset.padded_batch(
            self.args.batch_size,
            padded_shapes={
                'input_values': [None],
                'labels': [None],
                },
            padding_values={
                'input_values': tf.constant(0, dtype=tf.float32), 
                'labels': tf.constant(-1, dtype=tf.int32),
                })
        dataset = dataset.prefetch(self.AUTOTUNE)
        return dataset

    def get_val(self):
        dataset = self.load_dataset(self.val_files)
        dataset = dataset.padded_batch(
            self.args.batch_size,
            padded_shapes={
                'input_values': [None],
                'labels': [None],
                },
            padding_values={
                'input_values': tf.constant(0, dtype=tf.float32), 
                'labels': tf.constant(-1, dtype=tf.int32),
                })
        dataset = dataset.cache()
        dataset = dataset.prefetch(self.AUTOTUNE)
        return dataset

# Prepare Model

In [None]:
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1,
    sampling_rate=16000,
    padding_value=0.0,
    do_normalize=True,
    return_attention_mask=False
)

tokenizer = Wav2Vec2CTCTokenizer(
    r"E:\Datasets\ASR-dataset\vocab.json",
    word_delimiter_token=" ",
    do_lower_case=False
)

processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor, tokenizer=tokenizer
)

model = TFWav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base",
    ctc_loss_reduction="mean",
    from_pt=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

model.freeze_feature_extractor()

In [None]:
class PER(tf.keras.metrics.Metric):
    """Phoneme Error Rate

    This metric calculates the normalized error rate based on phonemes.

    Args:
        beam_width: (Optional)
        top_paths: (Optional)
        name: (Optional) string name of the metric instance

    """
    def __init__(self, beam_width=10, top_paths=1, name="PER", **kwargs):
        super(PER, self).__init__(name=name,  **kwargs)
        self.beam_width = beam_width
        self.top_paths = top_paths
        self.per_accumulator = self.add_weight(name="total_per", initializer="zeros")
        # self.counter = self.add_weight(name="per_count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        Function takes in model output logits and target labels and updates
        accumulator globally.

        Args: 
            y_true shape: [batch_size, sequence_length]
            y_pred shape: [batch_size, sequence_length, num_features]

        Returns:
            None

        """
        batch_size, sequence_length, num_features = tf.shape(y_pred)
        y_pred = tf.reshape(y_pred, [sequence_length, batch_size, num_features])
        sequence_length = tf.repeat(sequence_length, batch_size)

        # Decode logits into sparse tensor using beam search decoder
        hypothesis = tf.nn.ctc_beam_search_decoder(
            y_pred, sequence_length=sequence_length, beam_width=self.beam_width,
            top_paths=self.top_paths)[0][0]
        hypothesis = tf.cast(hypothesis, dtype=tf.int32)
        # Convert dense to sparse tensor for edit_distance function
        truth = tf.sparse.from_dense(y_true)
        # Calculate Levenshtein distance
        distance = tf.edit_distance(hypothesis, truth, normalize=True)
        self.per_accumulator.assign_add(tf.reduce_sum(distance))
        # self.counter.assign_add(len(y_true))

    def result(self):
        # return tf.math.divide_no_nan(self.per_accumulator, self.counter)
        return self.per_accumulator
    
    def reset_states(self):
        self.per_accumulator.assign(0.0)
        # self.counter.assign(0.0)

class CosineDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, args):
        self.args = args

    def __call__(self, epoch):  
        if epoch < self.args.warmup_epochs:
            lr = ((self.args.lr_max - self.args.lr_start) / self.args.warmup_epochs) * epoch + self.args.lr_start
        elif epoch < (self.args.warmup_epochs + self.args.sustain_epochs):
            lr = self.args.lr_max
        else:
            progress = ((epoch - self.args.warmup_epochs - self.args.sustain_epochs) / 
            (self.args.epochs - self.args.warmup_epochs - self.args.sustain_epochs))
            lr = (self.args.lr_max-self.args.lr_min) * (0.5 * (1.0 + tf.math.cos((22/7) * 
                self.args.n_cycles * 2.0 * progress)))
            if self.args.lr_min is not None:
                lr = tf.math.maximum(self.args.lr_min, lr)
        return lr

In [None]:
train_dataset = DataLoader(args).train
val_dataset = DataLoader(args).val

optimizer = tf.keras.optimizers.Adam(CosineDecaySchedule(args))
loss_metric = tf.keras.metrics.Mean(name="loss")
per_metric = PER()
stateful_metrics = ['loss', 'val_loss']

for epoch in range(args.epochs):
    progbar = tf.keras.utils.Progbar(
                args.train_steps, interval=0.05,
                stateful_metrics=stateful_metrics)
    print(f"Epoch {epoch+1}/{args.epochs}")

    # Training loop
    for step, batch in enumerate(train_dataset):
        X_train = batch['input_values']
        y_train = batch['labels']
        with tf.GradientTape() as tape:
            t_loss, t_logits = model(
                input_values=X_train, 
                labels=y_train, training=True)[:2]

        gradients = tape.gradient(t_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        loss_metric.update_state(t_loss)
        per_metric.update_state(y_train, t_logits)
        t_per = per_metric.result()

        values = [('loss', t_loss), ('per', t_per)]
        progbar.update(step, values=values, finalize=False)
        per_metric.reset_states()
   
    # Validation loop
    for batch in val_dataset:
        X_val = batch['input_values']
        y_val = batch['labels']
        v_loss, v_logits = model(
            input_values=X_val, 
            labels=y_val, training=False)[:2]
        per_metric.update_state(y_val, v_logits)
        v_per = per_metric.result()
        per_metric.reset_states()

    values = [
        ('loss', t_loss), ('per', t_per), 
        ('val_loss', v_loss), ('val_per', v_per)]
    progbar.update(args.train_steps, values=values)

    print("Training")
    y_preds = processor.batch_decode(tf.argmax(t_logits, axis=-1))
    y_trues = processor.batch_decode(y_train)
    
    for y_true, y_pred in zip(y_trues, y_preds):
        print(f"Target:    {y_true}")
        print(f"Predicted: {y_pred}\n")

    print("Validation")
    y_preds = processor.batch_decode(tf.argmax(v_logits, axis=-1))
    y_trues = processor.batch_decode(y_val)
    
    for y_true, y_pred in zip(y_trues, y_preds):
        print(f"Target:    {y_true}")
        print(f"Predicted: {y_pred}\n")

    model.save_weights(f"E:\Datasets\ASR-dataset\checkpoints\model_{epoch+1:02d}.h5")