In [None]:
##################
# Librart Imports
##################
import gc
import os
import sys
sys.path.append('..')  # Add parent directory to Python path
import time
from datetime import datetime
import yaml
from tqdm import tqdm_notebook as tqdm
import random
import numpy as np
import pandas as pd

import torch
torch.set_float32_matmul_precision('medium')
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

from data.dataset import CustomDataset, CustomDataModule
from model.danets_baseline import CustomDANETs
from model.lit import LightningModel

from utils.config_loader import load_config
CFG = load_config('../config/experiment/danets_baseline.yaml')

def seed_everything(seed=CFG['seed']):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark =True
    
seed_everything()

### Training Cross validation with oof predicition results

following code run cross validation using fold number pre-defined using training data. 
CustomDataModule will automatically split train and validation dataset to start training process.
Early stopping is enabled with custom loss score.

DANETs with custom loss function has advantage over Gradient Boosting models such as LGBM, XGBoost as loss function customization is limited, and DANETs out-performed as shown in the research.

In [None]:
# Get the current local date and time
current_date_time = datetime.now()

# Format the date and time as a string
formatted_date_time = current_date_time.strftime("%Y-%m-%d_%H-%M-%S")
train_df = "YOUR_TRAIN_DATA"
train_df["oof_pred"] = 0.0
for fold in tqdm(range(CFG['n_folds']), desc="Folds", total=CFG['n_folds']):

    data_module = CustomDataModule(train_df, fold, CFG['meta_feats'])

    pytorch_model = CustomDANETs(CFG['meta_feats'])
    lightning_model = LightningModel(model=pytorch_model)

    callbacks = [
        ModelCheckpoint(dirpath="path/checkpoints/", 
                        filename=f"{CFG['model']}_fold{fold}" + "-{epoch:02d}-{valid_score:.5f}",
                        save_top_k=1, mode='min', monitor="valid_score", verbose=True),
        EarlyStopping(monitor='valid_score', patience=CFG['patience'], mode='max'),
        LearningRateMonitor(logging_interval='epoch'),
        StochasticWeightAveraging(swa_lrs=1e-3, swa_epoch_start=0.5, annealing_epochs=5, annealing_strategy='linear', 
                                  avg_fn=lambda averaged_model_parameter, model_parameter, num_averaged: \
                                    0.3 * averaged_model_parameter + 0.7 * model_parameter)
        ]
    csvlogger = CSVLogger(save_dir="logs/", name=f"{CFG['model']}_{formatted_date_time}_fold{fold}", version=0)
    tblogger = TensorBoardLogger(save_dir="logs/", name=f"{CFG['model']}_{formatted_date_time}_fold{fold}", version=0)

    trainer = pl.Trainer(
        max_epochs=CFG['max_epochs'],
        callbacks=callbacks,
        precision="16-mixed",
        accelerator="auto",  # Uses GPUs or TPUs if available
        devices=1,  # Uses all available GPUs/TPUs if applicable
        logger=[csvlogger, tblogger],
        log_every_n_steps=100,
        )

    start_time = time.time()
    trainer.fit(model=lightning_model, datamodule=data_module)

    runtime = (time.time() - start_time) / 60
    print(f"Training took {runtime:.2f} min in total.")

    lightning_model.load_state_dict(torch.load(trainer.checkpoint_callback.best_model_path)['state_dict'])
    preds = trainer.predict(lightning_model, data_module.val_dataloader())
    train_df.loc[train_df['fold'] == fold, 'oof_pred'] = np.concatenate(preds, axis=0)

    del trainer, lightning_model, pytorch_model, data_module; gc.collect()
    torch.cuda.empty_cache()


### Tensorboard Experiment Tracking

each fold result metric and loss behaivor is logged and visualized by executing next command in VS code.

In [None]:

#%tensorboard --logdir logs/