In [None]:
import os
import sys

ROOT_DIR = os.path.abspath(os.path.join('..'))
sys.path.append(ROOT_DIR)

os.environ["WANDB_SILENT"] = "true"

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger
import lightning as L
import wandb

from data.dataloader import NoReCDataLoader
from data.preprocessor import NoReCDataPreprocessor
from dataloaders.lstm import NoReCDataModule
from models.lstm import LSTM
from utils.utils import init_run

In [None]:
config = init_run(config_name="multiclass_lstm", run_name="Multiclass-LSTM")

# Loading and processing data

In [None]:
train_df, val_df, test_df = NoReCDataLoader(**config.dataloader).load_multiclass_dataset()

preprocessor = NoReCDataPreprocessor()

train_df = preprocessor.sanitize(train_df, "train")
val_df = preprocessor.sanitize(val_df, "val")
test_df = preprocessor.sanitize(test_df, "test")

vocab, tokenizer = preprocessor.build_vocabulary(train_df, config.data.vocab_size)

train_df = preprocessor.tokenize(train_df, vocab, tokenizer)
val_df = preprocessor.tokenize(val_df, vocab, tokenizer)
test_df = preprocessor.tokenize(test_df, vocab, tokenizer)

train_df = preprocessor.pad(train_df, vocab, config.data.max_seq_len)
val_df = preprocessor.pad(val_df, vocab, config.data.max_seq_len)
test_df = preprocessor.pad(test_df, vocab, config.data.max_seq_len)

In [None]:
data_module = NoReCDataModule(
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    batch_size=config.general.batch_size
)
class_weights = data_module.get_class_weights()

# Modeling and Training

In [None]:
model = LSTM(**config.model, n_class=3, class_weights=class_weights)

early_stopping = EarlyStopping(monitor="val_auc", patience=3, mode="max", verbose=True, check_on_train_epoch_end=True)
trainer = L.Trainer(
    max_epochs=config.general.max_epochs,
    logger=WandbLogger(save_dir=config.general.log_dir),
    callbacks=[early_stopping],
)

trainer.fit(model, data_module)
trainer.validate(model, data_module)
trainer.test(model, data_module)
wandb.finish()