In [6]:
import sys 
import pytorch_lightning as pl
import seml
import torch
from sacred import SETTINGS, Experiment

sys.path.insert(0,"../")
from paths import EXPERIMENT_FOLDER

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.models.base.vae import VAE, AE
from scCFM.models.base.geometric_vae import GeometricNBAE,GeometricNBVAE

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger 

In [7]:
config = {
    "training": {
        "task_name": "geom_vae_lib",
        "seed": 42
    },
    "datamodule": {'path': '/nfs/homedirs/pala/scCFM/project_dir/data/eb/processed/eb_phate.h5ad',
                   'x_layer': 'X_norm', 
                   'cond_keys': 'experimental_time', 
                   'use_pca': False,
                   'n_dimensions': None, 
                   'train_val_test_split': [0.8, 0.2],
                   'num_workers': 2, 
                   'batch_size': 32},
    
    "model": {'n_epochs_anneal_kl': 1000, 
              'likelihood': 'nb', 
              'dropout': False,
              'learning_rate': 0.001, 
              'dropout_p': False, 
              'model_library_size': True,
              'batch_norm': True, 
              'library_size_regression': False,
              'data_library_size': True,
              'kl_warmup_fraction': 2, 
              'kl_weight': None, 
              'model_type': 'geometric_vae',
              'hidden_dims': [256, 10]},
    
    "geometric_vae": {'compute_metrics_every': 1, 
                      'use_c': False, 
                      'l2': True, 
                      'eta_interp': 0,
                      'interpolate_z': False, 
                      'start_jac_after': 0, 
                      'detach_theta': False, 
                      'fl_weight': 0.01, 
                      'anneal_fl_weight': False, 
                      'max_fl_weight': None,
                      'n_epochs_anneal_fl': None,
                      'fl_anneal_fraction': None},
    
    "model_checkpoint": {'filename': 'epoch_{epoch:04d}', 
                         'monitor': 'val/lik',
                         'mode': 'min', 
                         'save_last': True, 
                         'auto_insert_metric_name': False},
    
    "early_stopping": {'perform_early_stopping': False, 
                       'monitor': 'val/loss', 
                       'patience': 50, 
                       'mode': 'min', 
                       'min_delta': 0.0,
                       'verbose': False, 
                       'strict': True,
                       'check_finite': True, 
                       'stopping_threshold': None, 
                       'divergence_threshold': None,
                       'check_on_train_epoch_end': None},
    
    "logger": {'offline': False,
               'id': None,
               'project': 'geom_vae_lib_eb_reg',
               'log_model': True,
               'prefix': '', 
               'group': '', 
               'tags': [], 
               'job_type': ''},
    
    "trainer": {
        "max_epochs": 1000,
        "accelerator": "gpu",
        "devices": 1,
        "log_every_n_steps": 10
    }
}


In [8]:
# Training configuration
# Initialize task_name
task_name = config["training"]["task_name"]

# Fix seed for reproducibility
seed = config["training"]["seed"]
torch.manual_seed(seed)
if seed:
    pl.seed_everything(seed, workers=True)

# Initialize folder
current_experiment_dir = EXPERIMENT_FOLDER / task_name
current_experiment_dir.mkdir(parents=True, exist_ok=True)

# Datamodule initialization
datamodule = scDataModule(
    path=config["datamodule"]["path"],
    x_layer=config["datamodule"]["x_layer"],
    cond_keys=config["datamodule"]["cond_keys"],
    use_pca=config["datamodule"]["use_pca"],
    n_dimensions=config["datamodule"]["n_dimensions"],
    train_val_test_split=config["datamodule"]["train_val_test_split"],
    batch_size=config["datamodule"]["batch_size"],
    num_workers=config["datamodule"]["num_workers"]
)

# Model initialization
model_type = config["model"]["model_type"]
hidden_dims = config["model"]["hidden_dims"]
batch_norm = config["model"]["batch_norm"]
dropout = config["model"]["dropout"]
dropout_p = config["model"]["dropout_p"]
n_epochs_anneal_kl = config["model"]["n_epochs_anneal_kl"]
kl_warmup_fraction = config["model"]["kl_warmup_fraction"]
kl_weight = config["model"]["kl_weight"]
likelihood = config["model"]["likelihood"]
learning_rate = config["model"]["learning_rate"]
model_library_size = config["model"]["model_library_size"]
library_size_regression = config["model"]["library_size_regression"]
data_library_size = config["model"]["data_library_size"]


vae_kwargs = dict(
    in_dim=datamodule.in_dim,
    hidden_dims=hidden_dims,
    batch_norm=batch_norm,
    dropout=dropout,
    dropout_p=dropout_p,
    likelihood=likelihood,
    learning_rate=learning_rate,
    model_library_size=model_library_size,
    library_size_regression=library_size_regression,
    data_library_size=data_library_size
)

if model_type == "geometric_vae":
    vae_kwargs["n_epochs_anneal_kl"] = n_epochs_anneal_kl
    vae_kwargs["kl_warmup_fraction"] = kl_warmup_fraction
    vae_kwargs["kl_weight"] = kl_weight

# Geometric VAE initialization
geometric_vae = None

if model_type == "geometric_ae":
    model = GeometricNBAE(
        l2=config["geometric_vae"]["l2"],
        fl_weight=config["geometric_vae"]["fl_weight"],
        interpolate_z=config["geometric_vae"]["interpolate_z"],
        eta_interp=config["geometric_vae"]["eta_interp"],
        start_jac_after=config["geometric_vae"]["start_jac_after"],
        use_c=config["geometric_vae"]["use_c"],
        compute_metrics_every=config["geometric_vae"]["compute_metrics_every"],
        vae_kwargs=vae_kwargs,
        detach_theta=config["geometric_vae"]["detach_theta"]
    )
else:
    model = GeometricNBVAE(
        l2=config["geometric_vae"]["l2"],
        fl_weight=config["geometric_vae"]["fl_weight"],
        interpolate_z=config["geometric_vae"]["interpolate_z"],
        eta_interp=config["geometric_vae"]["eta_interp"],
        start_jac_after=config["geometric_vae"]["start_jac_after"],
        use_c=config["geometric_vae"]["use_c"],
        compute_metrics_every=config["geometric_vae"]["compute_metrics_every"],
        vae_kwargs=vae_kwargs,
        detach_theta=config["geometric_vae"]["detach_theta"],
        anneal_fl_weight=config["geometric_vae"]["anneal_fl_weight"], 
        max_fl_weight=config["geometric_vae"]["max_fl_weight"],
        n_epochs_anneal_fl=config["geometric_vae"]["n_epochs_anneal_fl"],
        fl_anneal_fraction=config["geometric_vae"]["fl_anneal_fraction"]
    )

# Model checkpoint initialization
model_ckpt_callbacks = ModelCheckpoint(
    dirpath=current_experiment_dir / "checkpoints",
    filename=config["model_checkpoint"]["filename"],
    monitor=config["model_checkpoint"]["monitor"],
    mode=config["model_checkpoint"]["mode"],
    save_last=config["model_checkpoint"]["save_last"],
    auto_insert_metric_name=config["model_checkpoint"]["auto_insert_metric_name"]
)

# Early stopping initialization
early_stopping_callbacks = None

if config["early_stopping"]["perform_early_stopping"]:
    early_stopping_callbacks = EarlyStopping(
        monitor=monitor_early_stopping,
        patience=config["early_stopping"]["patience"],
        mode=mode_early_stopping,
        min_delta=config["early_stopping"]["min_delta"],
        verbose=config["early_stopping"]["verbose"],
        strict=config["early_stopping"]["strict"],
        check_finite=config["early_stopping"]["check_finite"],
        stopping_threshold=config["early_stopping"]["stopping_threshold"],
        divergence_threshold=config["early_stopping"]["divergence_threshold"],
        check_on_train_epoch_end=config["early_stopping"]["check_on_train_epoch_end"]
    )

# Logger initialization
logger = WandbLogger(
    save_dir=current_experiment_dir,
    offline=config["logger"]["offline"],
    id=config["logger"]["id"],
    project=config["logger"]["project"],
    log_model=config["logger"]["log_model"],
    prefix=config["logger"]["prefix"],
    group=config["logger"]["group"],
    tags=config["logger"]["tags"],
    job_type=config["logger"]["job_type"]
)

# Trainer initialization
if early_stopping_callbacks:
    callbacks = [model_ckpt_callbacks, early_stopping_callbacks]
else:
    callbacks = [model_ckpt_callbacks]

trainer = Trainer(
    callbacks=callbacks,
    default_root_dir=current_experiment_dir,
    logger=logger,
    max_epochs=config["trainer"]["max_epochs"],
    accelerator=config["trainer"]["accelerator"],
    devices=config["trainer"]["devices"],
    log_every_n_steps=config["trainer"]["log_every_n_steps"]
)

[rank: 0] Global seed set to 42
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mallepalma[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model=model, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader())

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(

  | Name                 | Type   | Params
------------------------------------------------
0 | encoder_layers       | MLP    | 318 K 
1 | decoder_layers       | MLP    | 3.3 K 
2 | library_size_decoder | Linear | 11    
3 | decoder_mu_lib       | Linear | 318 K 
4 | mu_logvar            | Linear | 5.1 K 
------------------------------------------------
647 K     Trainable params
0         Non-trainable params
647 K     Total params
2.588     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [10]:
torch.exp(model.theta)

tensor([0.7675, 0.1999, 0.4274,  ..., 0.4748, 1.2398, 0.8061], device='cuda:0',
       grad_fn=<ExpBackward0>)

In [None]:
model.fl_weight_decrease

In [None]:
model.fl_weight - model.min_fl_weight

In [11]:
0.1*1000

100.0