In [None]:
from lightning import seed_everything
from datamodules.ws_datamodule import WSDataModule
from src.train import train_cae
from utils.loggers import get_loggers
from utils.callbacks import get_callbacks
from src.augmentations import build_augmentation_pipeline
from utils.config_utils import load_config, get_image_paths_from_dir, load_panel
from lightning.pytorch import Trainer
from models.cae_resnet import ConvAutoencoder, ResNetEncoder, Decoder
from models.cae_lightning_module import CAELightningModule

## Load and parse the configuration file
config = load_config('config/config.yaml')

dir_cfg = config.get('directories')
model_cfg = config.get('model')
train_cfg = config.get('training')
preproc_cfg = config.get('preprocessing')
aug_cfg = config.get('augmentation')

## Set the seed for all subsequent processes
seed_everything(train_cfg.get('seed'), workers = True)

# Initialize loggers and callbacks: TensorBoard, CSV, Checkpoints
loggers = get_loggers(dir_cfg.get('logs'), model_cfg.get('name'))
callbacks = get_callbacks(dir_cfg.get('checkpoints'))

## Initialize the augmentation pipeline and build the datasets
data_transforms = build_augmentation_pipeline(aug_cfg)
image_paths = get_image_paths_from_dir(dir_cfg.get('mcd_dir'))
panel = load_panel(dir_cfg.get('panel'))

# Datamodule initializes the train/test datasets and loaders
datamodule = WSDataModule(
    image_paths = image_paths,
    patch_size = preproc_cfg.get('patch_size'),
    stride = preproc_cfg.get('stride'),
    preproc_cfg = preproc_cfg,
    transforms = data_transforms,
    panel = panel,
    batch_size = train_cfg.get('batch_size'),
    num_workers = train_cfg.get('num_workers')
)

## Initialize the model and trainer
encoder = ResNetEncoder()
decoder = Decoder()
cae = ConvAutoencoder(encoder, decoder)
model = CAELightningModule(cae)

trainer = Trainer(
    devices = train_cfg.get('num_devices'),
    max_epochs = train_cfg.get('epochs'),
    accelerator = "auto",
    benchmark = True,        # Fixed input size, speeds up training
    logger = loggers,
    callbacks = callbacks,
    log_every_n_steps = 32
)

In [None]:
## Fit the data
trainer.fit(model, datamodule)

In [None]:
## Evaluate the model and return the metrics
trainer.test(model, datamodule)