# Test the estimator class 

In [1]:
import os
import torch
from celldreamer.estimator.celldreamer_estimator import CellDreamerEstimator
from celldreamer.paths import PERT_DATA_DIR
from celldreamer.data.utils import Args

from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.model_summary import ModelSummary

  warn(


Initialize the ```args``` dict and the estimator class

In [2]:
args_pert = Args(
                   {   
                    #General 
                    "train": True,
                    "train_autoencoder": False, 
                    "experiment_name": "try_experiment",
                    "task": "perturbation_modelling",
                    "freeze_embeddings": True,
                    "feature_type": "grover",
                    "data_path": PERT_DATA_DIR / 'sciplex' / 'sciplex_complete_middle_subset.h5ad',
                    "batch_size": 128, 
                    "use_latent_repr": False,
                    "resume": False,
                    "doser_width": 128,
                    "doser_depth": 3, 
                    "pretrained_autoencoder": False, 
                    "checkpoint_autoencoder": None,
                    
                    # Perturbation setting specific
                    "use_drugs": True, 
                    "perturbation_key": "condition",
                    "dose_key": "dose",
                    "covariate_keys": "cell_type",
                    "smile_keys": "SMILES",
                    "degs_key": "lincs_DEGs",
                    "pert_category": "cov_drug_dose_name",
                    "split_key": "split_ho_pathway",
                    "use_drugs_idx":True,
                    "embedding_dimensions": 100,
                    "one_hot_encode_features": False,
                    "doser_width":128,
                    "doser_depth":3,
                    "pretrained_generative":False,
                    "chekpoint_path": True, 
                     
                    # General model 
                    "generative_model":"diffusion", 
                    "denoising_model": "mlp",
                    
                    # Autoencoder 
                    "autoencoder_kwargs": {"in_dim": 2000,
                                  "learning_rate": 0.001,
                                  "hidden_dim_encoder": [256, 128, 64], 
                                  "hidden_dim_decoder": [64, 128, 64], 
                                  "batch_norm": True, 
                                  "layer_norm": False,
                                  "activation": torch.nn.ReLU,
                                  "output_activation": torch.nn.Identity, 
                                  "reconst_loss": "mse", 
                                  "dropout": 0.0,
                                  "weight_decay": 0.1, 
                                  "optimizer": torch.optim.Adam, 
                                  "lr_scheduler": None,
                                  "lr_scheduler_kwargs": None
                                  },
                    
                    # Denoising model specific 
                    "denoising_module_kwargs": 
                        {
                         "dims": [128, 64],
                         "time_embed_size": 100, 
                         "class_emb_size": 100,
                         "dropout": 0.0
                        }, 
                    
                    # Diffusion model specific
                    "generative_model_kwargs": 
                        {
                         "T": 4000, 
                         "w": 0.3, 
                         "v": 0.2,
                         "p_uncond": 0.2, 
                         "logging_freq": 1000,  
                         "classifier_free": False, 
                         "learning_rate": 0.001, 
                         "weight_decay": 0.0001
                        },
                    
                    # Autoencoder trainer hparams
                    "trainer_autoencoder_kwargs": {
                        'max_epochs': 1000,
                        'gradient_clip_val': 1.,
                        'gradient_clip_algorithm': 'norm',
                        'default_root_dir': "",
                        'accelerator': 'gpu',
                        'devices': 1,
                        'num_sanity_val_steps': 0,
                        'check_val_every_n_epoch': 1,
                        'log_every_n_steps': 100,
                        'detect_anomaly': False,
                        'enable_progress_bar': True,
                        'enable_model_summary': False,
                        'enable_checkpointing': True},  
                       
                    # Generative model trainer hyperparams 
                    "trainer_generative_kwargs": {
                        'max_epochs': 1000,
                        'gradient_clip_val': 1.,
                        'gradient_clip_algorithm': 'norm',
                        'default_root_dir': "",
                        'accelerator': 'gpu',
                        'devices': 1,
                        'num_sanity_val_steps': 0,
                        'check_val_every_n_epoch': 1,
                        'log_every_n_steps': 100,
                        'detect_anomaly': False,
                        'enable_progress_bar': True,
                        'enable_model_summary': False,
                        'enable_checkpointing': True}                 
                 })

Initialize the cell estimator 

In [3]:
estimator = CellDreamerEstimator(args_pert)

Create the training folders...
Initialize data module...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Initialize feature embeddings...
Initialize model...


Check training batches 

In [4]:
estimator.train()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

3999


RuntimeError: Tensors must have same number of dimensions: got 3 and 2

In [None]:
estimator.generative_model.feature_embeddings

In [None]:
next(iter(estimator.datamodule.train_dataloader))["X"].device