In [1]:
%load_ext autoreload
%autoreload 2

### State Data Loader Example

Here, we implement scplode in Arc Institute's State Model Data Loader. 
The original Data Loader is located in the Arc Institute's cell-load repository. 
For demonstration purpospes, I have forked the cell-load code here: https://github.com/rkita/cell-load/tree/scplode-integration
and made very minimal changes for scplode-integration. 

Important to note that the fork does not include scplode conversion. 


This is the primary change in the cell-load code (in addition to updating parameters and a `read_h5ad` line)

```python
if self.use_scplode:
    #### SCPLODE DATA ACCESS
    row_data = self.scadata.get([int(idx)])
    data = torch.tensor(row_data, dtype=torch.float32)
else:
    #### ORIGINAL STATE CELL-LOAD CODE BELOW
    attrs = dict(self.h5_file["X"].attrs)
    if attrs["encoding-type"] == "csr_matrix":
        indptr = self.h5_file["/X/indptr"]
        start_ptr = indptr[idx]
        end_ptr = indptr[idx + 1]
        sub_data = torch.tensor(
            self.h5_file["/X/data"][start_ptr:end_ptr], dtype=torch.float32
        )
        sub_indices = torch.tensor(
            self.h5_file["/X/indices"][start_ptr:end_ptr], dtype=torch.long
        )
        counts = torch.sparse_csr_tensor(
            torch.tensor([0], dtype=torch.long),
            sub_indices,
            sub_data,
            (1, self.n_genes),
        )
        data = counts.to_dense().squeeze()
    else:
        row_data = self.h5_file["/X"][idx]
        data = torch.tensor(row_data, dtype=torch.float32)
```

In [1]:
#Install the forked cell-load
#%pip install -q git+https://github.com/rkita/cell-load.git@scplode-integration

In [45]:
#Download Competition Support Set
"https://storage.googleapis.com/vcc_data_prod/datasets/state/competition_support_set.zip"

#Set the path in the toml below. 

In [47]:
#Make sure the version of PerturbationDataModule  is the right one
from cell_load.data_modules import PerturbationDataModule

import inspect
sig = inspect.signature(PerturbationDataModule.__init__)
assert 'use_scplode' in sig.parameters

In [48]:
import toml
config = {
    "datasets": {
        "h1": "/Volumes/T7/vcc_pp/data/raw/competition_support_set/competition_train.h5"
    },
    "training": {
        "h1": "train"
    },
    "zeroshot": {},
    "fewshot": {}
}

with open("config.toml", "w") as f:
    toml.dump(config, f)

In [55]:
#Check whether the h5 file is compressed or not. 

import h5py

with h5py.File("/Volumes/T7/vcc_pp/data/raw/competition_support_set/competition_train.h5", "r") as f:
    def print_compression(name, obj):
        if isinstance(obj, h5py.Dataset):
            print(f"{name}: compression = {obj.compression}")

    f.visititems(print_compression)

X/data: compression = None
X/indices: compression = None
X/indptr: compression = None
obs/_index: compression = None
obs/batch/categories: compression = None
obs/batch/codes: compression = None
obs/batch_var/categories: compression = None
obs/batch_var/codes: compression = None
obs/cell_type/categories: compression = None
obs/cell_type/codes: compression = None
obs/guide_id/categories: compression = None
obs/guide_id/codes: compression = None
obs/target_gene/categories: compression = None
obs/target_gene/codes: compression = None
var/_index: compression = None


In [49]:
#Create the mmaps upfront.
import scplode as sp
sdata = sp.read_h5ad("/Volumes/T7/vcc_pp/data/raw/competition_support_set/competition_train.h5")
sdata[0:10]

[INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)


View of AnnData object with n_obs × n_vars = 10 × 18080
    obs: 'target_gene', 'guide_id', 'batch', 'batch_var', 'cell_type'

In [50]:
%%time

import tracemalloc

tracemalloc.start()

dm = PerturbationDataModule(
    toml_config_path="config.toml",
    embed_key=None,
    num_workers=4,
    batch_col="batch",
    cell_sentence_len=16,
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    use_scplode = True,
    output_space = "all"
)
dm.setup()

# # Get training data
train_loader = dm.train_dataloader()

for i, batch in enumerate(train_loader):
    break  # Just test the first batch

#Assess peak memory
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
peak = peak / 1024**2  # MB
print(f"peak memory: {peak:.2f} MB")

/Volumes/T7/vcc_pp/data/raw/competition_support_set/competition_train.h5


Processing h1:   0%|                                                                                         | 0/1 [00:00<?, ?it/s][INFO] Loading index: obs
[INFO] Loading index: var
[INFO] Loading index: dat (implicitly)
Processing h1: 100%|█████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.81it/s]


Processed competition_train: 221273 train, 0 val, 0 test
peak memory: 139.50 MB
CPU times: user 13.4 s, sys: 586 ms, total: 13.9 s
Wall time: 34.8 s


In [51]:
%%time

import tracemalloc

tracemalloc.start()

dm = PerturbationDataModule(
    toml_config_path="config.toml",
    embed_key=None,
    num_workers=4,
    batch_col="batch",
    cell_sentence_len=16,
    pert_col="target_gene",
    cell_type_key="cell_type",
    control_pert="non-targeting",
    use_scplode = False,
    output_space = "all"
)
dm.setup()

# # Get training data
train_loader = dm.train_dataloader()

for i, batch in enumerate(train_loader):
    break  # Just test the first batch

#Assess peak memory
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
peak = peak / 1024**2  # MB
print(f"peak memory: {peak:.2f} MB")

/Volumes/T7/vcc_pp/data/raw/competition_support_set/competition_train.h5


Processing h1: 100%|█████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.74it/s]


Processed competition_train: 221273 train, 0 val, 0 test
peak memory: 114.58 MB
CPU times: user 13 s, sys: 502 ms, total: 13.5 s
Wall time: 57.3 s


### Conclusion

In this example, using scplode sped up from 57.3 seconds to 34.8 seconds.
Further work needed to evaluate across other contexts. Important variables include, the size/compression of original data the 