In [None]:
!pip install -q cellxgene-census

In [2]:
!pip install -q obonet


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [3]:
!pip install -q zarr


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [1]:
from os.path import join

import pandas as pd
import numpy as np

from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Utils code

In [2]:
import obonet
import networkx


url = 'http://purl.obolibrary.org/obo/cl/cl-simple.obo'
graph = obonet.read_obo(url, ignore_obsolete=True)

# only use "is_a" edges
edges_to_delete = []
for i, x in enumerate(graph.edges):
    if x[2] != 'is_a':
        edges_to_delete.append((x[0], x[1]))
for x in edges_to_delete:
    graph.remove_edge(u=x[0], v=x[1])

# define mapping from id to name
id_to_name = {id_: data.get('name') for id_, data in graph.nodes(data=True)}
# define inverse mapping from name to id
name_to_id = {v: k for k, v in id_to_name.items()}


def find_child_nodes(cell_type):
    return [id_to_name[node] for node in networkx.ancestors(graph, name_to_id[cell_type])]


def find_parent_nodes(cell_type):
    return [id_to_name[node] for node in networkx.descendants(graph, name_to_id[cell_type])]


# Select data to download

In [3]:
import cellxgene_census

census = cellxgene_census.open_soma(census_version="2023-05-15")

In [4]:
PROTOCOLS = [
    "10x 5' v2", 
    "10x 3' v3", 
    "10x 3' v2", 
    "10x 5' v1", 
    "10x 3' v1", 
    "10x 3' transcription profiling", 
    "10x 5' transcription profiling"
]


COLUMN_NAMES = [
    "soma_joinid",
    "is_primary_data",
    "dataset_id", 
    "donor_id",
    "assay", 
    "cell_type", 
    "development_stage", 
    "disease", 
    "tissue", 
    "tissue_general"
]

### Select data for training

In [5]:
obs = (
    census["census_data"]["homo_sapiens"]
    .obs
    .read(
        column_names=COLUMN_NAMES,
        value_filter=f"is_primary_data == True and assay in {PROTOCOLS}"
    )
    .concat()
    .to_pandas()
)

In [6]:
obs['tech_sample'] = (obs.dataset_id + '_' + obs.donor_id).astype('category')

for col in COLUMN_NAMES:
    if obs[col].dtype == object:
        obs[col] = obs[col].astype('category')


In [7]:
obs.dtypes

soma_joinid             int64
is_primary_data          bool
dataset_id           category
donor_id             category
assay                category
cell_type            category
development_stage    category
disease              category
tissue               category
tissue_general       category
tech_sample          category
dtype: object

In [8]:
# remove all cell types which are not a subtype of native cell
cell_types_to_remove = obs[~obs.cell_type.isin(find_child_nodes('native cell'))].cell_type.unique().tolist()

# remove all cell types which have less than 5000 cells
cell_freq = obs.cell_type.value_counts()
cell_types_to_remove += cell_freq[cell_freq < 5000].index.tolist()

# remove cell types which have less than 30 tech_samples
tech_samples_per_cell_type = obs[['cell_type', 'tech_sample']].groupby('cell_type').agg({'tech_sample': 'nunique'}).sort_values('tech_sample')
cell_types_to_remove += tech_samples_per_cell_type[tech_samples_per_cell_type.tech_sample <= 30].index.tolist()

# filter out too granular labels
# remove all cells that have <= 7 parents in the cell ontology
cell_types = obs.cell_type.unique().tolist()

n_children = []
n_parents = []

for cell_type in cell_types:
    n_parents.append(len(find_parent_nodes(cell_type)))
    n_children.append(len(find_child_nodes(cell_type)))

cell_types_to_remove += (
    pd.DataFrame({'n_children': n_children, 'n_parents': n_parents}, index=cell_types)
    .query('n_parents <= 7')
    .index.tolist()
)
cell_types_to_remove = list(set(cell_types_to_remove))

In [9]:
obs.cell_type.nunique() - len(cell_types_to_remove)

164

In [10]:
obs_subset = obs[~obs.cell_type.isin(cell_types_to_remove)].copy()
for col in obs_subset.columns:
    if obs_subset[col].dtype == 'category':
        obs_subset[col] = obs_subset[col].cat.remove_unused_categories()
obs_subset

Unnamed: 0,soma_joinid,is_primary_data,dataset_id,donor_id,assay,cell_type,development_stage,disease,tissue,tissue_general,tech_sample
15,15,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,macrophage,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue,9d8e5dca-03a3-457d-b7fb-844c75735c83_donor-GOLD
16,16,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,macrophage,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue,9d8e5dca-03a3-457d-b7fb-844c75735c83_donor-GOLD
18,18,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,endothelial cell,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue,9d8e5dca-03a3-457d-b7fb-844c75735c83_donor-GOLD
19,19,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,macrophage,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue,9d8e5dca-03a3-457d-b7fb-844c75735c83_donor-GOLD
20,20,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,macrophage,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue,9d8e5dca-03a3-457d-b7fb-844c75735c83_donor-GOLD
...,...,...,...,...,...,...,...,...,...,...,...
28169969,53794723,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,pericyte,51-year-old human stage,normal,lung,lung,8c42cfd0-0b0a-46d5-910c-fc833d83c45e_3
28169970,53794724,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,pericyte,51-year-old human stage,normal,lung,lung,8c42cfd0-0b0a-46d5-910c-fc833d83c45e_3
28169971,53794725,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,pericyte,51-year-old human stage,normal,lung,lung,8c42cfd0-0b0a-46d5-910c-fc833d83c45e_3
28169972,53794726,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,pericyte,51-year-old human stage,normal,lung,lung,8c42cfd0-0b0a-46d5-910c-fc833d83c45e_3


In [11]:
cell_types_to_keep = obs_subset.cell_type.unique().tolist()

### Select data for out-of-distribution evaluation

In [12]:
obs_ood = (
    census["census_data"]["homo_sapiens"]
    .obs
    .read(
        column_names=COLUMN_NAMES,
        value_filter=f"is_primary_data == True and assay in {PROTOCOLS} and cell_type in {cell_types_to_remove}"
    )
    .concat()
    .to_pandas()
)

In [13]:
# filter out too granular labels
# remove all cells that have <= 7 parents in the cell ontology
cell_types = obs.cell_type.unique().tolist()

n_children = []
n_parents = []

for cell_type in cell_types:
    n_parents.append(len(find_parent_nodes(cell_type)))
    n_children.append(len(find_child_nodes(cell_type)))

cell_types_to_remove = (
    pd.DataFrame({'n_children': n_children, 'n_parents': n_parents}, index=cell_types)
    .query('n_parents <= 7')
    .index.tolist()
)
cell_types_to_remove = list(set(cell_types_to_remove))

In [14]:
obs_ood_subset = obs_ood[~obs_ood.cell_type.isin(cell_types_to_remove)].copy()
for col in obs_ood_subset.columns:
    if obs_ood_subset[col].dtype == 'category':
        obs_ood_subset[col] = obs_ood_subset[col].cat.remove_unused_categories()
obs_ood_subset

Unnamed: 0,soma_joinid,is_primary_data,dataset_id,donor_id,assay,cell_type,development_stage,disease,tissue,tissue_general
19,38,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,fibro/adipogenic progenitor cell,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue
22,44,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,fibro/adipogenic progenitor cell,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue
23,49,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,fibro/adipogenic progenitor cell,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue
24,50,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,fibro/adipogenic progenitor cell,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue
27,56,True,9d8e5dca-03a3-457d-b7fb-844c75735c83,donor-GOLD,10x 3' v3,fibro/adipogenic progenitor cell,53-year-old human stage,normal,subcutaneous abdominal adipose tissue,adipose tissue
...,...,...,...,...,...,...,...,...,...,...
5979347,53794324,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,bronchial smooth muscle cell,51-year-old human stage,normal,lung,lung
5979348,53794325,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,bronchial smooth muscle cell,51-year-old human stage,normal,lung,lung
5979349,53794326,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,bronchial smooth muscle cell,51-year-old human stage,normal,lung,lung
5979350,53794327,True,8c42cfd0-0b0a-46d5-910c-fc833d83c45e,3,10x 3' v2,bronchial smooth muscle cell,51-year-old human stage,normal,lung,lung


# Download data

In [12]:
protein_coding_genes = pd.read_parquet('features.parquet').gene_names.tolist()

### Download training data

In [None]:
BASE_PATH = '/mnt/dssfs02/cxg_census/h5ad_raw_2023_05_15'


# download in batches to not run out of memory
for i, idxs in tqdm(enumerate(np.array_split(obs_subset.soma_joinid.to_numpy(), 20))):
    adata = cellxgene_census.get_anndata(
        census=census,
        organism="Homo sapiens",
        X_name='raw',
        obs_coords=idxs.tolist(),
        var_value_filter=f"feature_name in {protein_coding_genes}",
        column_names={"obs": COLUMN_NAMES, "var": ['feature_id', 'feature_name']},
    )
    adata.write_h5ad(join(BASE_PATH, f'{i}.h5ad'))


19it [19:39:25, 3355.84s/it]

#### Download out-of-distribution data

In [None]:
BASE_PATH = '/mnt/dssfs02/cxg_census/h5ad_raw_2023_05_15_ood'


# download in batches to not run out of memory
for i, idxs in tqdm(enumerate(np.array_split(obs_ood_subset.soma_joinid.to_numpy(), 2))):
    adata = cellxgene_census.get_anndata(
        census=census,
        organism="Homo sapiens",
        X_name='raw',
        obs_coords=idxs.tolist(),
        var_value_filter=f"feature_name in {protein_coding_genes}",
        column_names={"obs": COLUMN_NAMES, "var": ['feature_id', 'feature_name']},
    )
    adata.write_h5ad(join(BASE_PATH, f'{i}.h5ad'))


# Store data

In [15]:
import os
from os.path import join

import anndata
import dask
import dask.array as da

from scipy.sparse import csr_matrix

dask.config.set(scheduler='threads');

In [16]:
def read_X(path):
    return anndata.read_h5ad(path).X


def read_obs(path):
    obs = anndata.read_h5ad(path, backed='r').obs
    obs['tech_sample'] = obs.dataset_id.astype(str) + '_' + obs.donor_id.astype(str)
    return obs


def read_var(path):
    return anndata.read_h5ad(path, backed='r').var


## Training data

In [17]:
BASE_PATH = '/mnt/dssfs02/cxg_census/h5ad_raw_2023_05_15'

### Convert to zarr + DataFrame

In [18]:
files = [
    join(BASE_PATH, file) for file 
    in sorted(os.listdir(BASE_PATH), key=lambda x: int(x.split('.')[0])) 
    if file.endswith('.h5ad')
]

# read obs
print('Loading obs...')
obs = pd.concat([read_obs(file) for file in files]).reset_index(drop=True)
for col in obs.columns:
    if obs[col].dtype == object:
        obs[col] = obs[col].astype('category')
        obs[col].cat.remove_unused_categories()
# read var
print('Loading var...')
var = read_var(files[0])
# read X
print('Loading X...')
split_lens = [len(split) for split in np.array_split(obs_subset.soma_joinid.to_numpy(), 20)]
X = da.concatenate([
    da.from_delayed(dask.delayed(read_X)(file), (split_len, len(var)), dtype='f4') 
    for file, split_len in zip(files, split_lens)
]).persist()


Loading obs...
Loading var...
Loading X...


In [19]:
X

Unnamed: 0,Array,Chunk
Bytes,1.56 TiB,79.90 GiB
Shape,"(22190622, 19331)","(1109532, 19331)"
Count,20 Tasks,20 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 1.56 TiB 79.90 GiB Shape (22190622, 19331) (1109532, 19331) Count 20 Tasks 20 Chunks Type float32 numpy.ndarray",19331  22190622,

Unnamed: 0,Array,Chunk
Bytes,1.56 TiB,79.90 GiB
Shape,"(22190622, 19331)","(1109532, 19331)"
Count,20 Tasks,20 Chunks
Type,float32,numpy.ndarray


### Create train, val, test split

In [21]:
from statistics import mode
from scipy.sparse import csr_matrix

In [39]:
from math import ceil


def get_split(samples, val_split: float = 0.15, test_split: float = 0.15, seed = 1):
    rng = np.random.default_rng(seed=seed)

    samples = np.array(samples)
    rng.shuffle(samples)
    n_samples = len(samples)

    n_samples_val = ceil(val_split * n_samples)
    n_samples_test = ceil(test_split * n_samples)
    n_samples_train = n_samples - n_samples_val - n_samples_test

    return {
        'train': samples[:n_samples_train],
        'val': samples[n_samples_train:(n_samples_train + n_samples_val)],
        'test': samples[(n_samples_train + n_samples_val):]
    }


In [40]:
splits = {'train': [], 'val': [], 'test': []}
tech_sample_splits = get_split(obs.tech_sample.unique().tolist())
for x in ['train', 'val', 'test']:
    splits[x] = obs[obs.tech_sample.isin(tech_sample_splits[x])].index.to_numpy()

splits

{'train': array([       0,        1,        2, ..., 22190619, 22190620, 22190621]),
 'val': array([   20620,    20621,    20622, ..., 22168031, 22168032, 22168033]),
 'test': array([    7740,     7741,     7742, ..., 22158930, 22158931, 22158932])}

In [41]:
assert len(np.intersect1d(splits['train'], splits['val'])) == 0
assert len(np.intersect1d(splits['train'], splits['test'])) == 0
assert len(np.intersect1d(splits['val'], splits['test'])) == 0
assert len(np.intersect1d(splits['val'], splits['test'])) == 0

In [42]:
print(f"train: {len(obs.loc[splits['train'], :]):,}")
print(f"val: {len(obs.loc[splits['val'], :]):,}")
print(f"test: {len(obs.loc[splits['test'], :]):,}")

train: 15,241,127
val: 3,500,170
test: 3,449,325


In [43]:
print(f"train: {len(np.unique(obs.loc[splits['train'], 'cell_type']))}")
print(f"val: {len(np.unique(obs.loc[splits['val'], 'cell_type']))}")
print(f"test: {len(np.unique(obs.loc[splits['test'], 'cell_type']))}")

train: 164
val: 164
test: 164


In [44]:
rng = np.random.default_rng(seed=1)

splits['train'] = rng.permutation(splits['train'])
splits['val'] = rng.permutation(splits['val'])
splits['test'] = rng.permutation(splits['test'])

splits

{'train': array([15380693, 14071968, 21476582, ...,  7830057,   639705,  3287217]),
 'val': array([12641949,   192203, 17223332, ..., 11466818, 14458591,  1666073]),
 'test': array([ 9719568,  9310073, 10488308, ...,  5888438,  6370338,  1240529])}

### Save data

In [48]:
SAVE_PATH = f'/mnt/dssfs02/cxg_census/data_2023_05_15'
CHUNK_SIZE = 16384

In [49]:
for split, idxs in splits.items():
    # out-of-order indexing is on purpose here as we want to shuffle the data to break up data sets
    X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
    obs_split = obs.loc[idxs, :]
    
    save_dir = join(SAVE_PATH, split)
    os.makedirs(save_dir)
    
    var.to_parquet(path=join(save_dir, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)
    obs_split.to_parquet(path=join(save_dir, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)
    da.to_zarr(
        X_split.map_blocks(lambda xx: xx.toarray(), dtype='f4'),
        join(save_dir, 'zarr'),
        component='X',
        compute=True,
        compressor='default', 
        order='C'
    )


  X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
  X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))


## Out-of-distribution data

In [26]:
BASE_PATH_OOD = '/mnt/dssfs02/cxg_census/h5ad_raw_2023_05_15_ood'

In [27]:
files_ood = [
    join(BASE_PATH_OOD, file) for file 
    in sorted(os.listdir(BASE_PATH_OOD), key=lambda x: int(x.split('.')[0])) 
    if file.endswith('.h5ad')
]

# read obs
print('Loading obs...')
obs_ood = pd.concat([read_obs(file) for file in files_ood]).reset_index(drop=True)
for col in obs_ood.columns:
    if obs_ood[col].dtype == object:
        obs_ood[col] = obs_ood[col].astype('category')
        obs_ood[col].cat.remove_unused_categories()
# read var
print('Loading var...')
var_ood = read_var(files_ood[0])
# read X
print('Loading X...')
split_lens = [len(split) for split in np.array_split(obs_ood_subset.soma_joinid.to_numpy(), 2)]
X_ood = da.concatenate([
    da.from_delayed(dask.delayed(read_X)(file), (split_len, len(var_ood)), dtype='f4') 
    for file, split_len in zip(files_ood, split_lens)
]).persist()

Loading obs...
Loading var...
Loading X...


In [29]:
SAVE_PATH = f'/mnt/dssfs02/cxg_census/data_2023_05_15_ood'
CHUNK_SIZE = 16384

In [33]:
os.makedirs(SAVE_PATH)

var_ood.to_parquet(path=join(SAVE_PATH, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)
obs_ood.to_parquet(path=join(SAVE_PATH, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)
da.to_zarr(
    X_ood.rechunk((CHUNK_SIZE, -1)).map_blocks(lambda xx: xx.toarray(), dtype='f4'),
    join(SAVE_PATH, 'zarr'),
    component='X',
    compute=True,
    compressor='default', 
    order='C'
)
