## Tutorial on perturbation data loader

In [11]:
from celldreamer.data.utils import Args
from celldreamer.data.pert_loader import PertDataset
from celldreamer.paths import PERT_DATA_DIR
import torch 

Define hyperparameters to query the data loaders 

In [12]:
args = Args({"data_path": PERT_DATA_DIR / "sciplex" / "sciplex_complete_middle_subset.h5ad", 
             "perturbation_key": "condition",
             "dose_key": "dose",
             "covariate_keys": "cell_type", 
             "smiles_key": "SMILES", 
             "degs_key": "lincs_DEGs", 
             "pert_category": "cov_drug_dose_name",
             "split_key": "split_ho_pathway",
             "use_drugs_idx": True      , 
             "batch_size": 32
            })

Initialize the data class

In [13]:
dataset = PertDataset(data=args.data_path,
                        perturbation_key=args.perturbation_key,
                        dose_key=args.dose_key,
                        covariate_keys=args.covariate_keys,
                        smiles_key=args.smiles_key,
                        degs_key=args.degs_key,
                        pert_category=args.pert_category,
                        split_key=args.split_key,
                        use_drugs_idx=True)



Define data loaders based on the initialized classes 

In [14]:
datamodule = Args({"train_dataloader": torch.utils.data.DataLoader(
                                                        dataset.subset("train", "all"),
                                                        batch_size=args.batch_size,
                                                        shuffle=True,
                                                    ),
                                    "valid_dataloader": torch.utils.data.DataLoader(
                                                        dataset.subset("test", "all"),
                                                        batch_size=args.batch_size,
                                                        shuffle=True,
                                                    ),
                                    "test_dataloader": torch.utils.data.DataLoader(
                                                        dataset.subset("ood", "all"),
                                                        batch_size=args.batch_size,
                                                        shuffle=True,
                                    )})

Collec a batch for inspection

In [15]:
batch = next(iter(datamodule.train_dataloader))

```batch["X"]``` contains the cell expression profile of dimensionality ```n_batch x n_genes```

In [16]:
batch["X"]

tensor([[0.0000, 0.0000, 0.9085,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.3214,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6893,  ..., 0.0000, 0.0000, 0.0000]])

```batch["X_degs"]``` contains the (pre-computed) indexes of differentially expressed genes per perturbation 

In [17]:
batch["X_degs"]

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

```batch["y"]``` contains a dictionary representing the covariates associated with an observation

In [18]:
print(f"Keys in the covariate dictionary {list(batch['y'].keys())}")

Keys in the covariate dictionary ['y_drug', 'y_cell_type']


The ```y_drug``` key is associated to two elements:
* the indexes of the drugs used to extract the structural encoding 
* the dosage used for dose encoding 

In [19]:
batch["y"]["y_drug"]

[tensor([  7,  65, 106, 117, 176, 163, 101, 131,  77, 163,  75, 161,  77,  21,
          93,  30, 176,   0, 163, 179, 178,  52,  35,  11, 140,  63,  76, 150,
         174, 123,  57,  88]),
 tensor([ 1000.,    10.,   100.,    10.,  1000.,  1000.,   100.,   100.,   100.,
          1000.,   100.,    10.,  1000.,   100.,   100., 10000.,  1000., 10000.,
           100., 10000., 10000., 10000., 10000., 10000., 10000.,  1000.,   100.,
         10000.,  1000., 10000.,  1000.,   100.])]

```y_cell_type``` is a one hot encoded array 

In [20]:
batch["y"]["y_cell_type"]

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])