In [None]:
from lightning import seed_everything
from lightning.pytorch import Trainer
import os
import wandb

from utils.extract_patches import extract_and_save_patches
from utils.config_utils import load_config, get_image_paths_from_dir, load_panel, fetch_marker_indices
from utils.callbacks import get_callbacks
from utils.loggers import get_loggers
from utils.extract_patches import extract_and_save_patches

from datamodules.ws_datamodule_cached import WSDataModuleCached

from models.cae_resnet import ResNetEncoder, Decoder, ConvAutoencoder
from models.cae_lightning_module import CAELightningModule

In [None]:
## Load and parse the configuration file
config = load_config('config/config.yaml')

dir_cfg = config.get('directories')
model_cfg = config.get('model')
train_cfg = config.get('training')
preproc_cfg = config.get('preprocessing')
aug_cfg = config.get('augmentation')
log_cfg = config.get('logging')

In [None]:
## Set the seed for all subsequent processes
seed_everything(train_cfg.get('seed'), workers = True)

In [None]:
## Build the datamodule
patch_paths = get_image_paths_from_dir(dir_cfg.get('patch_dir'), {'.npy'})
panel = load_panel(dir_cfg.get('panel'))

datamodule = WSDataModuleCached(
    image_paths = patch_paths,
    patch_size = preproc_cfg.get('patch_size'),
    stride = preproc_cfg.get('stride'),
    preproc_cfg = preproc_cfg,
    panel = panel,
    batch_size = train_cfg.get('batch_size'),
    num_workers = train_cfg.get('num_workers')
)

In [None]:
## Initialize loggers and callbacks: TensorBoard, CSV, Checkpoints, wandb
hyperparams = {
    "epochs": train_cfg.get('epochs'),
    "batch_size": train_cfg.get('batch_size'),
    "seed": train_cfg.get('seed')
}
loggers = get_loggers(dir_cfg.get('logs'), model_cfg.get('name'), hyperparams)
callbacks = get_callbacks(dir_cfg.get('checkpoints'), loggers[2].experiment.name)

In [None]:
## Initialize patch reconstruction logging
log_channels = fetch_marker_indices(log_cfg.get('log_channels'), panel)
log_patches = [os.path.join(dir_cfg.get('patch_dir'), patch) for patch in log_cfg.get('log_patches')]

In [None]:
## Extract and save patches
# extract_and_save_patches(image_paths, panel, preproc_cfg, dir_cfg.get('patch_dir'))

In [None]:
## Initialize the model and trainer
encoder = ResNetEncoder()
decoder = Decoder()
cae = ConvAutoencoder(encoder, decoder)
model = CAELightningModule(
    autoencoder = cae, 
    lr = train_cfg.get('learning_rate'),
    log_patches = log_patches,
    log_channels = log_channels
)

trainer = Trainer(
    devices = train_cfg.get('num_devices'),
    max_epochs = train_cfg.get('epochs'),
    accelerator = "auto",
    benchmark = False,
    logger = loggers,
    callbacks = callbacks,
    log_every_n_steps = log_cfg.get('log_step_n')
)

In [None]:
## Fit the data
trainer.fit(model, datamodule)

best_val_loss = callbacks[0].best_model_score.item()
wandb.run.summary["best_val_loss"] = best_val_loss

wandb.finish()