In [None]:
!pip install zarr
!pip install scipy

In [None]:
import os

import dask
import dask.array as da
import dask.dataframe as dd
import pandas as pd
import numpy as np
import pyarrow as pa

from os.path import join
import shutil

In [None]:
from dask.distributed import Client, LocalCluster


cluster = LocalCluster(n_workers=5)  # assume 20 cores on LRZ -> 5 workers with 4 threads each
client = Client(cluster)
client

In [None]:
NORMALIZATION = 'sf-log1p'

# sf-log1p -> normalize to 10000 counts + log1p transform data
# raw -> don't normalize data

assert NORMALIZATION in ['sf-log1p', 'raw']

In [None]:
from scipy.sparse import csc_matrix, csr_matrix, issparse
from sklearn.utils import sparsefuncs


def sf_normalize(X):
    X = X.copy()
    counts = np.array(X.sum(axis=1))
    # avoid zero devision error
    counts += counts == 0.
    # normalize to 10000. counts
    scaling_factor = 10000. / counts

    if issparse(X):
        sparsefuncs.inplace_row_scale(X, scaling_factor)
    else:
        np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)

    return X


def sf_log1p_norm(x):
    x = sf_normalize(x)
    return np.log1p(x).astype('f4')


def preprocess_count_matrix(x, normalization):
    if normalization == 'sf-log1p':
        return x.map_blocks(sf_log1p_norm, dtype='f4')
    elif normalization == 'raw':
        return x
    else:
        raise ValueError(f'NORMALIZATION has to be in ["sf-log1p", "raw"]')


@dask.delayed
def convert_to_dataframe(x, start, end):
    return pd.DataFrame(
        {'X': [arr.squeeze().astype('f4') for arr in np.vsplit(x, x.shape[0])]},
        index=pd.RangeIndex(start, end)
    )


# Training data

In [None]:
DATA_PATH = '/mnt/dssfs02/cxg_census/data_2023_05_15'
OUT_PATH = f'/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_{NORMALIZATION}'

os.makedirs(OUT_PATH)

## Copy var dataframe + norm data

In [None]:
shutil.copy(join(DATA_PATH, 'train', 'var.parquet'), join(OUT_PATH, 'var.parquet'));

In [None]:
# only run if NORMALIZATION == 'sf-quantile'
!cp -r {join(DATA_PATH, 'norm')} {join(OUT_PATH, 'norm')}

## Create lookup tables for categorical variables

In [None]:
from pandas import testing as tm

In [None]:
obs_train = pd.read_parquet(join(DATA_PATH, 'train', 'obs.parquet')).reset_index(drop=True)
obs_val = pd.read_parquet(join(DATA_PATH, 'val', 'obs.parquet')).reset_index(drop=True)
obs_test = pd.read_parquet(join(DATA_PATH, 'test', 'obs.parquet')).reset_index(drop=True)

obs = pd.concat([obs_train, obs_val, obs_test])

In [None]:
cols_train = obs_train.columns.tolist()
assert cols_train == obs_val.columns.tolist()
assert cols_train == obs_test.columns.tolist()

In [None]:
for col in cols_train:
    if obs[col].dtype.name == 'category':
        obs[col] = obs[col].cat.remove_unused_categories()


for col in cols_train:
    if obs[col].dtype.name == 'category':
        categories = list(obs[col].cat.categories)
        obs_train[col] = pd.Categorical(obs_train[col], categories, ordered=False)
        obs_val[col] = pd.Categorical(obs_val[col], categories, ordered=False)
        obs_test[col] = pd.Categorical(obs_test[col], categories, ordered=False)

In [None]:
lookup_path = join(OUT_PATH, 'categorical_lookup')
os.makedirs(lookup_path)

for col in cols_train:
    if obs_train[col].dtype.name == 'category':
        cats_train = pd.Series(dict(enumerate(obs_train[col].cat.categories))).to_frame().rename(columns={0: 'label'})
        cats_val = pd.Series(dict(enumerate(obs_val[col].cat.categories))).to_frame().rename(columns={0: 'label'})
        cats_test = pd.Series(dict(enumerate(obs_test[col].cat.categories))).to_frame().rename(columns={0: 'label'})

        tm.assert_frame_equal(cats_train, cats_val)
        tm.assert_frame_equal(cats_train, cats_test)

        cats_train.to_parquet(join(lookup_path, f'{col}.parquet'), index=True)


In [None]:
# only use integer labels from now on
for col in cols_train:
    if obs_train[col].dtype.name == 'category':
        obs_train[col] = obs_train[col].cat.codes.astype('i8')
        obs_val[col] = obs_val[col].cat.codes.astype('i8')
        obs_test[col] = obs_test[col].cat.codes.astype('i8')

In [None]:
obs_dict = {'train': obs_train, 'val': obs_val, 'test': obs_test}

In [None]:
from sklearn.utils.class_weight import compute_class_weight

# calculate and save class weights
class_weights = compute_class_weight('balanced', classes=np.unique(obs_train['cell_type']), y=obs_train['cell_type'])

with open(join(OUT_PATH, 'class_weights.npy'), 'wb') as f:
    np.save(f, class_weights)

## Write store

In [None]:
CHUNK_SIZE = 32768
ROW_GROUP_SIZE = 1024


for split in ['train', 'val', 'test']:
    X = preprocess_count_matrix(da.from_zarr(join(DATA_PATH, split, 'zarr'), 'X'), NORMALIZATION)
    obs_ = obs_dict[split]
    # cut off samples that all row groups are full
    n_samples = X.shape[0]
    n_samples = (n_samples // ROW_GROUP_SIZE) * ROW_GROUP_SIZE
    X = X[:n_samples].rechunk((CHUNK_SIZE, -1))
    obs_ = obs_.iloc[:n_samples].copy()
    # add an index column to identifiy each sample
    obs_['idx'] = np.arange(len(obs_), dtype='i8')
    start_index = [0] + list(np.cumsum(X.chunks[0]))[:-1]
    end_index = list(np.cumsum(X.chunks[0]))
    # calculate divisons for dask dataframe
    divisions = [0] + list(np.cumsum(X.chunks[0]))
    divisions[-1] = divisions[-1] - 1
    ddf = dd.from_delayed(
        [
            convert_to_dataframe(arr, start, end) for arr, start, end in 
            zip(X.to_delayed().flatten().tolist(), start_index, end_index)
        ],
        divisions=divisions
    )
    obs_dask = dd.from_pandas(obs_, chunksize=CHUNK_SIZE)
    assert np.allclose(ddf.divisions, obs_dask.divisions)
    ddf = dd.multi.concat([ddf, obs_dask], axis=1)

    schema = pa.schema([
        ('X', pa.list_(pa.float32())),
        ('soma_joinid', pa.int64()),
        ('is_primary_data', pa.bool_()),
        ('dataset_id', pa.int64()),
        ('donor_id', pa.int64()),
        ('assay', pa.int64()),
        ('cell_type', pa.int64()),
        ('development_stage', pa.int64()),
        ('disease', pa.int64()),
        ('tissue', pa.int64()),
        ('tissue_general', pa.int64()),
        ('tech_sample', pa.int64()),
        ('idx', pa.int64()),
    ])
    print(f'{split}: {X.shape[0]} cells')
    ddf.to_parquet(
        join(OUT_PATH, split), 
        engine='pyarrow',
        schema=schema,
        write_metadata_file=True,
        row_group_size=ROW_GROUP_SIZE
    )
    
    # free up memory
    client.restart()
