In [None]:
!pip install celltypist

In [None]:
from os.path import join

import anndata
import scanpy as sc
import numpy as np
import pandas as pd
import dask.dataframe as dd
import dask.array as da

from scipy.sparse import csr_matrix

# Get subset training data

In [None]:
def get_count_matrix_and_obs(ddf):
    x = (
        ddf['X']
        .map_partitions(
            lambda xx: pd.DataFrame(np.vstack(xx.tolist())), 
            meta={col: 'f4' for col in range(19331)}
        )
        .to_dask_array(lengths=[1024] * ddf.npartitions)
    )
    obs = ddf[['cell_type']].compute()
    
    return x, obs

In [None]:
PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'

In [None]:
ddf = dd.read_parquet(join(PATH, 'train'), split_row_groups=True)
x, obs = get_count_matrix_and_obs(ddf)
var = pd.read_parquet(join(PATH, 'var.parquet'))

In [None]:
start = 0
subsample_size = 1_500_000
# data is already shuffled -> just take first x cells
# data is already normalized
adata_train = anndata.AnnData(
    X=x[start:start+subsample_size].map_blocks(csr_matrix).compute(), 
    obs=obs.iloc[start:start+subsample_size],
    var=var.set_index('feature_name')
)

adata_train

# Fit celltyist model

In [None]:
import celltypist

In [None]:
new_model = celltypist.train(
    adata_train, 
    labels='cell_type', 
    n_jobs=20, 
    feature_selection=True,
    use_SGD=True, 
    mini_batch=True,
    batch_number=1500,
    with_mean=False,
    random_state=1
)

In [None]:
new_model.write(f'/mnt/dssfs02/tb_logs/cxg_2023_05_15_celltypist/model_{subsample_size}_cells_run1.pkl')