In [1]:
import cellxgene_census
from tiledbsoma import AxisQuery

from tiledbsoma_ml import ExperimentDataset
from tiledbsoma_ml.x_locator import XLocator
from tiledbsoma_ml._query_ids import QueryIDs, Chunks
from tiledbsoma_ml._io_batch_iterable import IOBatchIterable
from tiledbsoma_ml._mini_batch_iterable import MiniBatchIterable

In [2]:
c = cellxgene_census.open_soma(census_version='2025-01-30')
c

<Collection 's3://cellxgene-census-public-us-west-2/cell-census/2025-01-30/soma/' (open for 'r') (3 items)
    'census_data': 's3://cellxgene-census-public-us-west-2/cell-census/2025-01-30/soma/census_data' (unopened)
    'census_info': 's3://cellxgene-census-public-us-west-2/cell-census/2025-01-30/soma/census_info' (unopened)
    'census_spatial_sequencing': 's3://cellxgene-census-public-us-west-2/cell-census/2025-01-30/soma/census_spatial_sequencing' (unopened)>

In [3]:
X = c['census_data']['homo_sapiens'].ms['RNA'].X['raw']
X

<SparseNDArray 's3://cellxgene-census-public-us-west-2/cell-census/2025-01-30/soma/census_data/homo_sapiens/ms/RNA/X/raw' (open for 'r')>

In [4]:
#exp = c['census_data']['homo_sapiens']
exp = c['census_data']['mus_musculus']
exp

<Experiment 's3://cellxgene-census-public-us-west-2/cell-census/2025-01-30/soma/census_data/mus_musculus' (open for 'r') (2 items)
    'ms': 's3://cellxgene-census-public-us-west-2/cell-census/2025-01-30/soma/census_data/mus_musculus/ms' (unopened)
    'obs': 's3://cellxgene-census-public-us-west-2/cell-census/2025-01-30/soma/census_data/mus_musculus/obs' (unopened)>

In [5]:
%%time
value_filter = 'is_primary_data == True and tissue_general in ["spleen", "kidney"] and nnz > 1000'
query = exp.axis_query(
    'RNA',
    obs_query=AxisQuery(value_filter=value_filter),
)
query.n_obs

CPU times: user 4.48 s, sys: 1.5 s, total: 5.98 s
Wall time: 1.31 s


137250

In [6]:
ds0 = ExperimentDataset(query, 'raw', batch_size=1024)
ds, dt = ds0.random_split(.9, .1, seed=111)
chunks = ds.shuffle_chunks
len(chunks)

1931

In [7]:
query_ids = ds.partitioned_query_ids
chunks = ds.shuffle_chunks

In [8]:
%%time
with ds.x_locator.open() as (X, obs):
    io_batch_iter = IOBatchIterable(
        chunks=chunks,
        io_batch_size=ds.io_batch_size,
        obs=obs,
        var_joinids=query_ids.var_joinids,
        X=X,
        obs_column_names=ds.obs_column_names,
        seed=111,
    )
    io_batches = list(io_batch_iter)

CPU times: user 1min 26s, sys: 21.9 s, total: 1min 48s
Wall time: 25.1 s


In [9]:
rows = nnzs = 0
for i, (X, obs) in enumerate(io_batches):
    R, C  = X.shape
    nnz = X.nnz
    rows += R
    nnzs += nnz
    print(f'IO batch {i}: {R:,} x {C:,}, {X.nnz:,} nnz')
print(f'Total: {rows:,} x {C:,}, {nnzs:,} nnz')

IO batch 0: 65,536 x 52,483, 141,691,631 nnz
IO batch 1: 57,989 x 52,483, 125,036,984 nnz
Total: 123,525 x 52,483, 266,728,615 nnz


In [10]:
%%time
mini_batch_iter = MiniBatchIterable(
    io_batch_iter=iter(io_batches),
    batch_size=ds.batch_size,
)
mini_batches = list(mini_batch_iter)
len(mini_batches)

CPU times: user 3.97 s, sys: 18.1 s, total: 22.1 s
Wall time: 1.13 s


121