In [1]:
%load_ext autoreload

In [2]:
from celldreamer.estimator.celldreamer_estimator import CellDreamerEstimator
from celldreamer.models.base.variance_scheduler.cosine import CosineScheduler
from celldreamer.data.utils import Args
from celldreamer.paths import PERT_DATA_DIR
from pathlib import Path 
import pandas as pd
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning_fabric.utilities.seed import seed_everything
from os.path import join
import os

In [3]:
root = os.path.dirname(os.path.abspath(os.getcwd()))
MODEL = 'ddpm_cellnet'

**Load scRNAseq**

In [4]:
DATA_DIR = "/lustre/scratch/users/felix.fischer/merlin_cxg_simple_norm"
METADATA_PATH = "/lustre/scratch/users/felix.fischer/merlin_cxg_simple_norm_parquet"

In [5]:
# args_scrnaseq = Args({"task": "cell_generation",
#                 "data_path": DATA_DIR,
#                 "metadata_path": METADATA_PATH, 
#                  "batch_size": 128,
#                  "drop_last": False,
#                  "one_hot_encode_features": False,
#                  "embedding_dimensions":100, 
#                  "categories":["cell_type"],
#                  "generative_model":"diffusion", 
#                  "denoiser_model": "mlp",
                 
#                 "denoising_module_kwargs": {
#                     "hidden_layers":3
#                 }, 
                      
#                 "model_kwargs": {
#                      "T": 4.000, 
#                      "w": 0.3, 
#                      "v": 0.2,
#                      "p_uncond": 0.2, 
#                      "number_of_genes": len(pd.read_parquet(join(METADATA_PATH, "var.parquet"))),
#                      "logging_freq": 1000,  
#                      "classifier_free": False, 
#                      "variance_scheduler": CosineScheduler(T=4_000)},
                
#                 "trainer_kwargs": {
#                         'max_epochs': 1000,
#                         'gradient_clip_val': 1.,
#                         'gradient_clip_algorithm': 'norm',
#                         'default_root_dir': "",
#                         'accelerator': 'gpu',
#                         'devices': 1,
#                         'num_sanity_val_steps': 0,
#                         'check_val_every_n_epoch': 1,
#                         'logger': [TensorBoardLogger(LOGS_PATH = os.path.join(root, 'trained_models/tb_logs', MODEL), name='default', 
#                                                     save_dir="")],
#                         'log_every_n_steps': 100,
#                         'detect_anomaly': False,
#                         'enable_progress_bar': True,
#                         'enable_model_summary': False,
#                         'enable_checkpointing': True,
#                         'callbacks': [
#                             TQDMProgressBar(refresh_rate=100),
#                             LearningRateMonitor(logging_interval='step'),
#                             ModelCheckpoint(filename='best_train_loss', monitor='train_loss_epoch', mode='min',
#                                             every_n_epochs=1, save_top_k=1),
#                             ModelCheckpoint(filename='best_val_loss', monitor='val_loss', mode='min',
#                                             every_n_epochs=1, save_top_k=1)
#                         ]}
#                      })

In [6]:
# estimator_scrnaseq_embeddings = CellDreamerEstimator(args_scrnaseq)
# estimator_scrnaseq_embeddings.init_datamodule()
# estimator_scrnaseq_embeddings.init_feature_embeddings()

In [7]:
# batch = next(iter(estimator_scrnaseq_embeddings.datamodule.train_dataloader()))

In [8]:
# batch[0]["X"].shape

In [9]:
# # Test featurizer 
# estimator_scrnaseq_embeddings.feature_embeddings

In [10]:
# batch_disease = batch[0]["disease"].squeeze()

In [11]:
# embeddings_batch = estimator_scrnaseq_embeddings.feature_embeddings["disease"](batch_disease)
# print(embeddings_batch.shape)

In [12]:
# embeddings_batch

**Load perturbation dataset**

In [13]:
pert_path = Path(PERT_DATA_DIR)

In [1]:
args_pert = Args({"task": "perturbation_modelling",
                "freeze_embeddings": True,
                "feature_type": "grover",
                "data_path": pert_path / 'sciplex' / 'sciplex_complete_middle_subset.h5ad',
                "perturbation_key": "condition",
                "dose_key": "dose",
                "covariate_keys": "cell_type",
                "smile_keys": "SMILES",
                "degs_key": "lincs_DEGs",
                "pert_category": "cov_drug_dose_name",
                "split_key": "split_ho_pathway",
                "batch_size": 128, 
                "use_drugs_idx":True,
                "embedding_dimensions": 100,
                "one_hot_encode_features": False,
                 
                "generative_model":"diffusion", 
                "denoising_model": "mlp",

                "denoising_module_kwargs": {
                    "dims": [128, 64],
                    "time_embed_size": 100, 
                    "class_emb_size": 100,
                    "dropout": 0.0
                }, 

                "model_kwargs": {
                     "T": 4.000, 
                     "w": 0.3, 
                     "v": 0.2,
                     "p_uncond": 0.2, 
                     "logging_freq": 1000,  
                     "classifier_free": False, 
                     "variance_scheduler": CosineScheduler(T=4_000)},

                "trainer_kwargs": {
                        'max_epochs': 1000,
                        'gradient_clip_val': 1.,
                        'gradient_clip_algorithm': 'norm',
                        'default_root_dir': "",
                        'accelerator': 'gpu',
                        'devices': 1,
                        'num_sanity_val_steps': 0,
                        'check_val_every_n_epoch': 1,
                        'logger': [TensorBoardLogger(LOGS_PATH = os.path.join(root, 'trained_models/tb_logs', MODEL), name='default', 
                                                    save_dir="")],
                        'log_every_n_steps': 100,
                        'detect_anomaly': False,
                        'enable_progress_bar': True,
                        'enable_model_summary': False,
                        'enable_checkpointing': True,
                        'callbacks': [
                            TQDMProgressBar(refresh_rate=100),
                            LearningRateMonitor(logging_interval='step'),
                            ModelCheckpoint(filename='best_train_loss', monitor='train_loss_epoch', mode='min',
                                            every_n_epochs=1, save_top_k=1),
                            ModelCheckpoint(filename='best_val_loss', monitor='val_loss', mode='min',
                                            every_n_epochs=1, save_top_k=1)
                        ]}                 
                 })

NameError: name 'Args' is not defined

In [15]:
estimator_drugs = CellDreamerEstimator(args_pert)
estimator_drugs.init_datamodule()
estimator_drugs.init_feature_embeddings()



In [16]:
batch = next(iter(estimator_drugs.datamodule.train_dataloader))

In [17]:
batch[0]

tensor([[0.0000, 0.0000, 1.0058,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5067,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.6080,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.7673,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5353,  ..., 0.0000, 0.0000, 0.0000]])

In [18]:
batch[1]

tensor([ 50,  62,  55, 181, 121,  84, 179, 132, 130,  88, 141, 117,  23,  34,
         27,  96, 166,  92, 157, 161, 183,   8, 186,  87, 177, 187, 117,   4,
          5,  68,  81, 128, 187, 138, 116, 147,  21,  52, 132, 166,   5, 185,
        116, 187,  14,  84,  96,  96, 113,  76, 145,  69, 118, 140,  83, 148,
        106,  11, 113,  74, 117,  39, 142,  76,  20,  50,  89,  15, 176,  67,
        128,  19, 111,  35, 158,  40,  20, 187, 125,  45, 106, 132, 130, 129,
         98,  68, 111,  71, 116, 170, 187,  68,  40, 164, 139, 187,  37,  46,
          0, 165, 117, 158, 164,  87,  56,   1,  30,  82,  62, 141,  20,  73,
        106, 168,  99, 140,  57,  93, 163,  19, 108, 121, 187, 153, 180, 120,
        187, 144])

In [19]:
batch[2]

tensor([  100.,    10.,  1000.,   100.,    10., 10000.,   100.,  1000.,    10.,
          100.,   100.,  1000.,    10.,  1000.,   100.,    10.,    10.,  1000.,
           10.,  1000.,   100., 10000.,    10., 10000.,    10.,     0.,  1000.,
          100., 10000.,    10.,  1000.,   100.,     0., 10000.,  1000.,   100.,
         1000., 10000.,  1000.,  1000.,  1000.,  1000.,  1000.,     0.,  1000.,
        10000.,   100.,  1000.,  1000., 10000.,   100., 10000.,   100.,   100.,
        10000., 10000.,   100.,  1000.,    10.,  1000., 10000.,  1000.,  1000.,
           10.,  1000.,   100., 10000.,  1000.,    10.,    10.,    10.,    10.,
          100.,  1000., 10000.,    10.,   100.,     0.,   100., 10000.,   100.,
        10000., 10000.,  1000., 10000., 10000.,    10.,  1000.,   100.,  1000.,
            0., 10000.,  1000.,    10.,    10.,     0.,  1000., 10000., 10000.,
        10000., 10000.,    10.,    10.,    10., 10000., 10000.,   100., 10000.,
         1000.,    10., 10000.,    10., 

In [20]:
batch[3]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.]])

In [21]:
batch[4]

tensor([[1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0

In [22]:
estimator_drugs.get_fixed_rna_model_params()
estimator_drugs.init_model()