In [None]:
import os
import sys

ROOT_DIR = os.path.abspath(os.path.join('..'))
sys.path.append(ROOT_DIR)

os.environ["WANDB_SILENT"] = "true"

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from transformers import AutoTokenizer
import lightning as L
import wandb

from data.dataloader import NoReCDataLoader
from dataloaders.bert import NoReCDataModule
from models.bert import BERT
from utils.utils import init_run

In [None]:
task = "binary"
# task = "multiclass"

MODEL_NAME =  "ltg/norbert3-xs"
# MODEL_NAME =  "ltg/norbert3-small"
# MODEL_NAME =  "ltg/norbert3-base"

In [None]:
if task == "binary":
    config = init_run(config_name="binary_bert", run_name="Binary-" + MODEL_NAME)
else:
    config = init_run(config_name="multiclass_bert", run_name="Multiclass-" + MODEL_NAME)

# Loading and processing data

In [None]:
if task == "binary":
    train_df, val_df, test_df = NoReCDataLoader(**config.dataloader).load_binary_dataset()
else:
    train_df, val_df, test_df = NoReCDataLoader(**config.dataloader).load_multiclass_dataset()

train_df = train_df[["text", "label"]]
val_df = val_df[["text", "label"]]
test_df = test_df[["text", "label"]]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
data_module = NoReCDataModule(
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    batch_size=config.general.batch_size,
    tokenizer=tokenizer,
    max_seq_len=config.model.max_seq_len,
    n_classes=config.model.n_classes
)

# Modeling and Training

In [None]:
model = BERT(
    model_name=MODEL_NAME,
    learning_rate=config.general.learning_rate,
    max_epochs=config.general.max_epochs,
    n_classes=config.model.n_classes
)

In [None]:
early_stopping = EarlyStopping(monitor="val_auc", patience=3, mode="max", verbose=True, check_on_train_epoch_end=True)
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = L.Trainer(
    max_epochs=config.general.max_epochs,
    logger=WandbLogger(save_dir=config.general.log_dir),
    callbacks=[early_stopping, lr_monitor]
)

In [None]:
trainer.fit(model, data_module)
trainer.validate(model, data_module)
trainer.test(model, data_module)
wandb.finish()