In [1]:
import os 
import pytorch_lightning as pl
import seml
import torch
from sacred import SETTINGS, Experiment
from functools import partial
from PerturbSeq_CMV.paths import EXPERIMENT_FOLDER

from PerturbSeq_CMV.datamodules.distribution_datamodule import TrajectoryDataModule
from PerturbSeq_CMV.models.cfm_module import CFMLitModule
from PerturbSeq_CMV.models.components.augmentation import AugmentationModule
from PerturbSeq_CMV.models.components.simple_mlp import VelocityNet

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from torch.optim import AdamW

import yaml

**Import configs**

In [2]:
with open("/nfs/homedirs/pala/PerturbSeq_CMV/configs/default_config.yaml", "r") as stream:
    hparams = yaml.safe_load(stream)["fixed"]

In [3]:
task_name = hparams["training.training"]["task_name"]
        
# Fix seed for reproducibility
torch.manual_seed(hparams["training.training"]["seed"])      
if hparams["training.training"]["seed"]: 
    pl.seed_everything(hparams["training.training"]["seed"], workers=True)

# Initialize folder 
current_experiment_dir = EXPERIMENT_FOLDER / task_name
current_experiment_dir.mkdir(parents=True, exist_ok=True) 
    

# Initialize datamodule
datamodule = TrajectoryDataModule(**hparams["datamodule.datamodule"])
    

# Initialize augmentations
augmentations = AugmentationModule(**hparams["augmentations.augmentations"])
         

# Neural network 
net = partial(VelocityNet, **hparams["net.net"])   


# Initialize the model 
model = CFMLitModule(
                    net=net,
                    datamodule=datamodule,
                    augmentations= augmentations, 
                    **hparams["model.model"]
                    ) 
        

# Initialize callbacks 
model_ckpt_callbacks = ModelCheckpoint(dirpath=current_experiment_dir / "checkpoints", 
                                        **hparams["model_checkpoint.model_checkpoint"])


# Initialize callbacks 
early_stopping_callbacks = EarlyStopping(**hparams["early_stopping.early_stopping"])
        

# Initialize logger 
logger = WandbLogger(save_dir=current_experiment_dir / "logs", 
                     **hparams["logger.logger"]) 
        

# Initialize the lightning trainer 
trainer = Trainer(default_root_dir=current_experiment_dir,
                  callbacks=[model_ckpt_callbacks, early_stopping_callbacks], 
                  logger=logger, 
                  **hparams["trainer.trainer"])
        

# Fit the model 
trainer.fit(model=model, 
                 datamodule=datamodule)
train_metrics = trainer.callback_metrics

# Test model 
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
    ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
test_metrics = trainer.callback_metrics

# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}



[rank: 0] Global seed set to 42


[(446, 3482), (837, 3482), (1354, 3482), (1672, 3482), (1119, 3482), (477, 3482), (532, 3482), (185, 3482)]
[0.8, 0.1, 0.1]


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mallepalma[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                 | Params
-----------------------------------------------------------
0 | net               | VelocityNet          | 457 K 
1 | augmentations     | AugmentationModule   | 0     
2 | aug_net           | AugmentedVectorField | 457 K 
3 | node              | NeuralODE            | 457 K 
4 | val_augmentations | AugmentationModule   | 0     
5 | val_aug_net       | AugmentedVectorField | 457 K 
6 | val_aug_node      | Sequential           | 457 K 
7 | aug_node          | Sequential           | 457 K 
8 | criterion         | MSELoss              | 0     
-----------------------------------------------------------
457 K     Trainable params
0         Non-trainable params
457 K     Total params
1.831     Total estimated model params size (M

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!