In [1]:
import sys 
import pytorch_lightning as pl
import seml
import torch
from sacred import SETTINGS, Experiment
from functools import partial

sys.path.insert(0,"../")
from paths import EXPERIMENT_FOLDER

from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.cfm.cfm_module import CFMLitModule
from scCFM.models.cfm.components.mlp import MLP

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  self.seed = seed
  self.dl_pin_memory_gpu_training = (


## Import configurations

In [2]:
config = {
    "training": {
        "task_name": "1_OFFICIAL_cfm_eb_latent_vae",
        "seed": 42
    },
    "datamodule": {
        "path": "/nfs/homedirs/pala/scCFM/project_dir/data/eb/flat/eb_lib.h5ad",
        "x_layer": "X_latents",
        "time_key": "experimental_time",
        "use_pca": False,
        "n_dimensions": None,
        "train_val_test_split": [0.90, 0.1],
        "num_workers": 2,
        "batch_size": 256,
        "model_library_size": True
    },
    "net": {
        "w": 64,
        "time_varying": True
    },
    "model": {
        "ot_sampler": "exact",
        "sigma": 0.1,
        "use_real_time": False,
        "lr": 0.001,
        "antithetic_time_sampling": False, 
        "leaveout_timepoint": 1,
    },
    "model_checkpoint": {
        "filename": "epoch_{epoch:04d}",
        "monitor": "train/loss",
        "mode": "min",
        "save_last": True,
        "auto_insert_metric_name": False
    },
    "early_stopping": {
        "perform_early_stopping": False,
        "monitor": "train/loss",
        "patience": 200,
        "mode": "min",
        "min_delta": 0.0,
        "verbose": False,
        "strict": True,
        "check_finite": True,
        "stopping_threshold": None,
        "divergence_threshold": None,
        "check_on_train_epoch_end": None
    },
    "logger": {
        "offline": True,
        "id": None,
        "project": "1_OFFICIAL_cfm_eb_latent_vae",
        "log_model": False,
        "prefix": "",
        "group": "",
        "tags": [],
        "job_type": ""
    },
    "trainer": {
        "max_epochs": None,
        "max_steps": 20000,
        "accelerator": "gpu",
        "devices": 1,
        "log_every_n_steps": 50
    }
}

In [3]:
class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))

## Initialize and train/load autoencoder 

In [4]:
# Training function 
class Solver:
    def __init__(self):
        pass
        
    def init_general(self, 
                     task_name,
                     seed):
        
        self.task_name = task_name 
        
        # Fix seed for reproducibility
        torch.manual_seed(seed)      
        if seed: 
            pl.seed_everything(seed, workers=True)
            
        # Initialize folder 
        self.current_experiment_dir = EXPERIMENT_FOLDER / self.task_name
        self.current_experiment_dir.mkdir(parents=True, exist_ok=True) 
    
    def init_datamodule(self, 
                        path,
                        x_layer, 
                        time_key,
                        use_pca,
                        n_dimensions, 
                        train_val_test_split,
                        batch_size,
                        num_workers, 
                        model_library_size):
        
        # Initialize datamodule
        self.datamodule = TrajectoryDataModule(path=path,
                                               x_layer=x_layer,
                                               time_key=time_key,
                                               use_pca=use_pca,
                                               n_dimensions=n_dimensions,
                                               train_val_test_split=train_val_test_split,
                                               batch_size=batch_size,
                                               num_workers=num_workers, 
                                               model_library_size=model_library_size)
         
    def init_net(self, 
                 w,
                 time_varying):
        
        # Neural network 
        net_hparams = {"dim": self.datamodule.dim,
                        "w": w,
                        "time_varying": time_varying}
        
        self.net = MLP(**net_hparams) 

    def init_model(self,
                   ot_sampler,
                   sigma,
                   lr,
                   use_real_time, 
                   antithetic_time_sampling, 
                   leaveout_timepoint):
        
        # Initialize the model 
        self.model = CFMLitModule(
                            net=self.net,
                            datamodule=self.datamodule,
                            ot_sampler=ot_sampler, 
                            sigma=sigma, 
                            lr=lr, 
                            use_real_time=use_real_time,
                            antithetic_time_sampling=antithetic_time_sampling,
                            leaveout_timepoint=leaveout_timepoint) 
        
    def init_checkpoint_callback(self, 
                                 filename, 
                                 monitor,
                                 mode,
                                 save_last,
                                 auto_insert_metric_name):
        
        # Initialize callbacks 
        self.model_ckpt_callbacks = ModelCheckpoint(dirpath=self.current_experiment_dir / "checkpoints", 
                                                    filename=filename,
                                                    monitor=monitor,
                                                    mode=mode,
                                                    save_last=save_last,
                                                    auto_insert_metric_name=auto_insert_metric_name)
    
    def init_early_stopping_callback(self, 
                                     perform_early_stopping,
                                     monitor, 
                                     patience, 
                                     mode,
                                     min_delta,
                                     verbose,
                                     strict, 
                                     check_finite,
                                     stopping_threshold,
                                     divergence_threshold,
                                     check_on_train_epoch_end):
        
        # Initialize callbacks 
        if perform_early_stopping:
            self.early_stopping_callbacks = EarlyStopping(monitor=monitor,
                                                        patience=patience, 
                                                        mode=mode,
                                                        min_delta=min_delta,
                                                        verbose=verbose,
                                                        strict=strict,
                                                        check_finite=check_finite,
                                                        stopping_threshold=stopping_threshold,
                                                        divergence_threshold=divergence_threshold,
                                                        check_on_train_epoch_end=check_on_train_epoch_end
                                                        )
        else:
            self.early_stopping_callbacks = None
        
    def init_logger(self, 
                    offline, 
                    id, 
                    project, 
                    log_model, 
                    prefix, 
                    group, 
                    tags, 
                    job_type):
        
        # Initialize logger 
        self.logger = WandbLogger(save_dir=self.current_experiment_dir, 
                                  offline=offline,
                                  id=id, 
                                  project=project,
                                  log_model=log_model, 
                                  prefix=prefix,
                                  group=group,
                                  tags=tags,
                                  job_type=job_type) 
        
    def init_trainer(self, 
                     max_epochs,
                     max_steps,
                     accelerator,
                     devices, 
                     log_every_n_steps):    
        # Initialize the lightning trainer 
        self.trainer = Trainer(default_root_dir=self.current_experiment_dir,
                                  max_epochs=max_epochs,
                                  max_steps=max_steps,
                                  accelerator=accelerator,
                                  devices=devices,
                                  log_every_n_steps=log_every_n_steps)
                
    def train(self):
        # Fit the model 
        self.trainer.fit(model=self.model, 
                          train_dataloaders=self.datamodule.train_dataloader(),
                          val_dataloaders=self.datamodule.val_dataloader())
        
        train_metrics = self.trainer.callback_metrics
        return train_metrics

In [5]:
solver = Solver()

In [6]:
solver.init_general(**config["training"])
solver.init_datamodule(**config["datamodule"])
solver.init_net(**config["net"])
solver.init_model(**config["model"])
solver.init_trainer(**config["trainer"])

[rank: 0] Global seed set to 42
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
solver.train()

  rank_zero_warn(
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params
----------------------------------------
0 | net       | MLP       | 9.9 K 
1 | node      | NeuralODE | 9.9 K 
2 | criterion | MSELoss   | 0     
----------------------------------------
9.9 K     Trainable params
0         Non-trainable params
9.9 K     Total params
0.039     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


0 2
2 3
3 4


  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4


Validation: 0it [00:00, ?it/s]

0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4


Validation: 0it [00:00, ?it/s]

0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4
0 2
2 3
3 4


Validation: 0it [00:00, ?it/s]

0 2
2 3
3 4


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


{'train/loss': tensor(0.4473, device='cuda:0'),
 'val/loss': tensor(0.4791, device='cuda:0'),
 'val/t1/1-Wasserstein': tensor(2.6931, device='cuda:0'),
 'val/t1/2-Wasserstein': tensor(2.7870, device='cuda:0'),
 'val/t1/Linear_MMD': tensor(0.0945, device='cuda:0'),
 'val/t1/Poly_MMD': tensor(0.3074, device='cuda:0'),
 'val/t1/RBF_MMD': tensor(0.2460, device='cuda:0'),
 'val/t1/Mean_MSE': tensor(0.1355, device='cuda:0'),
 'val/t1/Mean_L2': tensor(0.3682, device='cuda:0'),
 'val/t1/Mean_L1': tensor(0.2791, device='cuda:0'),
 'val/t1/Median_MSE': tensor(2.9048, device='cuda:0'),
 'val/t1/Median_L2': tensor(3.0316, device='cuda:0'),
 'val/t1/Median_L1': tensor(0.1730, device='cuda:0'),
 'val/t2/1-Wasserstein': tensor(0.4159, device='cuda:0'),
 'val/t2/2-Wasserstein': tensor(0.3283, device='cuda:0'),
 'val/t2/Linear_MMD': tensor(0.2271, device='cuda:0'),
 'val/t2/Poly_MMD': tensor(0.4766, device='cuda:0'),
 'val/t2/RBF_MMD': tensor(0.4108, device='cuda:0'),
 'val/t2/Mean_MSE': tensor(2.5310,

In [None]:
sc.