# Quickstart `annbatch`

This notebook will walk you through the following steps:
1. How to convert an existing collection of `anndata` files into a shuffled, zarr-based, collection of `anndata` datasets
2. How to load the converted collection using `annbatch`
3. Extend an existing collection with new `anndata` datasets

In [1]:
# !pip install annbatch[zarrs, torch]

In [2]:
# Download two example datasets from CELLxGENE
!wget https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad
!wget https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad

zsh:1: command not found: wget
zsh:1: command not found: wget


**IMPORTANT**: Configure zarrs

This step is both required for converting existing `anndata` files into a performant, shuffled collection of datasets for mini batch loading

In [7]:
import zarr

zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})

<donfig.config_obj.ConfigSet at 0x10837ba70>

In [8]:
import warnings

# Suppress zarr vlen-utf8 codec warnings
warnings.filterwarnings(
    "ignore",
    message="The codec `vlen-utf8` is currently not part in the Zarr format 3 specification.*",
    category=UserWarning,
    module="zarr.codecs.vlen_utf8",
)

## Converting existing `anndata` files into a shuffled collection

The conversion code will take care of the following things:
* Align (outer join) the gene spaces across all datasets listed in `adata_paths`
  * The gene spaces are outer-joined based on the gene names provided in the `var_names` field of the individual `AnnData` objects.
  * If you want to subset to specific gene space, you can provide a list of gene names via the `var_subset` parameter.
* Shuffle the cells across all datasets (this works on larger than memory datasets as well).
  * This is important for block-wise shuffling during data loading.
* Shuffle the input files across multiple output datasets:
  * The size of each individual output dataset can be controlled via the `n_obs_per_dataset` parameter.
  * We recommend to choose a dataset size that comfortably fits into system memory.


You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`

In [11]:
import anndata as ad
from annbatch import DatasetCollection


# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).
# To have a store that only contains raw counts, we can write the following load_adata function
def read_lazy_x_and_obs_only(path) -> ad.AnnData:
    """Custom load function to only load raw counts from CxG data."""
    # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint
    adata_ = ad.experimental.read_lazy(path)
    if adata_.raw is not None:
        x = adata_.raw.X
        var = adata_.raw.var
    else:
        x = adata_.X
        var = adata_.var

    return ad.AnnData(
        X=x,
        obs=adata_.obs.to_memory(),
        var=var.to_memory(),
    )


collection = DatasetCollection(zarr.open("annbatch_collection"))
collection.add(
    # List all the h5ad files you want to include in the collection
    adata_paths=["866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad", "f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad"],
    # Path to store the output collection
    shuffle=True,  # Whether to pre-shuffle the cells of the collection
    n_obs_per_dataset=2_097_152,  # Number of cells per dataset shard
    var_subset=None,  # Optionally subset the collection to a specific gene space
    load_adata=read_lazy_x_and_obs_only,
)

checking for mismatched keys: 100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
  _check_for_mismatched_keys(adata_paths)
  _check_for_mismatched_keys(adata_paths)
loading: 2it [00:00,  2.26it/s]
processing chunks: 100%|██████████| 1/1 [00:24<00:00, 24.02s/it]


<annbatch.io.DatasetCollection at 0x12e174fb0>

## Data loading example

In [12]:
from pathlib import Path

COLLECTION_PATH = Path("annbatch_collection/")

In [14]:
import anndata as ad

from annbatch import Loader

ds = Loader(
    batch_size=4096,  # Total number of obs per yielded batch
    chunk_size=256,  # Number of obs to load from disk contiguously - default settings should work well
    preload_nchunks=32,  # Number of chunks to preload + shuffle - default settings should work well
    preload_to_gpu=False,
    # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%
    to_torch=True,
)

# Add in the shuffled data that should be used for training
ds.add_collection(collection)

<annbatch.loader.Loader at 0x12c2bfa40>

**IMPORTANT:**
* The `Loader` yields batches of sparse tensors.
* The conversion to dense tensors should be done on the GPU, as shown in the example below.
  * First call `.cuda()` and then `.to_dense()`
  * E.g. `x = x.cuda().to_dense()`
  * This is significantly faster than doing the dense conversion on the CPU.


In [18]:
# Iterate over dataloader
import tqdm

for batch in tqdm.tqdm(ds):
    x, obs = batch["data"], batch["labels"]["cell_type"]
    # Important: Convert to dense on GPU
    x = x.cuda().to_dense()
    # Feed data into your model
    ...

  0%|          | 42/171792 [00:10<12:13:16,  3.90it/s]


## Optional: Extend an existing collection with a new dataset

You might want to extend an existing pre-shuffled collection with a new dataset.
This can be done using the `add_to_collection` function.

This function will take care of shuffling the new dataset into the existing collection without having to re-shuffle the entire collection.

In [19]:
def read_x_and_obs_only(path) -> ad.AnnData:
    """Custom load function to only load raw counts from CxG data."""
    # As it's only a small dataset, we can load the full dataset into memory to speed up computations
    adata_ = ad.read_h5ad(path)  # Replace with ad.experimental.read_lazy if data does not fit into memory anymore
    if adata_.raw is not None:
        x = adata_.raw.X
        var = adata_.raw.var
    else:
        x = adata_.X
        var = adata_.var

    return ad.AnnData(X=x, obs=adata_.obs, var=var)


collection.add(
    adata_paths=[
        "866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad",
    ],
    load_adata=read_x_and_obs_only,
)

checking for mismatched keys: 100%|██████████| 1/1 [00:00<00:00,  2.09it/s]
loading: 1it [00:10, 10.77s/it]
checking for mismatched keys: 100%|██████████| 2/2 [00:01<00:00,  1.18it/s]
    obs: 'reference_genome', 'gene_annotation_version', 'alignment_software', 'intronic_reads_counted', 'donor_id', 'donor_age', 'self_reported_ethnicity_ontology_term_id', 'donor_cause_of_death', 'donor_living_at_sample_collection', 'sample_id', 'sample_preservation_method', 'tissue_ontology_term_id', 'development_stage_ontology_term_id', 'sample_collection_method', 'tissue_source', 'tissue_type', 'sample_collection_year', 'suspension_derivation_process', 'suspension_uuid', 'suspension_type', 'tissue_handling_interval', 'library_id', 'assay_ontology_term_id', 'sequenced_fragment', 'institute', 'library_id_repository', 'sequencing_platform', 'is_primary_data', 'cell_type_ontology_term_id', 'author_cell_type', 'disease_ontology_term_id', 'reported_diseases', 'sex_ontology_term_id', 'nCount_RNA', 'nFeature_

<annbatch.io.DatasetCollection at 0x12e174fb0>