In [None]:
!pip install cellxgene-census

In [None]:
!pip install obonet

In [1]:
from os.path import join

import pandas as pd
import numpy as np

from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Utils code

In [2]:
import obonet
import networkx


url = 'http://purl.obolibrary.org/obo/cl/cl-simple.obo'
graph = obonet.read_obo(url, ignore_obsolete=True)

# only use "is_a" edges
edges_to_delete = []
for i, x in enumerate(graph.edges):
    if x[2] != 'is_a':
        edges_to_delete.append((x[0], x[1]))
for x in edges_to_delete:
    graph.remove_edge(u=x[0], v=x[1])

# define mapping from id to name
id_to_name = {id_: data.get('name') for id_, data in graph.nodes(data=True)}
# define inverse mapping from name to id
name_to_id = {v: k for k, v in id_to_name.items()}


def find_child_nodes(cell_type):
    return [id_to_name[node] for node in networkx.ancestors(graph, name_to_id[cell_type])]


def find_parent_nodes(cell_type):
    return [id_to_name[node] for node in networkx.descendants(graph, name_to_id[cell_type])]

# Select data to download

In [3]:
import cellxgene_census

census = cellxgene_census.open_soma(census_version="2023-05-08")

In [4]:
PROTOCOLS = [
    "10x 5' v2", 
    "10x 3' v3", 
    "10x 3' v2", 
    "10x 5' v1", 
    "10x 3' v1", 
    "10x 3' transcription profiling", 
    "10x 5' transcription profiling"
]


COLUMN_NAMES = [
    "soma_joinid",
    "is_primary_data",
    "dataset_id", 
    "donor_id",
    "assay", 
    "cell_type", 
    "development_stage", 
    "disease", 
    "tissue", 
    "tissue_general"
]

In [5]:
obs = (
    census["census_data"]["homo_sapiens"]
    .obs
    .read(
        column_names=COLUMN_NAMES,
        value_filter=f"is_primary_data == True and assay in {PROTOCOLS}"
    )
    .concat()
    .to_pandas()
)

In [6]:
obs['tech_sample'] = (obs.dataset_id + '_' + obs.donor_id).astype('category')

for col in COLUMN_NAMES:
    if obs[col].dtype == object:
        obs[col] = obs[col].astype('category')


In [7]:
obs.dtypes

soma_joinid             int64
is_primary_data          bool
dataset_id           category
donor_id             category
assay                category
cell_type            category
development_stage    category
disease              category
tissue               category
tissue_general       category
tech_sample          category
dtype: object

In [8]:
# remove all cell types which are not a subtype of native cell
cell_types_to_remove = obs[~obs.cell_type.isin(find_child_nodes('native cell'))].cell_type.unique().tolist()

# remove all cell types which have less than 5000 cells
cell_freq = obs.cell_type.value_counts()
cell_types_to_remove += cell_freq[cell_freq < 5000].index.tolist()

# remove cell types which have less than 30 tech_samples
tech_samples_per_cell_type = obs[['cell_type', 'tech_sample']].groupby('cell_type').agg({'tech_sample': 'nunique'}).sort_values('tech_sample')
cell_types_to_remove += tech_samples_per_cell_type[tech_samples_per_cell_type.tech_sample <= 30].index.tolist()

# filter out too granular labels
# remove all cells that have <= 7 parents in the cell ontology
cell_types = obs.cell_type.unique().tolist()

n_children = []
n_parents = []

for cell_type in cell_types:
    n_parents.append(len(find_parent_nodes(cell_type)))
    n_children.append(len(find_child_nodes(cell_type)))

cell_types_to_remove += (
    pd.DataFrame({'n_children': n_children, 'n_parents': n_parents}, index=cell_types)
    .query('n_parents <= 7')
    .index.tolist()
)
cell_types_to_remove = list(set(cell_types_to_remove))

In [9]:
obs.cell_type.nunique() - len(cell_types_to_remove)

157

In [10]:
obs_subset = obs[~obs.cell_type.isin(cell_types_to_remove)].copy()
for col in obs_subset.columns:
    if obs_subset[col].dtype == 'category':
        obs_subset[col] = obs_subset[col].cat.remove_unused_categories()
obs_subset

Unnamed: 0,soma_joinid,is_primary_data,dataset_id,donor_id,assay,cell_type,development_stage,disease,tissue,tissue_general,tech_sample
0,0,True,b0e547f0-462b-4f81-b31b-5b0a5d96f537,SG_HEL_H02a,10x 5' v2,"CD16-positive, CD56-dim natural killer cell, h...",57-year-old human stage,normal,blood,blood,b0e547f0-462b-4f81-b31b-5b0a5d96f537_SG_HEL_H02a
1,1,True,b0e547f0-462b-4f81-b31b-5b0a5d96f537,SG_HEL_H02a,10x 5' v2,"CD16-positive, CD56-dim natural killer cell, h...",57-year-old human stage,normal,blood,blood,b0e547f0-462b-4f81-b31b-5b0a5d96f537_SG_HEL_H02a
2,2,True,b0e547f0-462b-4f81-b31b-5b0a5d96f537,SG_HEL_H02a,10x 5' v2,CD14-positive monocyte,57-year-old human stage,normal,blood,blood,b0e547f0-462b-4f81-b31b-5b0a5d96f537_SG_HEL_H02a
3,3,True,b0e547f0-462b-4f81-b31b-5b0a5d96f537,SG_HEL_H02a,10x 5' v2,CD14-positive monocyte,57-year-old human stage,normal,blood,blood,b0e547f0-462b-4f81-b31b-5b0a5d96f537_SG_HEL_H02a
4,4,True,b0e547f0-462b-4f81-b31b-5b0a5d96f537,SG_HEL_H02a,10x 5' v2,"CD8-positive, alpha-beta memory T cell",57-year-old human stage,normal,blood,blood,b0e547f0-462b-4f81-b31b-5b0a5d96f537_SG_HEL_H02a
...,...,...,...,...,...,...,...,...,...,...,...
26605250,50248977,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,pericyte,51-year-old human stage,normal,lung,lung,8c42cfd0-0b0a-46d5-910c-fc833d83c45e_3
26605251,50248978,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,pericyte,51-year-old human stage,normal,lung,lung,8c42cfd0-0b0a-46d5-910c-fc833d83c45e_3
26605252,50248979,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,pericyte,51-year-old human stage,normal,lung,lung,8c42cfd0-0b0a-46d5-910c-fc833d83c45e_3
26605253,50248980,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,pericyte,51-year-old human stage,normal,lung,lung,8c42cfd0-0b0a-46d5-910c-fc833d83c45e_3


In [11]:
cell_types_to_keep = obs_subset.cell_type.unique().tolist()

# Download data

In [14]:
protein_coding_genes = pd.read_parquet('features.parquet').gene_names.tolist()

In [15]:
BASE_PATH = '/mnt/dssfs02/cxg_census/slices'


# download in batches to not run out of memory
for i, idxs in tqdm(enumerate(np.array_split(obs_subset.soma_joinid.to_numpy(), 20))):
    adata = cellxgene_census.get_anndata(
        census=census,
        organism="Homo sapiens",
        X_name='raw',
        obs_coords=idxs.tolist(),
        var_value_filter=f"feature_name in {protein_coding_genes}",
        column_names={"obs": COLUMN_NAMES, "var": ['feature_id', 'feature_name']},
    )
    adata.write_h5ad(join(BASE_PATH, f'{i}.h5ad'))


# Store data

## Convert to zarr + DataFrame

In [14]:
BASE_PATH = '/mnt/dssfs02/cxg_census/slices'

In [12]:
import os
from os.path import join

import anndata
import dask
import dask.array as da

from scipy.sparse import csr_matrix

dask.config.set(scheduler='threads');

In [21]:
def read_X(path):
    return anndata.read_h5ad(path).X


def read_obs(path):
    obs = anndata.read_h5ad(path, backed='r').obs
    obs['tech_sample'] = obs.dataset_id.astype(str) + '_' + obs.donor_id.astype(str)
    return obs


def read_var(path):
    return anndata.read_h5ad(path, backed='r').var


files = [
    join(BASE_PATH, file) for file 
    in sorted(os.listdir(BASE_PATH), key=lambda x: int(x.split('.')[0])) 
    if file.endswith('.h5ad')
]

# read obs
print('Loading obs...')
obs = pd.concat([read_obs(file) for file in files]).reset_index(drop=True)
for col in obs.columns:
    if obs[col].dtype == object:
        obs[col] = obs[col].astype('category')
        obs[col].cat.remove_unused_categories()
# read var
print('Loading var...')
var = read_var(files[0])
# read X
print('Loading X...')
split_lens = [len(split) for split in np.array_split(obs_subset.soma_joinid.to_numpy(), 20)]
X = da.concatenate([
    da.from_delayed(dask.delayed(read_X)(file), (split_len, len(var)), dtype='f4') 
    for file, split_len in zip(files, split_lens)
]).persist()


In [22]:
X

Unnamed: 0,Array,Chunk
Bytes,1.49 TiB,76.26 GiB
Shape,"(21179590, 19331)","(1058980, 19331)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.49 TiB 76.26 GiB Shape (21179590, 19331) (1058980, 19331) Dask graph 20 chunks in 1 graph layer Data type float32 numpy.ndarray",19331  21179590,

Unnamed: 0,Array,Chunk
Bytes,1.49 TiB,76.26 GiB
Shape,"(21179590, 19331)","(1058980, 19331)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray


## Create train, val, test split

In [43]:
from statistics import mode
from scipy.sparse import csr_matrix

In [44]:
from math import ceil


def get_split(samples, val_split: float = 0.15, test_split: float = 0.15, seed = 1):
    rng = np.random.default_rng(seed=seed)

    samples = np.array(samples)
    rng.shuffle(samples)
    n_samples = len(samples)

    n_samples_val = ceil(val_split * n_samples)
    n_samples_test = ceil(test_split * n_samples)
    n_samples_train = n_samples - n_samples_val - n_samples_test

    return {
        'train': samples[:n_samples_train],
        'val': samples[n_samples_train:(n_samples_train + n_samples_val)],
        'test': samples[(n_samples_train + n_samples_val):]
    }


tech_samples_per_cell_type = obs[['cell_type', 'tech_sample']].groupby('cell_type').agg({'tech_sample': 'unique'})


splits = {'train': [], 'val': [], 'test': []}
for cell_type in tech_samples_per_cell_type.index:
    samples = tech_samples_per_cell_type.loc[cell_type, 'tech_sample']
    split = get_split(samples)
    for x in ['train', 'val', 'test']:
        assert len(split[x]) >= 1
        idxs = obs[(obs.tech_sample.isin(split[x])) & (obs.cell_type == cell_type)].index.tolist()
        splits[x] += idxs


In [45]:
splits['train'] = np.array(splits['train'])
splits['val'] = np.array(splits['val'])
splits['test'] = np.array(splits['test'])

splits

{'train': array([    211,     418,    1151, ..., 2578190, 2578201, 2578240]),
 'val': array([   5696,    5990,    6085, ..., 2578045, 2578065, 2578224]),
 'test': array([   2445,   15793,   17027, ..., 2578101, 2578133, 2578175])}

In [46]:
assert len(np.intersect1d(splits['train'], splits['val'])) == 0
assert len(np.intersect1d(splits['train'], splits['test'])) == 0
assert len(np.intersect1d(splits['val'], splits['test'])) == 0
assert len(np.intersect1d(splits['val'], splits['test'])) == 0

In [47]:
print(f"train: {len(obs.loc[splits['train'], :]):,}")
print(f"val: {len(obs.loc[splits['val'], :]):,}")
print(f"test: {len(obs.loc[splits['test'], :]):,}")

train: 14,843,199
val: 3,091,012
test: 3,245,379


In [48]:
print(f"train: {obs.loc[splits['train'], :].cell_type.nunique()}")
print(f"val: {obs.loc[splits['val'], :].cell_type.nunique()}")
print(f"test: {obs.loc[splits['test'], :].cell_type.nunique()}")

train: 157
val: 157
test: 157


In [49]:
rng = np.random.default_rng(seed=1)

splits['train'] = rng.permutation(splits['train'])
splits['val'] = rng.permutation(splits['val'])
splits['test'] = rng.permutation(splits['test'])

splits

{'train': array([ 1371708, 11273954, 15187976, ..., 16091352, 19464233, 18237670]),
 'val': array([11607099, 15186124, 13612104, ...,  3928036, 21129690,  8283678]),
 'test': array([15528101,  8082559,   697058, ..., 11167194,  5659546,   416708])}

## Save data

In [53]:
SAVE_PATH = '/mnt/dssfs02/cxg_census/data'
CHUNK_SIZE = 16384

In [54]:
for split, idxs in splits.items():
    # out-of-order indexing is on purpose here as we want to shuffle the data to break up data sets
    X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
    obs_split = obs.loc[idxs, :]
    
    save_dir = join(SAVE_PATH, split)
    os.makedirs(save_dir)
    
    var.to_parquet(path=join(save_dir, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)
    obs_split.to_parquet(path=join(save_dir, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)
    da.to_zarr(
        X_split.map_blocks(lambda xx: xx.toarray(), dtype='f4'),
        join(save_dir, 'zarr'),
        component='X',
        compute=True,
        compressor='default', 
        order='C'
    )


  X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
  X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
  X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
