In [None]:
import logging
from pathlib import Path

import clearml
import pytorch_lightning as pl
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from callbacks.lr_callbacks import LrDecay, LrWarmup, LrExponential
from data.nih_data_module import NIHDataModule
from loggers.clearml_logger import ClearMLLogger
from models.efficient_net_v2_module import EfficientNetV2Module

logger = logging.getLogger(__name__)

pl.seed_everything(42)

cfg = OmegaConf.load('config/example/train.yaml')

task = clearml.Task.init(project_name='Nih-classification',
                         task_name=cfg.cluster.job_name,
                         auto_connect_frameworks=False,
                         output_uri=True)

#
# Extract and setup configuration from config file
#
log_ver = cfg.training.version
rfc = cfg.training.restore_from_ckpt
exp_root_dir = Path(cfg.training.log_dir) / cfg.cluster.job_name

version = f'version_{log_ver}' if isinstance(log_ver, int) else log_ver

exp_str = 'no_exp'

log_dir = exp_root_dir / exp_str / version
checkpoint_dir = log_dir / 'checkpoints'
checkpoint_file = exp_root_dir / rfc if rfc else None

es_cfg = cfg.training.early_stopping

lr_decay_cfg = cfg.hparams.lr_decay
lr_exponential_cfg = cfg.hparams.lr_exponential
lr_warmup_cfg = cfg.hparams.lr_warmup

#
# Define callbacks
#
callbacks = []

ckpt_params = dict(
    dirpath=checkpoint_dir,
    verbose=True,
    save_top_k=3,
    auto_insert_metric_name=False
)
max_auc_ckpt_cb = ModelCheckpoint(
    filename='epoch={epoch}_val_auroc={auroc_avg/val:.3f}_top',
    monitor='auroc_avg/val',
    mode='max',
    **ckpt_params
)
callbacks.append(max_auc_ckpt_cb)

es_cb = EarlyStopping(
    monitor="auc_roc_avg/val",
    min_delta=es_cfg.min_delta,
    patience=es_cfg.patience,
    verbose=True,
    mode="max"
)
if es_cfg.enabled:
    callbacks.append(es_cb)

# Note: LrDecay callback should be executed before LrWarmup
lr_decay_cb = LrDecay(
    rate=lr_decay_cfg.rate,
    interval=lr_decay_cfg.interval,
    initial_lr=cfg.hparams.lr_initial
)
if lr_decay_cfg.enabled:
    callbacks.append(lr_decay_cb)

# Note: LrExponential callback should be executed before LrWarmup
lr_exponential_cb = LrExponential(
    gamma=lr_exponential_cfg.gamma,
    warmup_steps=lr_warmup_cfg.warmup_steps if lr_warmup_cfg.enabled else None,
    phases=cfg.hparams.phases,
    initial_lr=cfg.hparams.lr_initial
)
if lr_exponential_cfg.enabled:
    callbacks.append(lr_exponential_cb)

lr_warmup_cb = LrWarmup(
    warmup_steps=lr_warmup_cfg.warmup_steps,
    phases=cfg.hparams.phases,
    initial_lr=cfg.hparams.lr_initial
)
if lr_warmup_cfg.enabled:
    callbacks.append(lr_warmup_cb)

#
# Instantiate modules
#
dm = NIHDataModule(
    dataset_path=cfg.data.dataset_path,
    df_prefix=cfg.data.df_prefix,
    phases=cfg.hparams.phases,
    num_workers=cfg.cluster.cpus_per_node,
    merge_train_val=cfg.data.merge_train_val
)

if cfg.hparams.architecture == 'eff_net_v2':
    model = EfficientNetV2Module(num_classes=NIHDataModule.NUM_CLASSES,
                                 class_freq=dm.get_train_class_freq(),
                                 hparams=cfg.hparams)
else:
    raise ValueError()

trainer = Trainer(
    max_epochs=cfg.hparams.epochs,
    logger=ClearMLLogger(task),
    deterministic=True,
    num_sanity_val_steps=0,
    callbacks=callbacks,
    default_root_dir=checkpoint_dir,
    resume_from_checkpoint=checkpoint_file,
    reload_dataloaders_every_n_epochs=1,
    limit_train_batches=100,
    log_every_n_steps=25
)

In [None]:
trainer.fit(model, datamodule=dm)

In [None]:
trainer.test(model, datamodule=dm)

In [None]:
task.flush()