In [None]:
!pip install -q zarr

In [None]:
import os
from os.path import join

import anndata
import dask
import dask.array as da
import pandas as pd
import numpy as np

from scipy.sparse import csr_matrix

dask.config.set(scheduler='threads');

In [None]:
def read_X(path):
    return anndata.read_h5ad(path).X


def read_obs(path):
    obs = anndata.read_h5ad(path, backed='r').obs
    obs['tech_sample'] = obs.dataset_id.astype(str) + '_' + obs.donor_id.astype(str)
    return obs


def read_var(path):
    return anndata.read_h5ad(path, backed='r').var


## Training data

In [None]:
BASE_PATH = '/mnt/dssfs02/cxg_census/h5ad_raw_2023_05_15'

### Convert to zarr + DataFrame

In [None]:
files = [
    join(BASE_PATH, file) for file 
    in sorted(os.listdir(BASE_PATH), key=lambda x: int(x.split('.')[0])) 
    if file.endswith('.h5ad')
]

# read obs
print('Loading obs...')
obs = pd.concat([read_obs(file) for file in files]).reset_index(drop=True)
for col in obs.columns:
    if obs[col].dtype == object:
        obs[col] = obs[col].astype('category')
        obs[col].cat.remove_unused_categories()
# read var
print('Loading var...')
var = read_var(files[0])
# read X
print('Loading X...')
split_lens = [len(split) for split in np.array_split(obs.soma_joinid.to_numpy(), 20)]
X = da.concatenate([
    da.from_delayed(dask.delayed(read_X)(file), (split_len, len(var)), dtype='f4') 
    for file, split_len in zip(files, split_lens)
]).rechunk((32768, -1)).persist()


In [None]:
X

### Create train, val, test split

In [None]:
from statistics import mode
from scipy.sparse import csr_matrix

In [None]:
from math import ceil


def get_split(samples, val_split: float = 0.15, test_split: float = 0.15, seed=1):
    rng = np.random.default_rng(seed=seed)

    samples = np.array(samples)
    rng.shuffle(samples)
    n_samples = len(samples)

    n_samples_val = ceil(val_split * n_samples)
    n_samples_test = ceil(test_split * n_samples)
    n_samples_train = n_samples - n_samples_val - n_samples_test

    return {
        'train': samples[:n_samples_train],
        'val': samples[n_samples_train:(n_samples_train + n_samples_val)],
        'test': samples[(n_samples_train + n_samples_val):]
    }


def subset(splits, frac):
    assert 0. < frac <= 1.
    if frac == 1.:
        return splits
    else:
        return splits[:ceil(frac * len(splits))]


In [None]:
# subsample_fracs: 0.15, 0.3, 0.5, 0.7, 1.
SUBSAMPLE_FRAC = 1.

In [None]:
splits = {'train': [], 'val': [], 'test': []}
tech_sample_splits = get_split(obs.tech_sample.unique().tolist())
for x in ['train', 'val', 'test']:
    # tech_samples are already shuffled in the get_split method -> just subselect to subsample donors
    if x == 'train':
        # only subset training data set
        splits[x] = obs[obs.tech_sample.isin(subset(tech_sample_splits[x], SUBSAMPLE_FRAC))].index.to_numpy()
    else:
        splits[x] = obs[obs.tech_sample.isin(tech_sample_splits[x])].index.to_numpy()

splits

In [None]:
assert len(np.intersect1d(splits['train'], splits['val'])) == 0
assert len(np.intersect1d(splits['train'], splits['test'])) == 0
assert len(np.intersect1d(splits['val'], splits['train'])) == 0
assert len(np.intersect1d(splits['val'], splits['test'])) == 0

In [None]:
print(f"train: {len(obs.loc[splits['train'], :]):,} cells")
print(f"val: {len(obs.loc[splits['val'], :]):,} cells")
print(f"test: {len(obs.loc[splits['test'], :]):,} cells")

In [None]:
print(f"train: {len(np.unique(obs.loc[splits['train'], 'cell_type']))} celltypes")
print(f"val: {len(np.unique(obs.loc[splits['val'], 'cell_type']))} celltypes")
print(f"test: {len(np.unique(obs.loc[splits['test'], 'cell_type']))} celltypes")

In [None]:
print(f"train: {len(np.unique(obs.loc[splits['train'], 'tech_sample']))} donors")
print(f"val: {len(np.unique(obs.loc[splits['val'], 'tech_sample']))} donors")
print(f"test: {len(np.unique(obs.loc[splits['test'], 'tech_sample']))} donors")

In [None]:
rng = np.random.default_rng(seed=1)

splits['train'] = rng.permutation(splits['train'])
splits['val'] = rng.permutation(splits['val'])
splits['test'] = rng.permutation(splits['test'])

splits

### Save data

In [None]:
SAVE_PATH = f'/mnt/dssfs02/cxg_census/data_2023_05_15'
if SUBSAMPLE_FRAC < 1.:
    SAVE_PATH = SAVE_PATH + f'_subsample_{round(SUBSAMPLE_FRAC * 100)}'

CHUNK_SIZE = 16384

In [None]:
if SUBSAMPLE_FRAC < 1.:
    # only save train data for subset stores
    # val + test can be copyed later from non subset store
    splits_to_save = ['train']
else:
    splits_to_save = ['train', 'val', 'test']


for split, idxs in splits.items():
    if split in splits_to_save:
        # out-of-order indexing is on purpose here as we want to shuffle the data to break up data sets
        X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))
        obs_split = obs.loc[idxs, :]

        save_dir = join(SAVE_PATH, split)
        os.makedirs(save_dir)

        var.to_parquet(path=join(save_dir, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)
        obs_split.to_parquet(path=join(save_dir, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)
        da.to_zarr(
            X_split.map_blocks(lambda xx: xx.toarray(), dtype='f4'),
            join(save_dir, 'zarr'),
            component='X',
            compute=True,
            compressor='default', 
            order='C'
        )
