### Imports

In [1]:
%load_ext autoreload

In [5]:
import os
import seaborn as sns
import torch
import pandas as pd
from os.path import join
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning_fabric.utilities.seed import seed_everything

In [6]:
%autoreload
from celldreamer.estimator.get_estim import CellDreamerEstimator
from celldreamer.models.base.variance_scheduler.cosine import CosineScheduler

ImportError: cannot import name 'CellDreamerEstimator' from 'celldreamer.estimator.get_estim' (/home/icb/alessandro.palma/environment/celldreamer/celldreamer/estimator/get_estim.py)

In [4]:
# config parameters
MODEL = 'ddpm_cellnet'
root = os.path.dirname(os.path.abspath(os.getcwd()))
CHECKPOINT_PATH = os.path.join(root, 'trained_models/tb_logs', MODEL)
LOGS_PATH = os.path.join(root, 'trained_models/tb_logs', MODEL)
DATA_PATH = '/lustre/scratch/users/felix.fischer/merlin_cxg_norm_parquet'

### Explore MLP Autoencoder with CellNet

In [None]:
estim = EstimatorCellDreamer(DATA_PATH)
seed_everything(1)

In [None]:
estim.init_datamodule(batch_size=4096)

In [None]:
estim.init_trainer(
    trainer_kwargs={
        'max_epochs': 1000,
        'gradient_clip_val': 1.,
        'gradient_clip_algorithm': 'norm',
        'default_root_dir': CHECKPOINT_PATH,
        'accelerator': 'gpu',
        'devices': 1,
        'num_sanity_val_steps': 0,
        'check_val_every_n_epoch': 1,
        'logger': [TensorBoardLogger(LOGS_PATH, name='default')],
        'log_every_n_steps': 100,
        'detect_anomaly': False,
        'enable_progress_bar': True,
        'enable_model_summary': False,
        'enable_checkpointing': True,
        'callbacks': [
            TQDMProgressBar(refresh_rate=100),
            LearningRateMonitor(logging_interval='step'),
            ModelCheckpoint(filename='best_train_loss', monitor='train_loss_epoch', mode='min',
                            every_n_epochs=1, save_top_k=1),
            ModelCheckpoint(filename='best_val_loss', monitor='val_loss', mode='min',
                            every_n_epochs=1, save_top_k=1)
        ],
    }
)

In [None]:
estim.init_model(
    generative_model='diffusion',
    denoiser_module='mlp',
    hidden_layers=3,
    model_kwargs={
        'T': 4_000,
        'w': 0.3,
        'p_uncond': 0.2,
        'width': 1,
        'height': len(pd.read_parquet(join(DATA_PATH, 'var.parquet'))),
        'input_channels': 1,
        'num_classes': len(pd.read_parquet('/lustre/scratch/users/felix.fischer/merlin_cxg_norm_parquet/categorical_lookup/cell_type.parquet')),
        'logging_freq': 1_000,
        'v': 0.2,
        'variance_scheduler': CosineScheduler(T=4_000),
    }
)

In [None]:
estim.model

### Find learning rate

In [27]:
lr_find_res = estim.find_lr(
    lr_find_kwargs={'early_stop_threshold': 10., 
                    'min_lr': 1e-8, 
                    'max_lr': 10., 
                    # 'cpkt_path': '/home/icb/till.richter/git/celldreamer/trained_models',
                    'num_training': 120})

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


TypeError: forward() takes 3 positional arguments but 4 were given

In [None]:
ax = sns.lineplot(x=lr_find_res[1]['lr'], y=lr_find_res[1]['loss'])
ax.set_xscale('log')
# ax.set_ylim(12., top=20.)
# ax.set_xlim(1e-6, 10.)
print(f'Suggested learning rate: {lr_find_res[0]}')

### Train

In [12]:
# Train takes a while, but this is how it's called in a script:
# estim.train()