In [21]:
import pandas as pd
from os.path import join
import anndata
from typing import List
import dask.dataframe as dd
import numpy as np
import pickle
import os
import logging
from typing import Tuple
from self_supervision.paths import DATA_DIR

In [None]:
STORE_DIR = os.path.join(DATA_DIR, 'merlin_cxg_2023_05_15_sf-log1p')

In [None]:
def get_count_matrix_and_obs(ddf):
    x = (
        ddf['X']
        .map_partitions(
            lambda xx: pd.DataFrame(np.vstack(xx.tolist())), 
            meta={col: 'f4' for col in range(19331)}
        )
        .to_dask_array(lengths=[1024] * ddf.npartitions)
    )
    obs = ddf[['cell_type']].compute()
    
    return x, obs

In [None]:
def get_count_matrix_and_obs_new(ddf: dd.DataFrame, n_features: int = 19331) -> Tuple[np.ndarray, pd.DataFrame]:
    def transform_partition(partition):
        # Ensure the input is transformed to a list of numpy arrays
        partition_list = [np.array(x) for x in partition]
        transformed = np.vstack(partition_list)
        
        # Debug: check shape consistency within the partition
        print(f"Transformed partition shape: {transformed.shape}")

        return pd.DataFrame(transformed, columns=[f'feature_{i}' for i in range(n_features)])

    x = (
        ddf['X']
        .map_partitions(
            transform_partition,
            meta={f'feature_{i}': 'f4' for i in range(n_features)}
        )
        .to_dask_array(lengths=True)
    )

    obs = ddf[['cell_type', 'tech_sample']].compute()
    
    return x, obs

In [None]:
# Load the train data
split='train'
ddf_split = dd.read_parquet(join(STORE_DIR, split))
x_split_new, obs_split_new = get_count_matrix_and_obs_new(ddf_split)

In [18]:
x_split_new.shape

(15240192, 19331)

In [19]:
obs_split_new.shape

(15240192, 2)

In [None]:
adata = anndata.AnnData(X=x_split_new, obs=obs_split_new)

In [17]:
adata.write_h5ad(DATA_DIR + 'log1p_cellxgene_train_adata.h5ad')

Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed

In [25]:
del adata

In [27]:
hvg_indices = pickle.load(open(DATA_DIR + 'hvg_2000_indices.pickle','rb'))

In [28]:
adata = anndata.AnnData(X=x_split_new[:, hvg_indices], obs=obs_split_new)

In [29]:
adata.write_h5ad(DATA_DIR + 'log1p_cellxgene_hvg_train_adata.h5ad')

Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed

In [30]:
del adata
del x_split_new
del obs_split_new

In [31]:
# Load the val data
split='val'
ddf_split = dd.read_parquet(join(DATA_DIR, split))
x_split, obs_split = get_count_matrix_and_obs_new(ddf_split)

Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed

In [32]:
x_split.shape

(3500032, 19331)

In [33]:
obs_split.shape

(3500032, 2)

In [34]:
adata = anndata.AnnData(X=x_split, obs=obs_split)

In [35]:
adata.write_h5ad(DATA_DIR + 'log1p_cellxgene_val_adata.h5ad')

Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)Transformed partition shape: (4096, 19331)

Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed

In [36]:
del adata

In [37]:
adata = anndata.AnnData(X=x_split[:, hvg_indices], obs=obs_split)

In [38]:
adata.write_h5ad(DATA_DIR + 'log1p_cellxgene_hvg_val_adata.h5ad')

Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed

In [41]:
# Load the val data
split='test'
ddf_split = dd.read_parquet(join(STORE_DIR, split))
x_split, obs_split = get_count_matrix_and_obs_new(ddf_split)

Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed

In [42]:
x_split.shape

(3448832, 19331)

In [43]:
obs_split.shape

(3448832, 2)

In [44]:
adata = anndata.AnnData(X=x_split, obs=obs_split)

In [45]:
adata.write_h5ad(DATA_DIR + 'log1p_cellxgene_test_adata.h5ad')

Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed

In [46]:
del adata

In [47]:
adata = anndata.AnnData(X=x_split[:, hvg_indices], obs=obs_split)

In [48]:
adata.write_h5ad(DATA_DIR + 'log1p_cellxgene_hvg_test_adata.h5ad')

Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (4096, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed partition shape: (7168, 19331)
Transformed