# Test the estimator class 

In [1]:
import os
import torch
from celldreamer.estimator.celldreamer_estimator import CellDreamerEstimator
from celldreamer.paths import PERT_DATA_DIR
from celldreamer.data.utils import Args

from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.model_summary import ModelSummary

  warn(


Initialize the ```args``` dict and the estimator class

In [3]:
args_pert = Args(
                   {   
                    #General 
                    "train": True,
                    "experiment_name": "try_experiment",
                    "task": "perturbation_modelling",
                    "freeze_embeddings": True,
                    "feature_type": "grover",
                    "data_path": PERT_DATA_DIR / 'sciplex' / 'sciplex_complete_middle_subset.h5ad',
                    "batch_size": 128, 
                    "use_latent_repr": True,
                    "one_hot_encode_features": True,
                    "resume": False,
                    
                    # Perturbation setting specific
                    "perturbation_key": "condition",
                    "dose_key": "dose",
                    "covariate_keys": "cell_type",
                    "smile_keys": "SMILES",
                    "degs_key": "lincs_DEGs",
                    "pert_category": "cov_drug_dose_name",
                    "split_key": "split_ho_pathway",
                    "use_drugs_idx":True,
                    "embedding_dimensions": 100,
                    "one_hot_encode_features": False,
                    "doser_width":128,
                    "doser_depth":3,
                     
                    # General model 
                    "generative_model":"diffusion", 
                    "denoising_model": "mlp",
                    
                    # Autoencoder 
                    "autoencoder_kwargs": {"in_dim": 2000,
                                  "batch_size": 32, 
                                  "hidden_dim_encoder": [256, 128, 64], 
                                  "hidden_dim_decoder": [64, 128, 64], 
                                  "batch_norm": True, 
                                  "layer_norm": False,
                                  "activation": torch.nn.ReLU,
                                  "output_activation": torch.nn.Identity, 
                                  "reconst_loss": "mse", 
                                  "dropout": 0.0,
                                  "weight_decay": 0.1, 
                                  "learning_rate": 0.001,
                                  "optimizer": torch.optim.Adam, 
                                  "lr_scheduler": None,
                                  "lr_scheduler_kwargs": None
                                  },
                    
                    # Denoising model specific 
                    "denoising_module_kwargs": 
                        {
                         "dims": [128, 64],
                         "time_embed_size": 100, 
                         "class_emb_size": 100,
                         "dropout": 0.0
                        }, 
                    
                    # Diffusion model specific
                    "generative_model_kwargs": 
                        {
                         "T": 4.000, 
                         "w": 0.3, 
                         "v": 0.2,
                         "p_uncond": 0.2, 
                         "logging_freq": 1000,  
                         "classifier_free": False
                        },
                    
                    # Autoencoder trainer hparams
                    "trainer_autoencoder_kwargs": {
                        'max_epochs': 1000,
                        'gradient_clip_val': 1.,
                        'gradient_clip_algorithm': 'norm',
                        'default_root_dir': "",
                        'accelerator': 'gpu',
                        'devices': 1,
                        'num_sanity_val_steps': 0,
                        'check_val_every_n_epoch': 1,
                        'log_every_n_steps': 100,
                        'detect_anomaly': False,
                        'enable_progress_bar': True,
                        'enable_model_summary': False,
                        'enable_checkpointing': True},  
                       
                    # Generative model trainer hyperparams 
                    "trainer_generative_kwargs": {
                        'max_epochs': 1000,
                        'gradient_clip_val': 1.,
                        'gradient_clip_algorithm': 'norm',
                        'default_root_dir': "",
                        'accelerator': 'gpu',
                        'devices': 1,
                        'num_sanity_val_steps': 0,
                        'check_val_every_n_epoch': 1,
                        'log_every_n_steps': 100,
                        'detect_anomaly': False,
                        'enable_progress_bar': True,
                        'enable_model_summary': False,
                        'enable_checkpointing': True}                 
                 })

Initialize the cell estimator 

In [4]:
estimator = CellDreamerEstimator(args_pert)

Create the training folders...
Initialize data module...




Initialize feature embeddings...
Initialize model...


Check feature embeddings 

In [12]:
estimator.args.denoising_module_kwargs["num_classes"]

{'drug': 3400, 'cell_type': 100}

In [6]:
estimator.feature_embeddings

{'drug': DrugsFeaturizer(
   (features): Embedding(188, 3400)
   (dosers): MLP(
     (0): Linear(in_features=3401, out_features=128, bias=True)
     (1): ReLU(inplace=True)
     (2): Dropout(p=0.0, inplace=True)
     (3): Linear(in_features=128, out_features=128, bias=True)
     (4): ReLU(inplace=True)
     (5): Dropout(p=0.0, inplace=True)
     (6): Linear(in_features=128, out_features=128, bias=True)
     (7): ReLU(inplace=True)
     (8): Dropout(p=0.0, inplace=True)
     (9): Linear(in_features=128, out_features=1, bias=True)
   )
 ),
 'cell_type': CategoricalFeaturizer(
   (embeddings): Embedding(3, 100)
 )}

Check training batches 

In [8]:
next(iter(estimator.datamodule.train_dataloader))

{'X': tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0428, 0.0000, 0.4773,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.5781,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.6152,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.5222,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]),
 'X_degs': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 'y': {'y_drug': [tensor([ 40, 187, 162,  68,  56,  54, 132, 185,  46, 161,  73, 117,   3, 136,
           159, 178, 179,  76,  79, 105,  48, 123, 111,  11, 140, 182, 149, 105,
           120, 125,  70,  85, 172,  41, 172, 167, 117, 119,  60, 170,  88, 181,
            87, 120,  91, 148, 130,  82, 129,   2, 100,   2