In [1]:
!pip install -e /dss/dsshome1/04/di93zer/git/cellnet --no-deps

Obtaining file:///dss/dsshome1/04/di93zer/git/cellnet
  Preparing metadata (setup.py) ... [?25ldone
[?25hInstalling collected packages: cellnet
  Running setup.py develop for cellnet
Successfully installed cellnet

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [1]:
import os
import seaborn as sns
import torch

from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.utilities.model_summary import ModelSummary
from lightning.pytorch import seed_everything

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()

True

In [3]:
torch.set_float32_matmul_precision('high')

In [4]:
%load_ext autoreload

In [5]:
%autoreload
from cellnet.estimators import EstimatorCellTypeClassifier

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")


# Init model

In [6]:
# config parameters
MODEL = 'cxg_2023_05_15_tabnet'
CHECKPOINT_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)
LOGS_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)
DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'


estim = EstimatorCellTypeClassifier(DATA_PATH)
seed_everything(1)
estim.init_datamodule(batch_size=2048)
estim.init_trainer(
    trainer_kwargs={
        'max_epochs': 250,
        'gradient_clip_val': 1.,
        'gradient_clip_algorithm': 'norm',
        'default_root_dir': CHECKPOINT_PATH,
        'accelerator': 'gpu',
        'devices': 1,
        'num_sanity_val_steps': 0,
        'check_val_every_n_epoch': 1,
        'logger': [TensorBoardLogger(LOGS_PATH, name='default')],
        'log_every_n_steps': 100,
        'detect_anomaly': False,
        'enable_progress_bar': True,
        'enable_model_summary': False,
        'enable_checkpointing': True,
        'callbacks': [
            TQDMProgressBar(refresh_rate=50),
            LearningRateMonitor(logging_interval='step'),
            ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max',
                            every_n_epochs=1, save_top_k=2),
            ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min',
                            every_n_epochs=1, save_top_k=2)
        ],
    }
)
estim.init_model(
    model_type='tabnet',
    model_kwargs={
        'learning_rate': 0.005,
        'weight_decay': 0.05,
        'lr_scheduler': torch.optim.lr_scheduler.StepLR,
        'lr_scheduler_kwargs': {
            'step_size': 1,
            'gamma': 0.9,
            'verbose': True
        },
        'optimizer': torch.optim.AdamW,
        'lambda_sparse': 1e-5,
        'n_d': 128,
        'n_a': 64,
        'n_steps': 3,
        'gamma': 1.3,
        'n_independent': 5,
        'n_shared': 3,
        'virtual_batch_size': 256,
        'mask_type': 'entmax',
        'augment_training_data': True
    },
)
print(ModelSummary(estim.model))


[rank: 0] Global seed set to 1
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


  | Name          | Type             | Params
---------------------------------------------------
0 | train_metrics | MetricCollection | 0     
1 | val_metrics   | MetricCollection | 0     
2 | test_metrics  | MetricCollection | 0     
3 | classifier    | TabNet           | 13.0 M
---------------------------------------------------
13.0 M    Trainable params
0         Non-trainable params
13.0 M    Total params
51.824    Total estimated model params size (MB)


# Find learning rate

In [None]:
lr_find_res = estim.find_lr(lr_find_kwargs={'early_stop_threshold': 10., 'min_lr': 1e-8, 'max_lr': 10., 'num_training': 100})

In [None]:
ax = sns.lineplot(x=lr_find_res[1]['lr'], y=lr_find_res[1]['loss'])
ax.set_xscale('log')
ax.set_ylim(6., top=9.)
ax.set_xlim(1e-6, 10.)
print(f'Suggested learning rate: {lr_find_res[0]}')

# Fit model

In [None]:
estim.train()