In [1]:
import os

DATA_PATH = '/lustre/scratch/users/felix.fischer/merlin_cxg_simple_norm'
OUT_PATH = '/lustre/scratch/users/felix.fischer/merlin_cxg_simple_norm_parquet'

os.makedirs(OUT_PATH)

In [2]:
import dask
import dask.array as da
import dask.dataframe as dd
import pandas as pd
import numpy as np
import pyarrow as pa

from os.path import join
import shutil

In [3]:
from multiprocessing.pool import ThreadPool
import dask

dask.config.set(scheduler='processes', num_workers=4);

# Copy var dataframe + norm data

In [4]:
shutil.copy(join(DATA_PATH, 'train', 'var.parquet'), join(OUT_PATH, 'var.parquet'));

In [5]:
!cp -r {join(DATA_PATH, 'norm')} {join(OUT_PATH, 'norm')}

# Create lookup tables for categorical variables

In [14]:
from pandas import testing as tm

In [7]:
obs_train = pd.read_parquet(join(DATA_PATH, 'train', 'obs.parquet')).reset_index(drop=True)
obs_val = pd.read_parquet(join(DATA_PATH, 'val', 'obs.parquet')).reset_index(drop=True)
obs_test = pd.read_parquet(join(DATA_PATH, 'test', 'obs.parquet')).reset_index(drop=True)

obs = pd.concat([obs_train, obs_val, obs_test])

In [8]:
cols_train = obs_train.columns.tolist()
assert cols_train == obs_val.columns.tolist()
assert cols_train == obs_test.columns.tolist()

In [9]:
for col in cols_train:
    obs[col] = obs[col].cat.remove_unused_categories()


for col in cols_train:
    categories = list(obs[col].cat.categories)
    obs_train[col] = pd.Categorical(obs_train[col], categories, ordered=False)
    obs_val[col] = pd.Categorical(obs_val[col], categories, ordered=False)
    obs_test[col] = pd.Categorical(obs_test[col], categories, ordered=False)

In [10]:
lookup_path = join(OUT_PATH, 'categorical_lookup')
os.makedirs(lookup_path)

for col in cols_train:
    cats_train = pd.Series(dict(enumerate(obs_train[col].cat.categories))).to_frame().rename(columns={0: 'label'})
    cats_val = pd.Series(dict(enumerate(obs_val[col].cat.categories))).to_frame().rename(columns={0: 'label'})
    cats_test = pd.Series(dict(enumerate(obs_test[col].cat.categories))).to_frame().rename(columns={0: 'label'})
    
    tm.assert_frame_equal(cats_train, cats_val)
    tm.assert_frame_equal(cats_train, cats_val)
    
    cats_train.to_parquet(join(lookup_path, f'{col}.parquet'), index=True)


In [11]:
# only use integer labels from now on
for col in cols_train:
    obs_train[col] = obs_train[col].cat.codes.astype('i8')
    obs_val[col] = obs_val[col].cat.codes.astype('i8')
    obs_test[col] = obs_test[col].cat.codes.astype('i8')

In [12]:
obs_dict = {'train': obs_train, 'val': obs_val, 'test': obs_test}

In [13]:
from sklearn.utils.class_weight import compute_class_weight

# calculate and save class weights
class_weights = compute_class_weight('balanced', classes=np.unique(obs_train['cell_type']), y=obs_train['cell_type'])

with open(join(OUT_PATH, 'class_weights.npy'), 'wb') as f:
    np.save(f, class_weights)

# Write store

In [15]:
CHUNK_SIZE = 32768
ROW_GROUP_SIZE = 1024


for split in ['train', 'val', 'test']:
    X = da.from_zarr(join(DATA_PATH, split, 'zarr'), 'X')
    obs_ = obs_dict[split]
    
    # cut off samples that all parquet files have the same number of samples
    n_samples = X.shape[0]
    n_samples = (n_samples // ROW_GROUP_SIZE) * ROW_GROUP_SIZE
    X = X[:n_samples]
    obs_ = obs_.iloc[:n_samples].copy()
    print(f'{split}: {X.shape[0]} cells')
    # add an index column to identifiy each sample
    obs_['idx'] = np.arange(len(obs_), dtype='i8')

    ddf_X = (
        X
        .rechunk((CHUNK_SIZE, -1))
        .to_dask_dataframe()
        .map_partitions(
            lambda df: df.apply(
                lambda row: np.array(row.tolist()).astype('f4'), axis=1
            ).to_frame().rename(columns={0: 'X'}),
            meta={'X': 'object'}
        )
    )
    ddf = dd.multi.concat([ddf_X, obs_], axis=1)

    schema = pa.schema([
        ('X', pa.list_(pa.float32())),
        ('id', pa.int64()),  
        ('assay_sc', pa.int64()),
        ('tech_sample', pa.int64()),
        ('cell_type', pa.int64()),
        ('cell_type_ontology_term_id', pa.int64()),
        ('disease', pa.int64()),
        ('development_stage', pa.int64()),
        ('organ', pa.int64()),
        ('idx', pa.int64()),
    ])
    ddf.to_parquet(
        join(OUT_PATH, split), 
        engine='pyarrow',
        schema=schema,
        write_metadata_file=True,
        row_group_size=ROW_GROUP_SIZE
    )


train: 6631424 cells
val: 1533952 cells
test: 1403904 cells
