In [None]:
import os
import seaborn as sns
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning_fabric.utilities.seed import seed_everything

In [None]:
%load_ext autoreload

In [None]:
%autoreload
from cellnet.estimators import EstimatorCellTypeClassifier

# Init model

In [None]:
# config parameters
MODEL = 'cellnet'
CHECKPOINT_PATH = os.path.join('/lustre/scratch/users/felix.fischer/tb_logs', MODEL)
LOGS_PATH = os.path.join('/lustre/scratch/users/felix.fischer/tb_logs', MODEL)
DATA_PATH = '/lustre/scratch/users/felix.fischer/merlin_cxg_norm_parquet'


estim = EstimatorCellTypeClassifier(DATA_PATH)
seed_everything(1)
estim.init_datamodule(batch_size=4096)
estim.init_trainer(
    trainer_kwargs={
        'max_epochs': 1000,
        'gradient_clip_val': 1.,
        'gradient_clip_algorithm': 'norm',
        'default_root_dir': CHECKPOINT_PATH,
        'accelerator': 'gpu',
        'devices': 1,
        'num_sanity_val_steps': 0,
        'check_val_every_n_epoch': 1,
        'logger': [TensorBoardLogger(LOGS_PATH, name='default')],
        'log_every_n_steps': 100,
        'detect_anomaly': False,
        'enable_progress_bar': True,
        'enable_model_summary': False,
        'enable_checkpointing': True,
        'callbacks': [
            TQDMProgressBar(refresh_rate=100),
            LearningRateMonitor(logging_interval='step'),
            ModelCheckpoint(filename='train_loss_{epoch}_{train_loss:.3f}', monitor='train_loss_epoch', mode='min',
                            every_n_epochs=1, save_top_k=2),
            ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max',
                            every_n_epochs=1, save_top_k=2),
            ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min',
                            every_n_epochs=1, save_top_k=2)
        ],
    }
)
estim.init_model(
    model_type='tabnet',
    model_kwargs={
        'learning_rate': 0.005,
        'weight_decay': 0.1,
        'lr_scheduler': torch.optim.lr_scheduler.StepLR,
        'lr_scheduler_kwargs': {
            'step_size': 2,
            'gamma': 0.9,
            'verbose': True
        },
        'optimizer': torch.optim.AdamW,
        'lambda_sparse': 1e-6,
        'n_d': 512,
        'n_a': 128,
        'n_steps': 5,
        'gamma': 1.3,
        'n_independent': 4,
        'n_shared': 4,
        'virtual_batch_size': 256,
        'mask_type': 'entmax',
    },
)
print(ModelSummary(estim.model))


# Find learning rate

In [None]:
lr_find_res = estim.find_lr(lr_find_kwargs={'early_stop_threshold': 10., 'min_lr': 1e-8, 'max_lr': 10., 'num_training': 120})

In [None]:
ax = sns.lineplot(x=lr_find_res[1]['lr'], y=lr_find_res[1]['loss'])
ax.set_xscale('log')
ax.set_ylim(12., top=20.)
ax.set_xlim(1e-6, 10.)
print(f'Suggested learning rate: {lr_find_res[0]}')

# Fit model

In [None]:
estim.train()