# Test the estimator class 

In [None]:
import os
import scanpy as sc
import numpy as np
import pandas as pd
import torch
from celldreamer.estimator.celldreamer_estimator import CellDreamerEstimator
from celldreamer.paths import PERT_DATA_DIR
from celldreamer.data.utils import Args

from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.model_summary import ModelSummary

Initialize the ```args``` dict and the estimator class

In [None]:
args_pert = Args(
                   {   
                    #General 
                    "train": True, 
                    "experiment_name": "try_experiment",
                    "task": "perturbation_modelling",
                    "data_path": PERT_DATA_DIR / 'sciplex' / 'sciplex_complete_middle_subset.h5ad',
                    "batch_size": 512,
                    "resume": False,
                    "train_autoencoder": False, 
                    "use_latent_repr": False,
                    "pretrained_autoencoder": False, 
                    "checkpoint_autoencoder": None,
                    "pretrained_generative":False,
                    "checkpoint_generative": True, 
                       
                    # Perturbation setting specific
                    "perturbation_key": "condition",
                    "dose_key": "dose",
                    "covariate_keys": "cell_type",
                    "smile_keys": "SMILES",
                    "degs_key": "lincs_DEGs",
                    "pert_category": "cov_drug_dose_name",
                    "split_key": "split_ho_pathway", 
                    "use_drugs":True,
                    "feature_type": "grover",
                    "freeze_embeddings": True,
                    "doser_width": 128,
                    "doser_depth": 3
                    "embedding_dimensions": 100,
                    "one_hot_encode_features": False,
                     
                    # General model 
                    "generative_model": "diffusion", 
                    "denoising_model": "mlp",
                    
                    # Autoencoder 
                    "autoencoder_kwargs": None,
                       
                    # Checkpoint kwargs 
                    "checkpoint_kwargs": 
                        {"filename": "epoch_{epoch:01d}"
                          "monitor": "loss/valid_loss", 
                          "mode": "min", 
                          "save_last": True,
                          "auto_insert_metric_name": False
                        }, 
                       
                    # Early stopping kwargs 
                    "early_stopping_kwargs": 
                        {"monitor": "loss/valid_loss",
                          "patience": 20,
                          "mode": "min",
                          "min_delta": 0.,
                          "verbose": False,
                          "strict": True, 
                          "check_finite": True, 
                          "stopping_threshold": None,
                          "divergence_threshold": None,
                          "check_on_train_epoch_end": None},    
                    
                    # Logger kwargs
                    "logger_kwargs": 
                        {"offline": False
                          "id": None 
                          "anonymous": None 
                          "project": "PerturbSeq_CMV"
                          "log_model": False 
                          "prefix": "" 
                          "tags": []
                          "job_type": ""
                        },   
                       
                    # Denoising model specific 
                    "denoising_module_kwargs": 
                        {
                         "dims": [128, 64],
                         "time_embed_size": 100, 
                         "class_emb_size": 100,
                         "dropout": 0.0
                        }, 
                    
                    # Diffusion model specific
                    "generative_model_kwargs": 
                        {
                         "T": 10, 
                         "w": 0.3, 
                         "v": 0.2,
                         "p_uncond": 0.2, 
                         "logging_freq": 1000,  
                         "classifier_free": False, 
                         "learning_rate": 0.001, 
                         "weight_decay": 0.0001
                        },
                    
                    # Autoencoder trainer hparams
                    "trainer_autoencoder_kwargs": {
                        "max_epochs": 100,
                        "gradient_clip_val": 1.,
                        "gradient_clip_algorithm": "norm",
                        "accelerator": 'gpu',
                        "devices": 1,
                        "check_val_every_n_epoch": 10,
                        "log_every_n_steps": 10,
                        "detect_anomaly": False,
                        "deterministic": False}    
                       
                    # Generative model trainer hyperparams 
                    "trainer_generative_kwargs": {
                        "max_epochs": 100,
                        "gradient_clip_val": 1.,
                        "gradient_clip_algorithm": "norm",
                        "accelerator": 'gpu',
                        "devices": 1,
                        "check_val_every_n_epoch": 10,
                        "log_every_n_steps": 10,
                        "detect_anomaly": False,
                        "deterministic": False}                 
                 })

Initialize the cell estimator 

In [None]:
estimator = CellDreamerEstimator(args_pert)

Check training batches 

In [None]:
estimator.train()

In [None]:
len(estimator.datamodule.valid_dataloader.dataset)

In [None]:
next(iter(estimator.datamodule.train_dataloader))["X"].device