In [None]:
import os
import src
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
class config:
    SEED = 42
    ACCELERATOR = "cpu"

    # Trainer
    
    EPOCHS = 200
    BATCH_SIZE = 4
    VAL_EACH_EPOCH = 2
    LEARNING_RATE = 1e-4
    ENCODER_FEATURE_DIM = 256

    # Data
    DATA_DIR = "./data1"

    # Tensorboard
    TENSORBOARD = {
        "DIR": "",
        "NAME": "LOG",
        "VERSION": "0",
    }

    # Checkpoint
    CHECKPOINT_DIR = os.path.join(TENSORBOARD["DIR"], TENSORBOARD["NAME"], TENSORBOARD["VERSION"], "CKPT")

    # ckpt path to test model
    TEST_CKPT_PATH = None

    # ckpt path to continue training
    CONTINUE_TRAINING = None

In [None]:
seed_everything(config.SEED)

dm = src.CaptchaDataModule(data_dir=config.DATA_DIR, batch_size=config.BATCH_SIZE)

model = src.CRNN(hidden_size=config.ENCODER_FEATURE_DIM, out_channels=config.ENCODER_FEATURE_DIM)
system = src.OCRTrainer(model, learning_rate=config.LEARNING_RATE)

checkpoint_callback = ModelCheckpoint(dirpath= config.CHECKPOINT_DIR, monitor="val_loss", save_top_k=3, mode="min")
early_stopping = EarlyStopping(monitor="val_loss", mode="min")

logger = TensorBoardLogger(save_dir=config.TENSORBOARD["DIR"], name=config.TENSORBOARD["NAME"], version=config.TENSORBOARD["VERSION"])

trainer = Trainer(accelerator=config.ACCELERATOR, check_val_every_n_epoch=config.VAL_EACH_EPOCH,
                gradient_clip_val=1.0,max_epochs=config.EPOCHS,
                enable_checkpointing=True, deterministic=True, default_root_dir=config.CHECKPOINT_DIR,
                callbacks=[checkpoint_callback, early_stopping], logger=logger, accumulate_grad_batches=5, log_every_n_steps=10)

trainer.fit(model=system, datamodule=dm, ckpt_path=config.CONTINUE_TRAINING)