In [1]:
# Define paths
DATA_PATH = '/mnt/dssfs02/cxg_census/data'
SAVE_PATH = '/mnt/dssfs02/cxg_census/data_normed'

# Read train data

In [2]:
import os

import numpy as np
import seaborn as sns
import dask.array as da
import pandas as pd

from scipy.sparse import csc_matrix, csr_matrix, issparse
from os.path import join

import dask
dask.config.set(scheduler='threads');

In [3]:
X = da.from_zarr(join(DATA_PATH, 'train', 'zarr'), 'X')
obs = pd.read_parquet(join(DATA_PATH, 'train', 'obs.parquet')).reset_index(drop=True)

In [4]:
X

Unnamed: 0,Array,Chunk
Bytes,1.04 TiB,1.18 GiB
Shape,"(14843199, 19331)","(16384, 19331)"
Dask graph,906 chunks in 2 graph layers,906 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.04 TiB 1.18 GiB Shape (14843199, 19331) (16384, 19331) Dask graph 906 chunks in 2 graph layers Data type float32 numpy.ndarray",19331  14843199,

Unnamed: 0,Array,Chunk
Bytes,1.04 TiB,1.18 GiB
Shape,"(14843199, 19331)","(16384, 19331)"
Dask graph,906 chunks in 2 graph layers,906 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [5]:
obs.head()

Unnamed: 0,soma_joinid,is_primary_data,dataset_id,donor_id,assay,cell_type,development_stage,disease,tissue,tissue_general,tech_sample
0,2597175,True,0ba636a1-4754-4786-a8be-7ab3cf760fd6,HBCA_Donor_18,10x 3' v3,myoepithelial cell of mammary gland,29-year-old human stage,normal,breast,breast,0ba636a1-4754-4786-a8be-7ab3cf760fd6_HBCA_Dono...
1,27253448,True,9ea768a2-87ab-46b6-a73d-c4e915f25af3,TxK4,10x 3' v2,epithelial cell of proximal tubule,72-year-old human stage,normal,cortex of kidney,kidney,9ea768a2-87ab-46b6-a73d-c4e915f25af3_TxK4
2,34231840,True,32b9bdce-2481-4c85-ba1b-6ad5fcea844c,32-10074,10x 3' v3,epithelial cell of proximal tubule,eighth decade human stage,acute kidney failure,kidney,kidney,32b9bdce-2481-4c85-ba1b-6ad5fcea844c_32-10074
3,26796846,True,ed5d841d-6346-47d4-ab2f-7119ad7e3a35,P1,10x 3' v3,"central memory CD8-positive, alpha-beta T cell",unknown,normal,blood,blood,ed5d841d-6346-47d4-ab2f-7119ad7e3a35_P1
4,3378141,True,c2876b1b-06d8-4d96-a56b-5304f815b99a,H21.33.012,10x 3' v3,L2/3-6 intratelencephalic projecting glutamate...,80 year-old and over human stage,dementia,middle temporal gyrus,brain,c2876b1b-06d8-4d96-a56b-5304f815b99a_H21.33.012


# Fit normalization + save model

In [2]:
import sklearn
sklearn.__version__

'1.2.2'

In [3]:
import numpy as np

from sklearn.preprocessing import Normalizer, QuantileTransformer, StandardScaler, FunctionTransformer
from sklearn.pipeline import Pipeline


from scipy.sparse import csc_matrix, csr_matrix, issparse
from sklearn.utils import sparsefuncs


def sf_normalize(X):
    X = X.copy()
    counts = np.array(X.sum(axis=1))
    # avoid zero devision error
    counts += counts == 0.
    # normalize to 10000. counts
    scaling_factor = 10000. / counts

    if issparse(X):
        sparsefuncs.inplace_row_scale(X, scaling_factor)
    else:
        np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)

    return X


In [8]:
WEIGHTED_SAMPLING = True
SAMPLE_SIZE = 1_000_000


rng = np.random.default_rng(seed=1)


if WEIGHTED_SAMPLING:
    obs['tech_sample_occurances'] = obs.tech_sample.replace(obs.tech_sample.value_counts().to_dict()).astype(float)
    obs['sampling_prob'] = 1. / obs.tech_sample_occurances
    obs['sampling_prob'] = obs.sampling_prob / obs.sampling_prob.sum()
    idx_subsample = rng.choice(obs.index.to_numpy(), size=SAMPLE_SIZE, replace=False, p=obs.sampling_prob.to_numpy())
else:
    idx_subsample = rng.choice(obs.index.to_numpy(), size=SAMPLE_SIZE, replace=False)


idx_subsample

array([ 7614642, 14104737,  2141988, ...,  8085558,   184839, 10901594])

In [9]:
len(np.unique(idx_subsample))

1000000

In [10]:
X_sparse = X[np.sort(idx_subsample), :].map_blocks(csc_matrix).compute().tocsc()

In [11]:
N_QUANTILES = 1000


preproc_pipeline = Pipeline([
    ('sf_normalizer', FunctionTransformer(sf_normalize)),
    (
        'quantile_transformer', 
        QuantileTransformer(n_quantiles=N_QUANTILES, output_distribution='uniform', ignore_implicit_zeros=True, subsample=len(idx_subsample))
    )
])
X_normed = preproc_pipeline.fit_transform(X_sparse)

In [12]:
feature_means = np.array(X_normed.mean(axis=0))

In [13]:
from os.path import join
import pickle

save_path_norm = join(SAVE_PATH, 'norm')
os.makedirs(save_path_norm)

# save fit data
os.makedirs(save_path_norm, exist_ok=True)
os.makedirs(join(save_path_norm, 'quantile_transform'), exist_ok=True)
os.makedirs(join(save_path_norm, 'zero_centering'))
np.save(join(save_path_norm, 'quantile_transform', 'quantiles.npy'), preproc_pipeline.steps[1][1].quantiles_)
np.save(join(save_path_norm, 'quantile_transform', 'references.npy'), preproc_pipeline.steps[1][1].references_)
np.save(join(save_path_norm, 'zero_centering', 'means.npy'), feature_means)

# save preproc pipeline
os.makedirs(join(save_path_norm, 'preproc_pipeline'), exist_ok=True)
pickle.dump(preproc_pipeline, open(join(save_path_norm, 'preproc_pipeline', 'preproc_pipeline.pickle'), 'wb'))

# Inference + save preprocessed data to disk

Restart notebook / dask scheduler for this to use processes scheduler for this part

In [4]:
from os.path import join
import shutil
from tqdm import tqdm
import os
import pickle

import numpy as np
import dask.array as da

In [5]:
preproc_pipeline = pickle.load(open(join(SAVE_PATH, 'norm', 'preproc_pipeline', 'preproc_pipeline.pickle'), 'rb'))

In [6]:
import dask
dask.config.set(scheduler='processes', num_workers=20)


def preprocess_gene_matrix(x):
    x_normed = preproc_pipeline.transform(x)
    x_normed = x_normed.astype(np.float32)

    return x_normed


for split in ['train', 'val', 'test']:
    
    path_raw = join(DATA_PATH, split)
    save_path = join(SAVE_PATH, split)
    os.makedirs(save_path)
    
    for file in ['obs.parquet', 'var.parquet']:
        shutil.copy(join(path_raw, file), join(save_path, file))
        
    da.to_zarr(
        da.from_zarr(join(path_raw, 'zarr'), component='X').map_blocks(preprocess_gene_matrix, dtype='f4'), 
        join(save_path, 'zarr'),
        component='X',
        compute=True,
        compressor='default', 
        order='C'
    )
