In [24]:
from os.path import join

import anndata
import numpy as np
import pandas as pd
import dask.dataframe as dd

In [25]:
DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'

## Retieve embeddings for train and test data for UCE model

In [None]:
!pip install -U cellxgene-census

In [3]:
import cellxgene_census
from cellxgene_census.experimental import get_embedding

In [4]:
cell_type_mapping = pd.read_parquet(join(DATA_PATH, 'categorical_lookup/cell_type.parquet'))
cell_type_hierarchy = np.load(join(DATA_PATH, 'cell_type_hierarchy/child_matrix.npy'))
dataset_id_mapping = pd.read_parquet(join(DATA_PATH, 'categorical_lookup/dataset_id.parquet'))
donor_id_mapping = pd.read_parquet(join(DATA_PATH, 'categorical_lookup/donor_id.parquet'))
tech_sample_mapping = pd.read_parquet(join(DATA_PATH, 'categorical_lookup/tech_sample.parquet'))

In [5]:
obs_columns = ['cell_type', 'dataset_id', 'donor_id', 'tech_sample']
ddf_train = dd.read_parquet(join(DATA_PATH, 'train'), split_row_groups=True)
obs_train = dd.read_parquet(join(DATA_PATH, 'train'), columns=obs_columns).compute()
ddf_test = dd.read_parquet(join(DATA_PATH, 'test'), split_row_groups=True)
obs_test = dd.read_parquet(join(DATA_PATH, 'test'), columns=obs_columns).compute()
var = pd.read_parquet(join(DATA_PATH, 'var.parquet'))

In [6]:
obs_test['tech_sample'] = obs_test.tech_sample.replace(tech_sample_mapping.label.to_dict()).astype('category')
obs_test['cell_type'] = obs_test.cell_type.replace(cell_type_mapping.label.to_dict()).astype('category')
obs_test['dataset_id'] = obs_test.dataset_id.replace(dataset_id_mapping.label.to_dict()).astype('category')
obs_test['donor_id'] = obs_test.donor_id.replace(donor_id_mapping.label.to_dict()).astype('category')

In [7]:
obs_train['tech_sample'] = obs_train.tech_sample.replace(tech_sample_mapping.label.to_dict()).astype('category')
obs_train['cell_type'] = obs_train.cell_type.replace(cell_type_mapping.label.to_dict()).astype('category')
obs_train['dataset_id'] = obs_train.dataset_id.replace(dataset_id_mapping.label.to_dict()).astype('category')
obs_train['donor_id'] = obs_train.donor_id.replace(donor_id_mapping.label.to_dict()).astype('category')

In [8]:
PROTOCOLS = [
    "10x 5' v2", 
    "10x 3' v3", 
    "10x 3' v2", 
    "10x 5' v1", 
    "10x 3' v1", 
    "10x 3' transcription profiling", 
    "10x 5' transcription profiling"
]


COLUMN_NAMES = [
    "soma_joinid",
    "is_primary_data",
    "dataset_id", 
    "donor_id",
    "assay", 
    "cell_type", 
    "development_stage", 
    "disease", 
    "tissue", 
    "tissue_general"
]

In [9]:
census = cellxgene_census.open_soma(census_version="2023-12-15")

In [10]:
obs = (
    census["census_data"]["homo_sapiens"]
    .obs
    .read(
        column_names=COLUMN_NAMES,
        value_filter=f"is_primary_data == True and assay in {PROTOCOLS}"
    )
    .concat()
    .to_pandas()
)

In [29]:
set(obs_test.tech_sample.unique()) - set(obs.tech_sample.unique())

{'0c774045-26a7-40f8-9b07-6742d3c771c0_ac03',
 '0c774045-26a7-40f8-9b07-6742d3c771c0_ac04',
 '0c774045-26a7-40f8-9b07-6742d3c771c0_ac07',
 '0c774045-26a7-40f8-9b07-6742d3c771c0_ac09',
 '0c774045-26a7-40f8-9b07-6742d3c771c0_ac15',
 '0c774045-26a7-40f8-9b07-6742d3c771c0_ac18',
 '0c774045-26a7-40f8-9b07-6742d3c771c0_ac21',
 '87ce26ed-e5d1-44b4-81cc-cc5b709a169f_CV-056',
 '87ce26ed-e5d1-44b4-81cc-cc5b709a169f_SK-008',
 '87ce26ed-e5d1-44b4-81cc-cc5b709a169f_SK-014',
 '9ff99bf8-2524-4ab5-ab6e-4bc218e4a449_pooled [sanes2022_Hu0216,sanes2022_Hu0218]',
 '9ff99bf8-2524-4ab5-ab6e-4bc218e4a449_sanes2022_Hu0216',
 'b252b015-b488-4d5c-b16e-968c13e48a2c_SPECTRUM-OV-014',
 'b252b015-b488-4d5c-b16e-968c13e48a2c_SPECTRUM-OV-053',
 'b252b015-b488-4d5c-b16e-968c13e48a2c_SPECTRUM-OV-067',
 'b252b015-b488-4d5c-b16e-968c13e48a2c_SPECTRUM-OV-075',
 'b252b015-b488-4d5c-b16e-968c13e48a2c_SPECTRUM-OV-082',
 'b252b015-b488-4d5c-b16e-968c13e48a2c_SPECTRUM-OV-083',
 'b252b015-b488-4d5c-b16e-968c13e48a2c_SPECTRUM-OV

In [33]:
obs_test.tech_sample.nunique()

758

In [31]:
obs_test.tech_sample.nunique() - len(set(obs_test.tech_sample.unique()) - set(obs.tech_sample.unique()))

736

In [11]:
obs['tech_sample'] = (obs.dataset_id + '_' + obs.donor_id).astype('category')

In [12]:
obs_subset_test = obs.query(
    f"tech_sample in {obs_test.tech_sample.unique().tolist()} and cell_type in {obs_test.cell_type.unique().tolist()}"
)
obs_subset_train = obs.query(
    f"tech_sample in {obs_train.tech_sample.unique().tolist()} and cell_type in {obs_train.cell_type.unique().tolist()}"
)

In [13]:
rng = np.random.default_rng(seed=1)
soma_ids_train = rng.choice(obs_subset_train.soma_joinid.to_numpy(), 1_500_000, replace=False)
soma_ids_test = obs_subset_test.soma_joinid.to_numpy()

In [None]:
embedding_uri = "s3://cellxgene-contrib-public/contrib/cell-census/soma/2023-12-15/CxG-contrib-2"
census = cellxgene_census.open_soma(census_version="2023-12-15")

In [None]:
embeddings_train = get_embedding("2023-12-15", embedding_uri, soma_ids_train)
embeddings_test = get_embedding("2023-12-15", embedding_uri, soma_ids_test)

In [None]:
np.save('/mnt/dssfs02/tb_logs/UCE/train_embeddings.npy', embeddings_train)
np.save('/mnt/dssfs02/tb_logs/UCE/train_soma_ids.npy', soma_ids_train)
np.save('/mnt/dssfs02/tb_logs/UCE/test_embeddings.npy', embeddings_test)
np.save('/mnt/dssfs02/tb_logs/UCE/test_soma_ids.npy', soma_ids_test)

In [None]:
anndata.AnnData(
    X=embeddings_train,
    obs=obs_subset_train.set_index('soma_joinid').loc[soma_ids_train]
).write('/mnt/dssfs02/tb_logs/UCE/train.h5ad')

anndata.AnnData(
    X=embeddings_test,
    obs=obs_subset_test.set_index('soma_joinid').loc[soma_ids_test]
).write('/mnt/dssfs02/tb_logs/UCE/test.h5ad')



## Fit linear model on top of UCE embeddings

In [1]:
DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'

In [2]:
from os.path import join

import anndata
import numpy as np
import pandas as pd

from cuml.linear_model import LogisticRegression
from tqdm import tqdm

In [3]:
adata_train = anndata.read_h5ad('/mnt/dssfs02/tb_logs/UCE/train.h5ad')

In [4]:
cell_type_mapping = pd.read_parquet(join(DATA_PATH, 'categorical_lookup/cell_type.parquet'))
inverse_mapping = {v: k for k, v in cell_type_mapping.label.to_dict().items()}
cell_type_hierarchy = np.load(join(DATA_PATH, 'cell_type_hierarchy/child_matrix.npy'))

In [5]:
X = adata_train.X
y = adata_train.obs.cell_type.replace(inverse_mapping).to_numpy()

In [6]:
# embeddings have nan values -> fill those with zeros
X[np.isnan(X)] = 0.

In [7]:
clf_list = []

for i in tqdm(range(4)):
    clf_list.append(LogisticRegression(class_weight='balanced').fit(X, y))


 25%|██▌       | 1/4 [05:07<15:23, 307.90s/it]

[W] [19:42:44.072709] L-BFGS stopped, because the line search failed to advance (step delta = 0.000000)


 50%|█████     | 2/4 [09:16<09:06, 273.02s/it]

[W] [19:46:52.677486] L-BFGS line search failed (code 3); stopping at the last valid step


 75%|███████▌  | 3/4 [10:51<03:11, 191.65s/it]

[W] [19:48:27.502374] L-BFGS line search failed (code 3); stopping at the last valid step


100%|██████████| 4/4 [13:01<00:00, 195.39s/it]

[W] [19:50:37.742261] L-BFGS line search failed (code 3); stopping at the last valid step





## Evaluate on test data

In [8]:
from sklearn.metrics import classification_report
from statistics import mean, stdev

from utils import correct_labels

In [9]:
adata_test = anndata.read_h5ad('/mnt/dssfs02/tb_logs/UCE/test.h5ad')

In [10]:
x_test = adata_test.X
y_test = adata_test.obs.cell_type.replace(inverse_mapping).to_numpy()

In [11]:
x_test[np.isnan(x_test)] = 0.

In [12]:
clf_reports = []

for clf in clf_list:
    preds = clf.predict(x_test)
    preds_corr = correct_labels(y_test, preds, cell_type_hierarchy)
    clf_reports.append(pd.DataFrame(classification_report(y_test, preds_corr, output_dict=True, zero_division=0.)).T)


In [13]:
f1_scores_scgpt = [clf_report.loc['macro avg', 'f1-score'] for clf_report in clf_reports]
print(f'{mean(f1_scores_scgpt):.4f}±{stdev(f1_scores_scgpt):.4f}')

0.7611±0.0018
