In [None]:
import sys
sys.path.append("..")

In [None]:
# import libraries
import os
import gc
import pandas as pd
import tensorflow as tf

from arjun import ArjunModel, ArjunTokenizer
from generator import BatchGenerator
from loss import masked_loss, masked_accuracy
from scheduler import LinearLRSchedule

In [None]:
# config
DATA_PATH = "../resources/data/raw_text.txt"
TOKENIZER_PATH = "../resources/vac-arjun-base"
CHECKPOINT_DIR = "../resources/checkpoints"

In [None]:
# read tokenizer
tokenizer = ArjunTokenizer.from_pretrained(TOKENIZER_PATH)
len(tokenizer)

In [None]:
# read data
with open(os.path.join(DATA_PATH), "r") as f:
    data = f.read().split("\n")

data = list(set(data))
df = pd.DataFrame({"text": data}).sample(frac=1.0, ignore_index=True)
del data
gc.collect()

In [None]:
# define model to train from scracth
config = {
    "num_layers": 6,
    "d_model": 512,
    "dff": 2048,
    "num_heads": 8,
    "dropout_rate": 0.1,
    "max_len": 256,
    "corr_prob": 0.15,
    "vocab_size": 50057
}

model = ArjunModel(config)
model.summary()

In [None]:
# set optimizer, loss and compile
batch_size = 32
num_epochs = 20
initial_learning_rate = 5e-5
learning_rate = LinearLRSchedule(df, batch_size, num_epochs, initial_learning_rate)
optimizer = tf.keras.optimizers.experimental.AdamW(learning_rate=learning_rate)

model.compile(loss={"ms_output": masked_loss, "mt_output": masked_loss, "nt_output": masked_loss, "cs_output": masked_loss}, 
                optimizer=optimizer, 
                metrics={"ms_output": masked_accuracy, "mt_output": masked_accuracy, "nt_output": masked_accuracy, "cs_output": masked_accuracy})

In [None]:
# define callbacks
checkpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join(CHECKPOINT_DIR, "{epoch}.weights.h5"), save_weights_only=True)
logs = tf.keras.callbacks.CSVLogger(os.path.join(CHECKPOINT_DIR, "logs.csv"))

In [None]:
# train model
gen = BatchGenerator(df=df, batch_size=batch_size, tokenizer=tokenizer, max_len=config["max_len"], corr_prob=config["corr_prob"])
model.fit(gen, epochs=num_epochs, callbacks=[checkpoint, logs])