In [None]:
import os
import yaml

from src.models import *
from src.dataset import DataModule
from src.trainers import PhoBERTModel, FastTextLSTMModel

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger

In [None]:
class config:
    # for data
    root_data_dir = './data'
    model_type = 'bert'  # or 'lstm' for FastText-LSTM
    batch_size = 16
    num_workers = 2
    fasttext_embedding = None  # Otherwise specify path to embedding like src/embedding/fasttext_train_dev.model

    # for model
    model_name = 'BERT-FF-BASE'
    # [BERT | LSTM]-[FF | LSTM]-[BASE | LARGE] 
    # FASTTEXT-LSTM: FastText + LSTM
    from_pretrained = True
    freeze_backbone = False
    drop_out = 0.1
    out_channels = 3
    vector_size = 300  # For FastText

    # for trainer
    seed = 42
    max_epochs = 100
    val_each_epoch = 2
    learning_rate = 1e-4
    accelarator = "gpu"

    # TENSORBOARD LOGGING
    tensorboard = {
        'dir': 'logging',
        'name': 'experiment',
        'version': 0
    }

    # ckpt
    ckpt_dir = 'logging/experiment/0/ckpt'

    # CKPT FOR EVALUATE
    test_ckpt = None

    # CKPT FOR CONTINUE TRAINING
    keep_training_path = None

In [None]:
dm = DataModule(root_data_dir=config.root_data_dir,  
                    model_type=config.model_type, 
                    batch_size=config.batch_size, 
                    num_workers=config.num_workers, 
                    fasttext_embedding=config.fasttext_embedding)

In [None]:
dm.setup('fit')
loss_weight = dm.train_data.class_weights

In [None]:
if config.model_name == "BERT-FF-BASE":
        model = PhoBertFeedForward_base(from_pretrained=config.from_pretrained,
                                        freeze_backbone=config.freeze_backbone,
                                        drop_out=config.drop_out,
                                        out_channels=config.out_channels)
elif config.model_name == "BERT-FF-LARGE":
    model = PhoBertFeedForward_large(from_pretrained=config.from_pretrained,
                                    freeze_backbone=config.freeze_backbone,
                                    drop_out=config.drop_out,
                                    out_channels=config.out_channels)
elif config.model_name == "BERT-LSTM-BASE":
    model = PhoBERTLSTM_base(from_pretrained=config.from_pretrained,
                                    freeze_backbone=config.freeze_backbone,
                                    drop_out=config.drop_out,
                                    out_channels=config.out_channels)
elif config.model_name == "BERT-LSTM-LARGE":
    model = PhoBERTLSTM_large(from_pretrained=config.from_pretrained,
                                    freeze_backbone=config.freeze_backbone,
                                    drop_out=config.drop_out,
                                    out_channels=config.out_channels)
elif config.model_name == "FASTTEXT-LSTM":
    pass
else:
    raise ValueError(f"Not support {config.model_name}")

# system configuration
if config.model_name.startswith("FASTTEXT"):
    system = FastTextLSTMModel(dropout=config.drop_out, 
                                out_channels=config.out_channels,
                                hidden_size=config.vector_size,
                                loss_weight=loss_weight)
else:
    system = PhoBERTModel(model=model, 
                            out_channels=config.out_channels,
                            loss_weight=loss_weight)

In [None]:
checkpoint_callback = ModelCheckpoint(dirpath=config.ckpt_dir, 
                                      monitor="val_loss", 
                                      save_top_k=3, mode="min")

early_stopping = EarlyStopping(monitor="val_loss", mode="min")

logger = TensorBoardLogger(save_dir=config.tensorboard['dir'], 
                        name=config.tensorboard['name'], 
                        version=config.tensorboard['version'])

trainer = Trainer(accelerator=config.accelarator, check_val_every_n_epoch=config.val_each_epoch,
                gradient_clip_val=1.0,max_epochs=config.max_epochs,
                enable_checkpointing=True, deterministic=True, default_root_dir=config.ckpt_dir,
                callbacks=[checkpoint_callback, early_stopping], logger=logger, accumulate_grad_batches=4,log_every_n_steps=1)

In [None]:
trainer.fit(model=system, datamodule=dm, ckpt_path=config.keep_training_path)

In [None]:
trainer.test(model=system, datamodule=dm, ckpt_path=config.test_ckpt)