In [1]:
import os
import re
import ast
import glob
import json
import random
import cutlet
import jiwer
import argparse
import MeCab
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from convert_romaji import Romaji2Kana

import pandas as pd
import numpy as np
import seaborn as sns
from sacrebleu.metrics import BLEU, CHRF

from sklearn.model_selection import KFold, train_test_split

import tensorflow as tf

import sentencepiece as spm

from transformers import (
    MarianConfig,
    MarianTokenizer,
    TFMarianMTModel,
    GradientAccumulator,
    logging)

# policy = tf.keras.mixed_precision.Policy('mixed_float16')
# tf.keras.mixed_precision.set_global_policy(policy)

def seed_everything(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    print("Random seed set.")

seed_everything(42)
tf.get_logger().setLevel('FATAL')
logging.set_verbosity_error()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Random seed set.


# Data Preparation

### Clean text

In [2]:
# katsu = cutlet.Cutlet()
# katsu.use_foreign_spelling = False

# def clean_kanji(sentence):
#     symbols = r"\（.*\）|\(.*\)|\「.*\」|\『.*\』"
#     sentence = re.sub(symbols, "", sentence.strip())
#     return sentence

# def clean_romaji(sentence):
#     sentence = sentence.strip().lower()
#     sentence = re.sub(r"[^a-zA-Z0-9\ ]", "", sentence)
#     sentence = sentence.split()
#     for i, mora in enumerate(sentence):
#         if (mora == "n") | (mora == "u") & (i < len(sentence) - 1):
#             prev_mora = sentence.pop(i-1)
#             sentence[i-1] = "".join([prev_mora, mora])
#     sentence = " ".join(sentence)
#     return sentence.strip()

# def kanji2romaji(text):
#     try:
#         new_line = katsu.romaji(text)
#         new_line = clean_romaji(new_line)
#     except:
#         new_line = None
#     return new_line

# def clean_en(sentence):
#     sentence = re.sub(r"\(.*\)|\[.*\]|\{.*\}", "", sentence.strip().lower()) # Parenthesis
#     sentence = re.sub(r"\.{3,}", ",", sentence) # Ellipsis
#     sentence = re.sub(r"[^A-Za-z0-9\ \?\.\!\,\'\"]", "", sentence) # Non alphanumeric
#     return sentence.strip()

In [3]:
# main_dir = "D:\School-stuff\Sem-2\PR-Project\HoloASR\Datasets"
# opus_ja_paths = glob.glob(f"{main_dir}\OPUS100-dataset\*.ja")
# tatoeba_ja_paths = glob.glob(f"{main_dir}\Tatoeba-dataset\*.ja")
# jesc_path = f"{main_dir}/JESC-dataset/raw"
# ja_paths = opus_ja_paths + tatoeba_ja_paths

# ja_lines, en_lines = [], []
# for ja_path in ja_paths:
#     if ja_path.endswith(".ja"):
#         en_path = ja_path.rsplit(".", 1)[0] + ".en"
#     else:
#         en_path = ja_path.replace("ja", "en")
#     with open(ja_path, "r", encoding="utf-8") as f:
#         lines = [line.strip("\n") for line in f.readlines()]
#         ja_lines.extend(lines)
#     with open(en_path, "r", encoding="utf-8") as f:
#         lines = [line.strip("\n") for line in f.readlines()]
#         en_lines.extend(lines)

# with open(jesc_path, "r", encoding="utf-8") as f:
#     texts = [text.split("\t") for text in f.readlines()]
#     for en, ja in tqdm(texts):
#         ja_lines.append(ja)
#         en_lines.append(en)

# tqdm.pandas()
# data = pd.DataFrame({'ja_raw': ja_lines, 'en': en_lines})
# data['ja_raw'] = data['ja_raw'].progress_apply(clean_kanji)
# data['ja_ro'] = data['ja_raw'].progress_apply(kanji2romaji)
# data = data.dropna().reset_index(drop=True)
# data['ja_hira'] = data['ja_ro'].progress_apply(lambda x: Romaji2Kana(x).strip())
# data['en'] = data['en'].progress_apply(clean_en)
# data = data.dropna().reset_index(drop=True)
# data.to_csv(
#     r"E:\Datasets\Language_model\text_data\tokenizer_text.csv",
#     index=False, encoding="utf-8")
# data

  0%|          | 0/2801388 [00:00<?, ?it/s]

  0%|          | 0/4014254 [00:00<?, ?it/s]

  0%|          | 0/4014254 [00:00<?, ?it/s]

  0%|          | 0/4013985 [00:00<?, ?it/s]

  0%|          | 0/4013985 [00:00<?, ?it/s]

Unnamed: 0,ja_raw,en,ja_ro,ja_hira
0,アレックス・ダンバースを殺すか お前が殺されるか。,kill alex danvers or let her kill you.,arekkusudan baasu wo korosu ka omae ga korosar...,あれっくすだん ばあす を ころす か おまえ が ころされる か
1,あいつに聞こえるわ,"shut up. he can hear you, he can hear you.",aitsu ni kikoeru wa,あいつ に きこえる わ
2,Vは人間のDNAを解析して 皮膚を生成 人間に疑われないように,"the v's used the human d.n.a. to replicate it,...",v wa ningen no dna wo kaiseki shite hifu wo se...,は にんげん の な を かいせき して ひふ を せいせい にんげん に うたがわれない ...
3,スーパーマンが登場以来、超最悪の男の 戦争はその都市で繰り広げられている。,"since superman's arrival, a veritable super ba...",suupaaman ga toujou irai chousaiaku no otoko n...,すうぱあまん が とうじょう いらい ちょうさいあく の おとこ の せんそう は その と...
4,こんにちは？,hello?,konnichiha,こんにちは
...,...,...,...,...
4013980,知って何の意味がある?,what good would that have done?,shitte nan no imi ga aru,しって なん の いみ が ある
4013981,二之助さんに駆け寄った 一之助さんは,when ichinosukesan rushed over to ninosukesan..,ni no sukesan ni kakeyotta ichinosukesan wa,に の すけさん に かけよった いちのすけさん わ
4013982,そういった お言葉を力に...。,your words are so encouraging.,sou itta okotoba wo chikara ni,そう いった おことば を ちから に
4013983,切るぞ,"goodbye, george.",kiru zo,きる ぞ


### Train SPM model

In [6]:
# main_dir = "E://Datasets/Language_model/text_data"
# data = pd.read_csv(f"{main_dir}/tokenizer_text.csv")
# data = data.dropna().reset_index(drop=True)[['en', 'ja_hira']]

# with open(f"{main_dir}/ja_text.txt", "w", encoding="utf-8") as f:
#     ja_lines = data['ja_hira'].apply(lambda x: x + "\n").tolist()
#     f.writelines(ja_lines)

# with open(f"{main_dir}\en_text.txt", "w", encoding="utf-8") as f:
#     en_lines = data['en'].apply(lambda x: x + "\n").tolist()
#     f.writelines(en_lines)

# ja_file = f"{main_dir}/ja_text.txt"
# en_file = f"{main_dir}/en_text.txt"
# ja_prefix = f"{main_dir}/ja_spm"
# en_prefix = f"{main_dir}/en_spm"
# vocab_prefix = f"{main_dir}/vocab_spm"

# spm.SentencePieceTrainer.train(
#     f"--input={ja_file} --model_prefix={ja_prefix} --vocab_size={10003} --model_type=unigram --pad_id=0 --unk_id=2 --bos_id=-1 --eos_id=1 --pad_piece=<pad> --unk_piece=<unk> --eos_piece=</s>"
# )

# spm.SentencePieceTrainer.train(
#     f"--input={en_file} --model_prefix={en_prefix} --vocab_size={50716} --model_type=unigram --pad_id=0 --unk_id=2 --bos_id=-1 --eos_id=1 --pad_piece=<pad> --unk_piece=<unk> --eos_piece=</s>"
# )

# vocab = []
# with open(f"{main_dir}/ja_spm.vocab", "r", encoding="utf-8") as f:
#     for text in f.readlines():
#         token, _ = text.split("\t")
#         vocab.append(token)

# with open(f"{main_dir}/en_spm.vocab", "r", encoding="utf-8") as f:
#     for text in f.readlines()[3:]:
#         token, _ = text.split("\t")
#         vocab.append(token)

# with open(f"{main_dir}/vocab.json", "w", encoding="utf-8") as f:
#     vocab = {k: v for v, k in enumerate(vocab)}
#     json.dump(vocab, f, ensure_ascii=True, indent=4)

### Tokenize

In [20]:
main_dir = "E://Datasets/Language_model/text_data"

tokenizer = MarianTokenizer(
    source_spm=f"{main_dir}/ja_spm.model",
    target_spm=f"{main_dir}/en_spm.model",
    source_lang="ja",
    target_lang="en",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    vocab=f"{main_dir}/vocab.json")

tqdm.pandas()
data = pd.read_csv(f"{main_dir}/tokenizer_text.csv", encoding="utf-8")
data = data.dropna().reset_index(drop=True)[['en', 'ja_hira']]
data['ja_token'] = data['ja_hira'].progress_apply(lambda x: tokenizer(x).input_ids)
with tokenizer.as_target_tokenizer():
    data['en_token'] = data['en'].progress_apply(lambda x: tokenizer(x).input_ids)
data['ja_len'] = data['ja_token'].apply(len)
data = data.query("ja_len <= 17")
data.to_csv(f"{main_dir}/marian_text.csv", index=False)
data

  0%|          | 0/3987235 [00:00<?, ?it/s]

  0%|          | 0/3987235 [00:00<?, ?it/s]

Unnamed: 0,en,ja_hira,ja_token,en_token,ja_len
0,kill alex danvers or let her kill you.,あれっくすだん ばあす を ころす か おまえ が ころされる か,"[3016, 638, 678, 86, 6, 574, 11, 60, 7, 2075, ...","[10220, 11581, 29914, 10093, 10084, 10068, 102...",12
1,"shut up. he can hear you, he can hear you.",あいつ に きこえる わ,"[475, 5, 1493, 19, 1]","[10511, 10062, 10003, 10024, 10038, 10256, 100...",5
2,"the v's used the human d.n.a. to replicate it,...",は にんげん の な を かいせき して ひふ を せいせい にんげん に うたがわれない ...,"[4, 231, 3, 13, 6, 5164, 21, 4142, 6, 4763, 23...","[10006, 11704, 10005, 10009, 10260, 10006, 103...",17
4,hello?,こんにちは,"[1416, 1]","[10433, 10011, 1]",2
5,"have a good night, dear.",よい よる を,"[184, 353, 6, 1]","[10028, 10012, 10086, 10196, 10004, 11040, 100...",4
...,...,...,...,...,...
3987230,what good would that have done?,しって なん の いみ が ある,"[282, 17, 3, 324, 7, 32, 1]","[10023, 10086, 10092, 10014, 10028, 10228, 100...",7
3987231,when ichinosukesan rushed over to ninosukesan..,に の すけさん に かけよった いちのすけさん わ,"[5, 3, 3612, 77, 5, 767, 5128, 107, 269, 3171,...","[10077, 55162, 10843, 16500, 10130, 10010, 557...",13
3987232,your words are so encouraging.,そう いった おことば を ちから に,"[28, 112, 6263, 6, 397, 5, 1]","[10033, 10588, 10031, 10041, 19016, 10003, 1]",7
3987233,"goodbye, george.",きる ぞ,"[2103, 72, 1]","[11143, 10004, 11641, 10003, 1]",3


# Data Loading

In [None]:
def ArgParser():
    parser = argparse.ArgumentParser()

    parser.add_argument("--random_state", default=42)
    parser.add_argument("--main_dir", default="E://Datasets/Language_model")
    parser.add_argument("--n_shards", default=10)
    parser.add_argument("--n_samples", default=1200000)
    parser.add_argument("--test_size", default=0.1)
    parser.add_argument("--batch_size", default=32)
    parser.add_argument("--buffer_size", default=1024)

    # Scheduler
    parser.add_argument("--epochs", default=15)
    parser.add_argument("--learning_rate", default=3e-4)
    parser.add_argument("--lr_start", default=1e-8)
    parser.add_argument("--lr_min", default=1e-8)
    parser.add_argument("--lr_max", default=3e-4)
    parser.add_argument("--n_cycles", default=0.5)
    parser.add_argument("--warmup_epochs", default=2)
    parser.add_argument("--sustain_epochs", default=1)

    args = parser.parse_known_args()[0]

    n_train = int(args.n_samples * (1 - args.test_size))
    n_val = int(args.n_samples * args.test_size)
    train_steps = int(np.ceil(n_train / args.batch_size))
    val_steps = int(np.ceil(n_val / args.batch_size))
        
    parser.add_argument("--n_train", default=n_train)
    parser.add_argument("--n_val", default=n_val)
    parser.add_argument("--train_steps", default=train_steps)  
    parser.add_argument("--val_steps", default=val_steps)

    return parser.parse_known_args()[0]

args = ArgParser()
args

In [None]:
class TFRWriter():
    def __init__(self, args):
        self.args = args
        self.data = self.get_data()

    def get_data(self):
        tqdm.pandas()
        data = pd.read_csv(
            f"{self.args.main_dir}/text_data/marian_text.csv", encoding="utf-8")
        data['ja_token'] = data['ja_token'].progress_apply(ast.literal_eval)
        data['en_token'] = data['en_token'].progress_apply(ast.literal_eval)
        data = data.sample(
            n=self.args.n_samples,
            random_state=self.args.random_state,
            ignore_index=True)
        data = data.sort_values(by="ja_len", ignore_index=True, ascending=True)
        return data

    def _bytes_feature(self, value):
        """Returns a bytes_list from a string / byte."""
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def serialize_example(self, *args):
        feature = {
            'input_ids': self._bytes_feature(args[0]),
            'attention_mask': self._bytes_feature(args[1]),
            'label_ids': self._bytes_feature(args[2])}

        example_proto = tf.train.Example(
            features=tf.train.Features(feature=feature))
        return example_proto.SerializeToString()

    def get_shards(self):
        skf = KFold(n_splits=self.args.n_shards, shuffle=False)
        return [j for i,j in skf.split(self.data)]

    def get_shard_data(self, samples):
        for sample in samples:
            input_ids = tf.convert_to_tensor(
                self.data['ja_token'][sample], dtype=tf.int32)
            attention_mask = tf.where(input_ids != 0, x=1, y=0)
            label_ids = tf.convert_to_tensor(
                self.data['en_token'][sample], dtype=tf.int32)
            yield {
                "input_ids": tf.io.serialize_tensor(input_ids),
                "attention_mask": tf.io.serialize_tensor(attention_mask),
                "label_ids": tf.io.serialize_tensor(label_ids)}

    def write(self):
        for shard, samples in tqdm(enumerate(self.get_shards()), total=self.args.n_shards):
            with tf.io.TFRecordWriter(f"{self.args.main_dir}/marian_tfrec/shard_{shard+1}.tfrec") as f:
                for sample in self.get_shard_data(samples):
                    example = self.serialize_example(
                        sample['input_ids'],
                        sample['attention_mask'],
                        sample['label_ids'])
                    f.write(example)

# TFRWriter(args).write()

In [None]:
class DataLoader:
    def __init__(self, args):
        self.files = glob.glob(args.main_dir + "/marian_tfrec/*.tfrec")
        self.args = args
        self.AUTOTUNE = tf.data.AUTOTUNE
        self.train_files, self.val_files = train_test_split(
            self.files, test_size=args.test_size, shuffle=True, 
            random_state=args.random_state)
        self.train = self.get_train()
        self.val = self.get_val()

    def read_tfrecord(self, example):
        feature_description = {
            'input_ids': tf.io.FixedLenFeature([], tf.string),
            'attention_mask': tf.io.FixedLenFeature([], tf.string),
            'label_ids': tf.io.FixedLenFeature([], tf.string)
            }
        
        example = tf.io.parse_single_example(example, feature_description)
        example['input_ids'] = tf.io.parse_tensor(
            example['input_ids'], out_type=tf.int32)
        example['attention_mask'] = tf.io.parse_tensor(
            example['attention_mask'], out_type=tf.int32) 
        example['label_ids'] = tf.io.parse_tensor(
            example['label_ids'], out_type=tf.int32)
        return example

    def load_dataset(self, files):
        ignore_order = tf.data.Options()
        ignore_order.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(files)
        dataset = dataset.with_options(ignore_order)
        dataset = dataset.map(self.read_tfrecord, num_parallel_calls=self.AUTOTUNE)
        return dataset

    def get_train(self):
        dataset = self.load_dataset(self.train_files)
        dataset = dataset.padded_batch(
            self.args.batch_size,
            padded_shapes={
                'input_ids': [None],
                'attention_mask': [None],
                'label_ids': [None]
            },
            padding_values={
                'input_ids': tf.constant(0, dtype=tf.int32),
                'attention_mask': tf.constant(0, dtype=tf.int32),
                'label_ids': tf.constant(-100, dtype=tf.int32)
            })        
        dataset = dataset.shuffle(self.args.buffer_size)
        dataset = dataset.prefetch(self.AUTOTUNE)
        return dataset

    def get_val(self):
        dataset = self.load_dataset(self.val_files)
        dataset = dataset.padded_batch(
            self.args.batch_size,
            padded_shapes={
                'input_ids': [None],
                'attention_mask': [None],
                'label_ids': [None]
            },
            padding_values={
                'input_ids': tf.constant(0, dtype=tf.int32),
                'attention_mask': tf.constant(0, dtype=tf.int32),
                'label_ids': tf.constant(-100, dtype=tf.int32)
            })
        dataset = dataset.shuffle(self.args.buffer_size)
        dataset = dataset.cache()
        dataset = dataset.prefetch(self.AUTOTUNE)
        return dataset

# train = DataLoader(args).train

# inputs = next(iter(train))
# input_values = inputs['input_ids']
# attention_mask = inputs['attention_mask']
# labels = inputs['label_ids']
# print(inputs)

# Model Training

In [None]:
class BLEUMetric(tf.keras.metrics.Metric):
    def __init__(self, name="BLEU", **kwargs):
        super(BLEUMetric, self).__init__(name=name, **kwargs)
        self.bleu = BLEU()
        self.accumulator = self.add_weight(name="total_bleu", initializer="zeros")
        self.counter = self.add_weight(name="counter", initializer="zeros")
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        try:
            bleu_score = self.bleu.corpus_score(hypotheses=y_pred, references=y_true).score
        except:
            y_pred = ["." if x == "" else x for x in y_pred]
            bleu_score = self.bleu.corpus_score(hypotheses=y_pred, references=y_true).score
        self.accumulator.assign_add(bleu_score)
        self.counter.assign_add(1)

    def result(self):
        return tf.math.divide_no_nan(self.accumulator, self.counter)

    def reset_state(self):
        self.accumulator.assign(0.0)
        self.counter.assign(0.0)

class WERMetric(tf.keras.metrics.Metric):
    def __init__(self, name="WER", **kwargs):
        super(WERMetric, self).__init__(name=name,  **kwargs)
        self.accumulator = self.add_weight(name="total_wer", initializer="zeros")
        self.counter = self.add_weight(name="wer_count", initializer="zeros")    

    def update_state(self, y_true, y_pred, sample_weight=None):
        wer = jiwer.wer(y_true, y_pred)

        # Add distance and number of batches to variables
        self.accumulator.assign_add(wer)
        self.counter.assign_add(1)

    def result(self):
        # Divides accumulated distance scores against number of batches passed,
        # mimics mean reduction over batch
        return tf.math.divide_no_nan(self.accumulator, self.counter)   
    
    def reset_states(self):
        self.accumulator.assign(0.0)
        self.counter.assign(0.0)

class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, args):
        self.args = args

    def __call__(self, epoch):  
        if epoch < self.args.warmup_epochs:
            lr = ((self.args.lr_max - self.args.lr_start) / self.args.warmup_epochs) * epoch + self.args.lr_start
        elif epoch < (self.args.warmup_epochs + self.args.sustain_epochs):
            lr = self.args.lr_max
        else:
            progress = ((epoch - self.args.warmup_epochs - self.args.sustain_epochs) / 
            (self.args.epochs - self.args.warmup_epochs - self.args.sustain_epochs))
            lr = (self.args.lr_max-self.args.lr_min) * (0.5 * (1.0 + tf.math.cos((22/7) * 
                self.args.n_cycles * 2.0 * progress)))
            if self.args.lr_min is not None:
                lr = tf.math.maximum(self.args.lr_min, lr)
        return lr

    def plot(self):
        epochs = range(self.args.epochs+1)
        lr = [self(epoch) for epoch in epochs]
        plt.plot(epochs, lr)
        plt.xlabel("learning_rate")
        plt.ylabel("epochs")
        plt.show()

CosineDecayWithWarmup(args).plot()

In [None]:
class Trainer:
    def __init__(self, args):
        self.args = args
        self.tokenizer = MarianTokenizer(
            source_spm=f"{args.main_dir}/text_data/ja_spm.model",
            target_spm=f"{args.main_dir}/text_data/en_spm.model",
            source_lang="ja",
            target_lang="en",
            eos_token="</s>",
            unk_token="<unk>",
            pad_token="<pad>",
            vocab=f"{args.main_dir}/text_data/vocab.json")
        self.train_dataset = DataLoader(args).train
        self.val_dataset = DataLoader(args).val
        schedule = CosineDecayWithWarmup(args)
        self.optimizer = tf.keras.optimizers.Adam(schedule)
        self.bleu_metric = BLEUMetric()

        self.model = TFMarianMTModel.from_pretrained(
            "Helsinki-NLP/opus-mt-ja-en",
            bad_words_ids=[[0]],
            bos_token_id=1,
            eos_token_id=1,
            pad_token_id=0,
            decoder_start_token_id=0,
            from_pt=True,
            use_cache=False
        )

        self.model_name = f"model_{int(self.args.n_samples/1000)}k"
        self.log_path = f"{self.args.main_dir}/model_weights/{self.model_name}.csv"
        if not os.path.exists(self.log_path):
            print("Log file created.")
            columns = "epoch,loss,bleu,val_loss,val_bleu\n"
            with open(self.log_path, "a") as f:
                f.write(columns)

    def decoder(self, labels, logits):
        labels = tf.where(labels == -100, x=0, y=labels)
        labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
        logits = tf.argmax(logits, axis=-1)
        logits = self.tokenizer.batch_decode(logits, skip_special_tokens=True)
        return labels, logits

    def display(self, epoch, t_labels, t_logits, v_labels, v_logits):
        print("-" * 129)
        print("Training")
        for y_true, y_pred in zip(t_labels, t_logits):
            print(f"Target:    {y_true}")
            print(f"Predicted: {y_pred}") 

        print("\nValidation")
        for y_true, y_pred in zip(v_labels, v_logits):
            print(f"Target:    {y_true}")
            print(f"Predicted: {y_pred}")
        print("-" * 129)
        
    def fit(self):
        # Checkpointing
        self.ckpt_dir = f"{self.args.main_dir}/checkpoints"
        self.ckpt = tf.train.Checkpoint(self.model)
        self.ckpt_manager = tf.train.CheckpointManager(
            checkpoint=self.ckpt, directory=self.ckpt_dir, max_to_keep=5)

        if self.ckpt_manager.latest_checkpoint:
            self.start_epoch = int(self.ckpt_manager.latest_checkpoint.split("-")[-1])
            self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
            print(f"Resuming from epoch {self.start_epoch + 1}...")
        else:
            self.start_epoch = 0
            print("Starting from epoch 1...")

        for epoch in range(self.args.epochs):
            print(f"Epoch {epoch+1}/{self.args.epochs}: Learning rate @ {self.optimizer.lr(epoch):.2e}")
            stateful_metrics = ["loss", "bleu", "val_loss", "val_bleu"]
            progbar = tf.keras.utils.Progbar(
                self.args.train_steps, interval=0.05,
                stateful_metrics=stateful_metrics)

            # Training loop
            for step, t_batch in enumerate(self.train_dataset):
                with tf.GradientTape() as tape:
                    t_loss, t_logits = self.model(
                        input_ids=t_batch['input_ids'],
                        attention_mask=t_batch['attention_mask'],
                        labels=t_batch['label_ids'],
                        training=True)[:2]              
                
                gradients = tape.gradient(t_loss, self.model.trainable_weights)
                self.optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))
                t_labels, t_logits = self.decoder(t_batch['label_ids'], t_logits)
                self.bleu_metric.update_state(t_labels, t_logits)
                t_bleu = self.bleu_metric.result()
                t_values = [
                    ("loss", tf.reduce_mean(t_loss)),
                    ("bleu", t_bleu)]
                progbar.update(step, values=t_values, finalize=False)
            self.bleu_metric.reset_state()

            # Validation loop
            for v_batch in self.val_dataset:
                v_loss, v_logits = self.model(
                    input_ids=v_batch['input_ids'],
                    attention_mask=v_batch['attention_mask'],
                    labels=v_batch['label_ids'],
                    training=False)[:2]
                v_labels, v_logits = self.decoder(v_batch['label_ids'], v_logits)
                self.bleu_metric.update_state(v_labels, v_logits)
            
            v_bleu = self.bleu_metric.result()
            v_values = [
                ("loss", tf.reduce_mean(t_loss)),
                ("bleu", t_bleu),
                ("val_loss", tf.reduce_mean(v_loss)),
                ("val_bleu", v_bleu)]
            progbar.update(self.args.train_steps, values=v_values, finalize=True)
            self.bleu_metric.reset_state()

            # Print sample transcriptions for both loops
            self.display(epoch, t_labels, t_logits, v_labels, v_logits)

            # Checkpointing
            self.ckpt.save(file_prefix=f"{self.ckpt_dir}/{self.model_name}")

            # Logging
            log = f"{epoch+1},{tf.reduce_mean(t_loss)},{t_bleu},{tf.reduce_mean(v_loss)},{v_bleu}\n"
            with open(self.log_path, "a") as f:
                f.write(log)

            save_path = f"{self.args.main_dir}/model_weights"
            self.model.save_weights(f"{save_path}/{self.model_name}_{epoch+1}.h5")

Trainer(args).fit()