In [1]:
# Define paths
DATA_PATH = '/lustre/scratch/users/felix.fischer/merlin_cxg_simple'
SAVE_PATH = '/lustre/scratch/users/felix.fischer/merlin_cxg_simple_norm'

# Read train data

In [2]:
import os

import numpy as np
import seaborn as sns
import dask.array as da
import pandas as pd

from scipy.sparse import csc_matrix, csr_matrix, issparse
from os.path import join

import dask
dask.config.set(scheduler='threads');

In [3]:
X = da.from_zarr(join(DATA_PATH, 'train', 'zarr'), 'X')
obs = pd.read_parquet(join(DATA_PATH, 'train', 'obs.parquet')).reset_index(drop=True)

In [4]:
X

Unnamed: 0,Array,Chunk
Bytes,478.27 GiB,75.61 MiB
Shape,"(6632427, 19357)","(1024, 19357)"
Count,6478 Tasks,6477 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 478.27 GiB 75.61 MiB Shape (6632427, 19357) (1024, 19357) Count 6478 Tasks 6477 Chunks Type float32 numpy.ndarray",19357  6632427,

Unnamed: 0,Array,Chunk
Bytes,478.27 GiB,75.61 MiB
Shape,"(6632427, 19357)","(1024, 19357)"
Count,6478 Tasks,6477 Chunks
Type,float32,numpy.ndarray


In [5]:
obs.head()

Unnamed: 0,id,assay_sc,tech_sample,cell_type,cell_type_ontology_term_id,disease,development_stage,organ
0,1ba3d0cb-18e2-470c-92be-70ad3fa820fb,10x 5' v2,1ba3d0cb-18e2-470c-92be-70ad3fa820fb10x 5' v2_...,"CD8-positive, alpha-beta T cell",CL:0000625,COVID-19,49-year-old human stage,blood
1,ea762ec0-f8f2-4f16-ba53-6bf7464b879b,10x 3' v3,ea762ec0-f8f2-4f16-ba53-6bf7464b879b10x 3' v3_...,kidney connecting tubule epithelial cell,CL:1000768,normal,seventh decade human stage,cortex of kidney
2,1ba3d0cb-18e2-470c-92be-70ad3fa820fb,10x 5' v2,1ba3d0cb-18e2-470c-92be-70ad3fa820fb10x 5' v2_...,"CD8-positive, alpha-beta T cell",CL:0000625,COVID-19,35-year-old human stage,blood
3,48f8ad54-091a-41be-ac40-1ad6f13e6ca3,10x 3' v3,48f8ad54-091a-41be-ac40-1ad6f13e6ca310x 3' v3_...,monocyte,CL:0000576,malignant ovarian serous tumor,57-year-old human stage,right ovary
4,c0660294-0bef-4274-9525-6cc901975ea3,10x 3' v3,c0660294-0bef-4274-9525-6cc901975ea310x 3' v3_...,L2/3-6 intratelencephalic projecting glutamate...,CL:4023040,dementia,human adult stage,middle temporal gyrus


# Fit normalization + save model

In [5]:
import sklearn
sklearn.__version__

'1.2.1'

In [4]:
import numpy as np

from sklearn.preprocessing import Normalizer, QuantileTransformer, StandardScaler, FunctionTransformer
from sklearn.pipeline import Pipeline


from scipy.sparse import csc_matrix, csr_matrix, issparse
from sklearn.utils import sparsefuncs


def sf_normalize(X):
    X = X.copy()
    counts = np.array(X.sum(axis=1))
    # avoid zero devision error
    counts += counts == 0.
    # normalize to 10000. counts
    scaling_factor = 10000. / counts

    if issparse(X):
        sparsefuncs.inplace_row_scale(X, scaling_factor)
    else:
        np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)

    return X


In [9]:
WEIGHTED_SAMPLING = True
SAMPLE_SIZE = 1_000_000


rng = np.random.default_rng(seed=1)


if WEIGHTED_SAMPLING:
    obs['tech_sample_occurances'] = obs.tech_sample.replace(obs.tech_sample.value_counts().to_dict()).astype(float)
    obs['sampling_prob'] = 1. / obs.tech_sample_occurances
    obs['sampling_prob'] = obs.sampling_prob / obs.sampling_prob.sum()
    idx_subsample = rng.choice(obs.index.to_numpy(), size=SAMPLE_SIZE, replace=False, p=obs.sampling_prob.to_numpy())
else:
    idx_subsample = rng.choice(obs.index.to_numpy(), size=SAMPLE_SIZE, replace=False)


idx_subsample

array([3393313, 6300848,  959415, ..., 3812005, 1585100, 3691064])

In [10]:
len(np.unique(idx_subsample))

1000000

In [11]:
X_sparse = X[np.sort(idx_subsample), :].map_blocks(csc_matrix).compute().tocsc()

In [12]:
N_QUANTILES = 1000


preproc_pipeline = Pipeline([
    ('sf_normalizer', FunctionTransformer(sf_normalize)),
    (
        'quantile_transformer', 
        QuantileTransformer(n_quantiles=N_QUANTILES, output_distribution='uniform', ignore_implicit_zeros=True, subsample=len(idx_subsample))
    )
])
X_normed = preproc_pipeline.fit_transform(X_sparse)

In [13]:
feature_means = np.array(X_normed.mean(axis=0))

In [14]:
from os.path import join
import pickle

save_path_norm = join(SAVE_PATH, 'norm')
os.makedirs(save_path_norm)

# save fit data
os.makedirs(save_path_norm, exist_ok=True)
os.makedirs(join(save_path_norm, 'quantile_transform'), exist_ok=True)
os.makedirs(join(save_path_norm, 'zero_centering'))
np.save(join(save_path_norm, 'quantile_transform', 'quantiles.npy'), preproc_pipeline.steps[1][1].quantiles_)
np.save(join(save_path_norm, 'quantile_transform', 'references.npy'), preproc_pipeline.steps[1][1].references_)
np.save(join(save_path_norm, 'zero_centering', 'means.npy'), feature_means)

# save preproc pipeline
os.makedirs(join(save_path_norm, 'preproc_pipeline'), exist_ok=True)
pickle.dump(preproc_pipeline, open(join(save_path_norm, 'preproc_pipeline', 'preproc_pipeline.pickle'), 'wb'))

# Inference + save preprocessed data to disk

Restart notebook / dask scheduler for this to use processes scheduler for this part

In [2]:
from os.path import join
import shutil
from tqdm import tqdm
import os
import pickle

import numpy as np
import dask.array as da

In [6]:
preproc_pipeline = pickle.load(open(join(SAVE_PATH, 'norm', 'preproc_pipeline', 'preproc_pipeline.pickle'), 'rb'))

In [7]:
import dask
dask.config.set(scheduler='processes', num_workers=12)


def preprocess_gene_matrix(x):
    x_normed = preproc_pipeline.transform(x)
    x_normed = x_normed.astype(np.float32)

    return x_normed


for split in ['train', 'val', 'test']:
    
    path_raw = join(DATA_PATH, split)
    save_path = join(SAVE_PATH, split)
    os.makedirs(save_path)
    
    for file in ['obs.parquet', 'var.parquet']:
        shutil.copy(join(path_raw, file), join(save_path, file))
        
    da.to_zarr(
        da.from_zarr(join(path_raw, 'zarr'), component='X').map_blocks(preprocess_gene_matrix, dtype='f4'), 
        join(save_path, 'zarr'),
        component='X',
        compute=True,
        compressor='default', 
        order='C'
    )
