In [1]:
import numpy as np
import pandas as pd
import dask.array as da
import os

from os.path import join

import dask
dask.config.set(scheduler='threads');

In [2]:
import sfaira



In [3]:
from sfaira.consts import OC

ontology = OC.cell_type

In [4]:
# store to subset
# DATA_PATH = '/lustre/groups/ml01/workspace/felix.fischer.2/sfaira/data/store/dao_512_cxg_primary'
# DATA_PATH = '/lustre/groups/ml01/workspace/david.fischer/sfairazero/data/store/dao_512_sfaira_norm'
DATA_PATH = '/lustre/scratch/users/felix.fischer/dao_1024_cxg_primary'

data_store = (
    sfaira.data
    .load_store(
        cache_path=join(DATA_PATH, 'dao'),
        store_format='dao', 
        columns=['id', 'assay_sc', 'tech_sample', 'cell_type', 'cell_type_ontology_term_id', 'disease', 'development_stage', 'organ']
    )
    .stores['Homo sapiens']
)
data_store.n_obs

13701695

In [5]:
obs = data_store.obs

In [6]:
var = data_store.var

In [7]:
datasets = list(obs.id.unique())

# Define train, val, test split

In [8]:
from sfaira.consts import OC

In [9]:
# only look at 10x based protocols
assays = [
    "10x 3' v2",
    "10x 5' v1",
    "10x 3' v3",
    "10x 5' v2",
    "10x 3' v1"
]

obs = obs[obs.assay_sc.isin(assays)]

In [10]:
obs.cell_type.nunique()

467

In [11]:
# remove all cell types which are not a subtype of native cell
cell_types_to_remove = obs[~obs.cell_type_ontology_term_id.isin(ontology.get_ancestors('native cell'))].cell_type.unique().tolist()
# remove all cell types which have less than 1000 cells
cell_freq = obs.cell_type.value_counts()
cell_types_to_remove += cell_freq[cell_freq < 1000].index.tolist()

# remove cell types which have less than 30 tech_samples
tech_samples_per_cell_type = obs[['cell_type', 'tech_sample']].groupby('cell_type').agg({'tech_sample': 'nunique'}).sort_values('tech_sample')
cell_types_to_remove += tech_samples_per_cell_type[tech_samples_per_cell_type.tech_sample <= 30].index.tolist()

# filter out too granular labels
# remove all cells that have <= 7 parents in the cell ontology
cell_ontology = OC.cell_type

cell_types = obs.cell_type.unique().tolist()

n_children = []
n_parents = []

for cell_type in cell_types:
    n_parents.append(len([cell_ontology.convert_to_name(node) for node in cell_ontology.get_descendants(cell_type)]))
    n_children.append(len([cell_ontology.convert_to_name(node) for node in cell_ontology.get_ancestors(cell_type)]))

cell_types_to_remove += (
    pd.DataFrame({'n_children': n_children, 'n_parents': n_parents}, index=cell_types)
    .query('n_parents <= 7')
    .index.tolist()
)
cell_types_to_remove = list(set(cell_types_to_remove))

In [12]:
obs.cell_type.nunique() - len(cell_types_to_remove)

128

In [17]:
(
    pd.DataFrame({'n_children': n_children, 'n_parents': n_parents}, index=cell_types)
    .sort_values('n_parents')
    .head(50)
)

Unnamed: 0,n_children,n_parents
cell in vitro,5,1
native cell,2504,1
eukaryotic cell,2206,2
sebaceous gland cell,2,2
contractile cell,169,2
precursor cell,262,2
somatic cell,2125,2
secretory cell,400,2
neoplastic cell,2,2
supporting cell,58,2


In [13]:
obs_subset = obs[~obs.cell_type.isin(cell_types_to_remove)].copy()
for col in obs_subset.columns:
    if obs_subset[col].dtype == 'category':
        obs_subset[col] = obs_subset[col].cat.remove_unused_categories()
obs_subset

Unnamed: 0,id,assay_sc,tech_sample,cell_type,cell_type_ontology_term_id,disease,development_stage,organ
0,0217420c-b31d-4f92-8edf-c6d113573963,10x 3' v2,0217420c-b31d-4f92-8edf-c6d11357396310x 3' v2_...,epithelial cell of proximal tubule,CL:0002306,normal,2-year-old human stage,cortex of kidney
1,0217420c-b31d-4f92-8edf-c6d113573963,10x 3' v2,0217420c-b31d-4f92-8edf-c6d11357396310x 3' v2_...,epithelial cell of proximal tubule,CL:0002306,normal,2-year-old human stage,cortex of kidney
2,0217420c-b31d-4f92-8edf-c6d113573963,10x 3' v2,0217420c-b31d-4f92-8edf-c6d11357396310x 3' v2_...,epithelial cell of proximal tubule,CL:0002306,normal,67-year-old human stage,kidney
3,0217420c-b31d-4f92-8edf-c6d113573963,10x 3' v2,0217420c-b31d-4f92-8edf-c6d11357396310x 3' v2_...,epithelial cell of proximal tubule,CL:0002306,normal,12-year-old human stage,cortex of kidney
4,0217420c-b31d-4f92-8edf-c6d113573963,10x 3' v2,0217420c-b31d-4f92-8edf-c6d11357396310x 3' v2_...,epithelial cell of proximal tubule,CL:0002306,normal,12-year-old human stage,renal medulla
...,...,...,...,...,...,...,...,...
13578479,f6a333c8-6442-4561-85e6-e6333c30e658,10x 3' v2,f6a333c8-6442-4561-85e6-e6333c30e65810x 3' v2_...,endothelial cell,CL:0000115,normal,13th week post-fertilization human stage,kidney
13578488,f6a333c8-6442-4561-85e6-e6333c30e658,10x 3' v2,f6a333c8-6442-4561-85e6-e6333c30e65810x 3' v2_...,conventional dendritic cell,CL:0000990,normal,13th week post-fertilization human stage,kidney
13578493,f6a333c8-6442-4561-85e6-e6333c30e658,10x 3' v2,f6a333c8-6442-4561-85e6-e6333c30e65810x 3' v2_...,macrophage,CL:0000235,normal,embryonic human stage,kidney
13578498,f6a333c8-6442-4561-85e6-e6333c30e658,10x 3' v2,f6a333c8-6442-4561-85e6-e6333c30e65810x 3' v2_...,macrophage,CL:0000235,normal,embryonic human stage,kidney


In [14]:
from math import ceil



def get_split(samples, val_split: float = 0.15, test_split: float = 0.15, seed = 1):
    rng = np.random.default_rng(seed=seed)

    samples = np.array(samples)
    rng.shuffle(samples)
    n_samples = len(samples)

    n_samples_val = ceil(val_split * n_samples)
    n_samples_test = ceil(test_split * n_samples)
    n_samples_train = n_samples - n_samples_val - n_samples_test

    return {
        'train': samples[:n_samples_train],
        'val': samples[n_samples_train:(n_samples_train + n_samples_val)],
        'test': samples[(n_samples_train + n_samples_val):]
    }


tech_samples_per_cell_type = obs_subset[['cell_type', 'tech_sample']].groupby('cell_type').agg({'tech_sample': 'unique'})


splits = {'train': [], 'val': [], 'test': []}
for cell_type in tech_samples_per_cell_type.index:
    samples = tech_samples_per_cell_type.loc[cell_type, 'tech_sample']
    split = get_split(samples)
    for x in ['train', 'val', 'test']:
        assert len(split[x]) >= 1
        idxs = obs_subset[(obs_subset.tech_sample.isin(split[x])) & (obs_subset.cell_type == cell_type)].index.tolist()
        splits[x] += idxs



In [15]:
splits['train'] = np.array(splits['train'])
splits['val'] = np.array(splits['val'])
splits['test'] = np.array(splits['test'])

splits

{'train': array([       7,       29,       44, ..., 13505969, 13505985, 13506049]),
 'val': array([    9122,     9544,    40293, ..., 13505735, 13505968, 13505972]),
 'test': array([     478,      597,     1023, ..., 13505045, 13505691, 13505831])}

In [16]:
assert len(np.intersect1d(splits['train'], splits['val'])) == 0
assert len(np.intersect1d(splits['train'], splits['test'])) == 0
assert len(np.intersect1d(splits['val'], splits['test'])) == 0
assert len(np.intersect1d(splits['val'], splits['test'])) == 0

In [17]:
print(f"train: {len(obs_subset.loc[splits['train'], :]):,}")
print(f"val: {len(obs_subset.loc[splits['val'], :]):,}")
print(f"test: {len(obs_subset.loc[splits['test'], :]):,}")

train: 6,632,427
val: 1,534,201
test: 1,404,739


In [18]:
print(f"train: {obs_subset.loc[splits['train'], :].cell_type.nunique()}")
print(f"val: {obs_subset.loc[splits['val'], :].cell_type.nunique()}")
print(f"test: {obs_subset.loc[splits['test'], :].cell_type.nunique()}")

train: 128
val: 128
test: 128


In [19]:
rng = np.random.default_rng(seed=1)

splits['train'] = rng.permutation(splits['train'])
splits['val'] = rng.permutation(splits['val'])
splits['test'] = rng.permutation(splits['test'])

splits

{'train': array([ 2222335, 13059830,  1151732, ...,  5730285, 10563159,  9712153]),
 'val': array([  967154,  3116254,  3366389, ...,  6695963, 11375004, 11774447]),
 'test': array([10653924,  1463170,  7199781, ..., 10798760,  4177025,    81796])}

# Create train, val, test split

In [20]:
from statistics import mode
from scipy.sparse import csr_matrix

In [21]:
X = data_store.x.map_blocks(csr_matrix).persist()
obs = data_store.obs
var = data_store.var

In [22]:
from sfaira.versions.genomes.genomes import GenomeContainer

genome_container = GenomeContainer(organism='Homo sapiens', release='104')
var_names = var.index.to_numpy().tolist()
gene_names = genome_container.translate_id_to_symbols(var_names)
var['gene_names'] = gene_names

In [23]:
SAVE_PATH = '/lustre/scratch/users/felix.fischer/merlin_cxg_simple'
CHUNK_SIZE = 1024

In [24]:
for split, idxs in splits.items():
    X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
    obs_split = obs.loc[idxs, :]
    
    save_dir = join(SAVE_PATH, split)
    os.makedirs(save_dir)
    
    var.to_parquet(path=join(save_dir, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)
    obs_split.to_parquet(path=join(save_dir, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)
    da.to_zarr(
        X_split.map_blocks(lambda xx: xx.toarray(), dtype='f4'),
        join(save_dir, 'zarr'),
        component='X',
        compute=True,
        compressor='default', 
        order='C'
    )


  X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
  X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
  X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
