In [20]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
from celldreamer.estimator.celldreamer_estimator import CellDreamerEstimator
from celldreamer.data.utils import Args
from celldreamer.paths import PERT_DATA_DIR
from pathlib import Path 

**Load scRNAseq**

In [10]:
args_scrnaseq = Args({"task": "cell_generation",
                "data_path": "/lustre/scratch/users/felix.fischer/merlin_cxg_simple_norm_parquet",
                 "batch_size": 128,
                 "drop_last": False,
                 "one_hot_encode_features": False,
                 "embedding_dimensions":100, 
                 "categories":["cell_type", "disease"],
})

In [12]:
estimator_scrnaseq_embeddings = CellDreamerEstimator(args_scrnaseq)
estimator_scrnaseq_embeddings.init_datamodule()
estimator_scrnaseq_embeddings.init_feature_embeddings()

In [13]:
batch = next(iter(estimator_scrnaseq_embeddings.datamodule.train_dataloader()))

In [14]:
batch[0]["X"].shape

torch.Size([128, 19357])

In [15]:
# Test featurizer 
estimator_scrnaseq_embeddings.feature_embeddings

{'cell_type': CategoricalFeaturizer(
   (embeddings): Embedding(128, 100)
 ),
 'disease': CategoricalFeaturizer(
   (embeddings): Embedding(29, 100)
 )}

In [16]:
batch_disease = batch[0]["disease"].squeeze()

In [17]:
embeddings_batch = estimator_scrnaseq_embeddings.feature_embeddings["disease"](batch_disease)
print(embeddings_batch.shape)

torch.Size([128, 100])


In [18]:
embeddings_batch

tensor([[-0.4553, -0.7296, -0.2910,  ...,  0.7050,  1.7850, -0.5511],
        [-0.2434,  1.1698, -0.4538,  ..., -0.3709, -0.7903,  1.8038],
        [ 0.1290,  0.8190,  0.7355,  ...,  0.8869,  1.4369, -0.4915],
        ...,
        [-0.4553, -0.7296, -0.2910,  ...,  0.7050,  1.7850, -0.5511],
        [ 0.2846,  1.5765, -0.6417,  ...,  1.4187, -1.3890,  1.7056],
        [ 0.1290,  0.8190,  0.7355,  ...,  0.8869,  1.4369, -0.4915]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

**Load perturbation dataset**

In [2]:
path = Path(PERT_DATA_DIR)

In [3]:
args_pert = Args({"task": "perturbation_modelling",
                  "freeze_embeddings": True,
                  "feature_type": "grover",
                  "data_path": path / 'sciplex' / 'sciplex_complete_middle_subset.h5ad',
                  "perturbation_key": "condition",
                  "dose_key": "dose",
                  "covariate_keys": "cell_type",
                  "smile_keys": "SMILES",
                  "degs_key": "lincs_DEGs",
                  "pert_category": "cov_drug_dose_name",
                  "split_key": "split_ho_pathway",
                  "batch_size": 128, 
                  "use_drugs_idx":True})

In [4]:
estimator_drugs = CellDreamerEstimator(args_pert)
estimator_drugs.init_datamodule()
estimator_drugs.init_feature_embeddings()



In [10]:
batch = next(iter(estimator_drugs.datamodule.train_dataloader))

In [11]:
batch[0].shape

torch.Size([128, 2000])

In [12]:
batch[1]

tensor([173, 187,  85, 127,  88,  74,  16,  30, 171, 162,  28, 102, 113,  25,
        161, 152, 135, 107, 161,  55, 103,  96,  56,  28, 181,  18, 105,  12,
         55, 102, 110, 103, 183, 102, 134, 167, 141, 140,  57,  57,  71,  60,
        122,  20,  20,   6, 117, 119,  73,  43, 137,  35, 109, 147,  70,  38,
        140, 106, 107, 110,  98,  11, 187, 187, 156,  20, 115, 170,  46, 128,
         73,  20,  38, 186,  60,  78, 151, 176, 135, 112, 168, 170,   3,  63,
        186, 151,  35, 114, 175, 173, 180,  85, 113, 111,  18, 121,   8,  90,
         23, 148,  62,  25, 165,  54, 112, 187, 187,  76,  73, 109,  64,  94,
        181,  17, 107,  55, 182, 140,  59, 126, 157, 135,  15, 133,  54,  47,
        106, 170])

In [13]:
batch[2]

tensor([10000.,     0.,  1000., 10000.,  1000.,    10.,   100., 10000.,  1000.,
          100.,  1000., 10000.,   100.,   100.,   100.,   100.,   100., 10000.,
        10000.,  1000.,    10.,  1000.,  1000.,    10.,   100., 10000.,  1000.,
        10000., 10000., 10000.,  1000., 10000.,  1000.,   100.,   100.,   100.,
         1000.,  1000., 10000., 10000.,  1000., 10000.,   100.,  1000.,    10.,
          100., 10000., 10000., 10000.,  1000.,   100., 10000.,    10.,  1000.,
        10000.,    10.,    10.,    10., 10000.,  1000., 10000.,  1000.,     0.,
            0.,    10., 10000., 10000.,   100.,   100.,   100., 10000., 10000.,
         1000.,  1000.,  1000.,  1000.,  1000., 10000.,  1000.,    10., 10000.,
          100.,    10.,   100.,    10., 10000., 10000.,   100.,   100.,  1000.,
          100.,  1000.,  1000.,    10., 10000., 10000.,   100.,  1000.,    10.,
         1000., 10000.,   100.,    10.,  1000.,   100.,     0.,     0.,    10.,
        10000.,  1000., 10000.,  1000., 

In [9]:
batch[3]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [14]:
batch[4]

tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1

In [18]:
embeddings_batch = estimator_drugs.feature_embeddings(batch[1])
print(embeddings_batch.shape)

torch.Size([128, 3400])
