In [1]:
import anndata
import pandas as pd
import numpy as np
import scanpy as sp
import zarr
import os

In [2]:
dat_dir = '/bigstore/GeneralStorage/fangming/projects/dredfish/data/'
output = os.path.join(dat_dir, 'rna', 'scrna_ss_ctxhippo_a_exon_count_matrix_v4.zarr')
print(output)

/bigstore/GeneralStorage/fangming/projects/dredfish/data/rna/scrna_ss_ctxhippo_a_exon_count_matrix_v4.zarr


In [3]:
%%time
f = '/bigstore/GeneralStorage/fangming/projects/dredfish/data/rna/scrna_ss_ctxhippo_a_exon_count_matrix_v2.h5ad'
adata = anndata.read(f, backed=None)
adata

CPU times: user 1.31 s, sys: 3.84 s, total: 5.15 s
Wall time: 5.15 s


AnnData object with n_obs × n_vars = 73347 × 45768
    obs: 'donor_sex_id', 'donor_sex_label', 'donor_sex_color', 'region_id', 'region_label', 'region_color', 'platform_label', 'cluster_order', 'cluster_label', 'cluster_color', 'subclass_order', 'subclass_label', 'subclass_color', 'neighborhood_id', 'neighborhood_label', 'neighborhood_color', 'class_order', 'class_label', 'class_color', 'exp_component_name', 'external_donor_name_label', 'full_genotype_label', 'facs_population_plan_label', 'injection_roi_label', 'injection_materials_label', 'injection_method_label', 'injection_type_label', 'full_genotype_id', 'full_genotype_color', 'external_donor_name_id', 'external_donor_name_color', 'facs_population_plan_id', 'facs_population_plan_color', 'injection_materials_id', 'injection_materials_color', 'injection_method_id', 'injection_method_color', 'injection_roi_id', 'injection_roi_color', 'injection_type_id', 'injection_type_color', 'cell_type_accession_label', 'cell_type_alias_label', 'ce

In [4]:
%%time
chunksize = 10
cells = adata.obs.index.values
genes = adata.var.index.values
ncells = len(cells)
ngenes = len(genes)

# create a zarr file (dir)
z1 = zarr.open(
          output, 
          mode='w', 
    )
z1

CPU times: user 5.08 ms, sys: 63 µs, total: 5.14 ms
Wall time: 169 ms


<zarr.hierarchy.Group '/'>

In [5]:
%%time
# create count matrix
z1.create_dataset('counts', 
                  shape=(ncells, ngenes), 
                  chunks=(chunksize, None), 
                  dtype='i4',
                 )

z1['counts'][:] = np.array(adata.X.toarray())

CPU times: user 1min 7s, sys: 1min 56s, total: 3min 3s
Wall time: 3min 13s


In [6]:
# genes
z1['genes'] = genes.astype(str)

In [7]:
# create labels
z1.create_dataset(
    'l1_code',
    shape=(ncells,), 
    chunks=(chunksize,), 
    dtype='i4',
)
a = adata.obs['class_label']
z1['l1_code'] = a.cat.codes.values
z1['l1_cat'] = a.cat.categories.values.astype(str)

z1.create_dataset(
    'l2_code',
    shape=(ncells,), 
    chunks=(chunksize,), 
    dtype='i4',
)
a = adata.obs['neighborhood_label']
z1['l2_code'] = a.cat.codes.values
z1['l2_cat'] = a.cat.categories.values.astype(str)

z1.create_dataset(
    'l3_code',
    shape=(ncells,), 
    chunks=(chunksize,), 
    dtype='i4',
)
a = adata.obs['subclass_label']
z1['l3_code'] = a.cat.codes.values
z1['l3_cat'] = a.cat.categories.values.astype(str)

z1.create_dataset(
    'l5_code',
    shape=(ncells,), 
    chunks=(chunksize,), 
    dtype='i4',
)
a = adata.obs['cluster_label']
z1['l5_code'] = a.cat.codes.values
z1['l5_cat'] = a.cat.categories.values.astype(str)

In [8]:
%%time
y = z1['counts'].oindex[975,:]
y, y.shape

CPU times: user 1.62 ms, sys: 2.68 ms, total: 4.3 ms
Wall time: 3.5 ms


(array([  0,   0, 242, ...,   0,   0,   0], dtype=int32), (45768,))

In [9]:
%%time
testidx = np.random.choice(ncells, 100)

CPU times: user 181 µs, sys: 300 µs, total: 481 µs
Wall time: 452 µs


In [10]:
%%time
z1['counts'].oindex[testidx,:].shape

CPU times: user 1.62 s, sys: 80 ms, total: 1.7 s
Wall time: 1.71 s


(100, 45768)

In [11]:
%%time
z1['counts'].oindex[testidx[0],:].shape

CPU times: user 4.12 ms, sys: 367 µs, total: 4.48 ms
Wall time: 3.21 ms


(45768,)

In [12]:
%%time
z1['l3_code'].oindex[testidx[0]].shape

CPU times: user 0 ns, sys: 2.11 ms, total: 2.11 ms
Wall time: 2.17 ms


()

# split train and test

In [13]:
def split_train_test(zarr_file, keys_copy, keys_split, frac=0.9, random_seed=None):
    """randomly select frac vs 1-frac samples into training and test (validation) set.
    Save them as separate zarr files
    """
    assert frac <= 1 and frac >= 0
    # the original zarr file
    z = zarr.open(zarr_file, 'r')
    size = len(z['counts'])
    
    path_train = zarr_file.replace('.zarr', '_train.zarr')
    path_test = zarr_file.replace('.zarr', '_test.zarr')
    print(f"{zarr_file} -> \n{path_train} and \n{path_test}\n")
    if random_seed: np.random.seed(random_seed)
    cond_train = np.random.rand(size) < frac
    ntrain = cond_train.sum()
    ntest = (~cond_train).sum()
    print(f"{size}, {ntrain} ({ntrain/size:.3f}), {ntest} ({ntest/size:.3f})")
    
    z_train = zarr.open(path_train, mode='w')
    z_test = zarr.open(path_test, mode='w')
    for key in keys_copy:
        z_train[key] = z[key]
        z_test[key] = z[key]

    for key in keys_split:
        # train
        if z[key].ndim == 1:
            chunksize = (10,)
            dat = z[key].oindex[cond_train]
        elif z[key].ndim == 2:
            chunksize = (10,None)
            dat = z[key].oindex[cond_train,:]
        else:
            raise ValueError('unimplemented')
            
        z_train.create_dataset(key, shape=dat.shape, chunks=chunksize)
        z_train[key][:] = dat
        
        # test
        if z[key].ndim == 1:
            chunksize = (10,)
            dat = z[key].oindex[~cond_train]
        elif z[key].ndim == 2:
            chunksize = (10,None)
            dat = z[key].oindex[~cond_train,:]
        else:
            raise ValueError('unimplemented')
            
        z_test.create_dataset(key, shape=dat.shape, chunks=chunksize)
        z_test[key][:] = dat
        
    return 

In [14]:
z = zarr.open(output, mode='r')
z, list(z.keys())

(<zarr.hierarchy.Group '/' read-only>,
 ['counts',
  'genes',
  'l1_cat',
  'l1_code',
  'l2_cat',
  'l2_code',
  'l3_cat',
  'l3_code',
  'l5_cat',
  'l5_code'])

In [15]:
#
keys_copy = ['genes', 'l1_cat', 'l2_cat', 'l3_cat', 'l5_cat']
keys_split = [key for key in z.keys() if key not in keys_copy]
keys_copy, keys_split


(['genes', 'l1_cat', 'l2_cat', 'l3_cat', 'l5_cat'],
 ['counts', 'l1_code', 'l2_code', 'l3_code', 'l5_code'])

In [16]:
split_train_test(output, keys_copy, keys_split, frac=0.9)

/bigstore/GeneralStorage/fangming/projects/dredfish/data/rna/scrna_ss_ctxhippo_a_exon_count_matrix_v4.zarr -> 
/bigstore/GeneralStorage/fangming/projects/dredfish/data/rna/scrna_ss_ctxhippo_a_exon_count_matrix_v4_train.zarr and 
/bigstore/GeneralStorage/fangming/projects/dredfish/data/rna/scrna_ss_ctxhippo_a_exon_count_matrix_v4_test.zarr

73347, 66017 (0.900), 7330 (0.100)


# test if we can load train and val as well as the original

In [17]:
f_org = '/bigstore/GeneralStorage/fangming/projects/dredfish/data/rna/scrna_ss_ctxhippo_a_exon_count_matrix_v4.zarr'
f_trn = '/bigstore/GeneralStorage/fangming/projects/dredfish/data/rna/scrna_ss_ctxhippo_a_exon_count_matrix_v4_train.zarr'
f_tst = '/bigstore/GeneralStorage/fangming/projects/dredfish/data/rna/scrna_ss_ctxhippo_a_exon_count_matrix_v4_test.zarr'


In [18]:
z_trn = zarr.open(f_trn, mode='r')
z_tst = zarr.open(f_tst, mode='r')
z_trn, list(z_trn.keys()), z_tst, list(z_tst.keys())

(<zarr.hierarchy.Group '/' read-only>,
 ['counts',
  'genes',
  'l1_cat',
  'l1_code',
  'l2_cat',
  'l2_code',
  'l3_cat',
  'l3_code',
  'l5_cat',
  'l5_code'],
 <zarr.hierarchy.Group '/' read-only>,
 ['counts',
  'genes',
  'l1_cat',
  'l1_code',
  'l2_cat',
  'l2_code',
  'l3_cat',
  'l3_code',
  'l5_cat',
  'l5_code'])

In [19]:
z_trn['genes'][:]

array(['0610005C13Rik', '0610006L08Rik', '0610007P14Rik', ..., 'n-R5s144',
       'n-R5s146', 'n-R5s149'], dtype='<U28')

In [20]:
%%time
y = z_trn['counts'].oindex[975,:]
y, y.shape

CPU times: user 0 ns, sys: 9.34 ms, total: 9.34 ms
Wall time: 5.53 ms


(array([  5.,   0., 668., ...,   0.,   0.,   0.]), (45768,))

In [21]:
%%time
y = z_tst['counts'].oindex[975,:]
y, y.shape

CPU times: user 3.64 ms, sys: 2.04 ms, total: 5.68 ms
Wall time: 4.12 ms


(array([  0.,   0., 144., ...,   0.,   0.,   0.]), (45768,))

# check

In [22]:
z_org = zarr.open(f_org, mode='r')
z_trn = zarr.open(f_trn, mode='r')
z_tst = zarr.open(f_tst, mode='r')
z_trn, list(z_trn.keys()), z_tst, list(z_tst.keys())

(<zarr.hierarchy.Group '/' read-only>,
 ['counts',
  'genes',
  'l1_cat',
  'l1_code',
  'l2_cat',
  'l2_code',
  'l3_cat',
  'l3_code',
  'l5_cat',
  'l5_code'],
 <zarr.hierarchy.Group '/' read-only>,
 ['counts',
  'genes',
  'l1_cat',
  'l1_code',
  'l2_cat',
  'l2_code',
  'l3_cat',
  'l3_code',
  'l5_cat',
  'l5_code'])