# Test the estimator class 

In [1]:
import os
import scanpy as sc
import numpy as np
import pandas as pd
import torch
from celldreamer.estimator.celldreamer_estimator import CellDreamerEstimator
from celldreamer.paths import DATA_DIR
from celldreamer.data.utils import Args

from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.model_summary import ModelSummary

Initialize the ```args``` dict and the estimator class

In [2]:
args_pert = Args(
                   {   
                    #General 
                    "train": True, 
                    "experiment_name": "try_experiment",
                    "task": "perturbation_modelling",
                    "dataset_path": 'sciplex/sciplex_complete_middle_subset.h5ad',
                    "batch_size": 256,
                    "resume": False,
                    "train_autoencoder": False, 
                    "use_latent_repr": False,
                    "pretrained_autoencoder": False, 
                    "checkpoint_autoencoder": None,
                    "pretrained_generative":False,
                    "checkpoint_generative": True, 
                       
                    # Perturbation setting specific
                    "perturbation_key": "condition",
                    "dose_key": "dose",
                    "covariate_keys": "cell_type",
                    "smile_keys": "SMILES",
                    "degs_key": "lincs_DEGs",
                    "pert_category": "cov_drug_dose_name",
                    "split_key": "split_ho_pathway", 
                    "use_drugs":False,
                    "feature_type": "rdkit",
                    "freeze_embeddings": True,
                    "doser_width": 128,
                    "doser_depth": 3,
                    "embedding_dimensions": 100,
                    "one_hot_encode_features": True,
                     
                    # General model 
                    "generative_model": "diffusion", 
                    "denoising_model": "mlp",
                    
                    # Autoencoder 
                    "autoencoder_kwargs": None,
                       
                    # Checkpoint kwargs 
                    "checkpoint_kwargs": 
                        {"filename": "epoch_{epoch:01d}",
                          "monitor": "loss/valid_loss", 
                          "mode": "min", 
                          "save_last": True,
                          "auto_insert_metric_name": False
                        }, 
                       
                    # Early stopping kwargs 
                    "early_stopping_kwargs": 
                        {"monitor": "loss/valid_loss",
                          "patience": 20,
                          "mode": "min",
                          "min_delta": 0.,
                          "verbose": False,
                          "strict": True, 
                          "check_finite": True, 
                          "stopping_threshold": None,
                          "divergence_threshold": None,
                          "check_on_train_epoch_end": None},    
                    
                    # Logger kwargs
                    "logger_kwargs": 
                        {"offline": False,
                          "id": None ,
                          "anonymous": None ,
                          "project": "celldreamer",
                          "log_model": False ,
                          "prefix": "", 
                          "tags": [],
                          "job_type": "",
                        },   
                       
                    # Denoising model specific 
                    "denoising_module_kwargs": 
                        {
                         "dims":  [64],
                         "time_embed_size": 128, 
                         "class_emb_size": 10,
                         "dropout": 0.0,
                         "encode_class": False
                        }, 
                    
                    # Diffusion model specific
                    "generative_model_kwargs": 
                        {
                         "T": 1000, 
                         "w": 0.3, 
                         "p_uncond": 0.2, 
                         "classifier_free": False, 
                         "learning_rate": 0.0001, 
                         "weight_decay": 0.0001
                        },
                    
                    # Autoencoder trainer hparams
                    "trainer_kwargs": {
                        "max_epochs": 100,
                        "gradient_clip_val": 1.,
                        "gradient_clip_algorithm": "norm",
                        "accelerator": 'gpu',
                        "devices": 1,
                        "check_val_every_n_epoch": 1,
                        "log_every_n_steps": 1,
                        "detect_anomaly": False,
                        "deterministic": False},      
                 })

Initialize the cell estimator 

In [3]:
estimator = CellDreamerEstimator(args_pert)

Create the training folders...
Initialize data module...


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mallepalma[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Initialize feature embeddings...
Initialize model...


Check training batches 

In [4]:
estimator.generative_model

ConditionalGaussianDDPM(
  (denoising_model): MLPTimeStep(
    (encoder): ModuleList(
      (0): MLPTimeEmbedCond(
        (linear_map_class): Identity()
        (l_embedding): Sequential(
          (0): GELU(approximate='none')
          (1): Linear(in_features=128, out_features=64, bias=True)
        )
        (net): Sequential(
          (0): Linear(in_features=2003, out_features=64, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=64, out_features=64, bias=True)
        )
        (relu): ReLU()
        (out_layer): Sequential(
          (0): GELU(approximate='none')
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=64, out_features=64, bias=True)
        )
        (skip_connection): Sequential(
          (0): Linear(in_features=2003, out_features=64, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=64, out_features=64, bias=True)
        )
      )
    )
    (middle_block): MLPTimeEm

In [5]:
estimator.train()

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type        | Params
------------------------------------------------
0 | denoising_model | MLPTimeStep | 12.9 M
1 | mse             | MSELoss     | 0     
------------------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.503    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]