In [1]:
EXPERIMENT_INFO = {
    'pancreas': {
        'file_name': 'pancreas.h5ad',
        'condition_key': 'study',
        'cell_type_key': ['cell_type'],
        'reference': [
            "inDrop1", 
            "inDrop2", 
            "inDrop3", 
            "inDrop4", 
            "fluidigmc1", 
            "smartseq2", 
            "smarter"
        ],
        'query': [
            'celseq',
            'celseq2'
        ],
    },
    'pbmc': {
        'file_name': 'pbmc.h5ad',
        'condition_key': 'condition',
        'cell_type_key': ['cell_type'],
        'reference': [
            "Oetjen", 
            "10X", 
            "Sun", 
        ],
        'query': [
            'Freytag',
        ],
    },
    'brain': {
        'file_name': 'brain.h5ad',
        'condition_key': 'study',
        'cell_type_key': ['cell_type'],
        'reference': [
            'Rosenberg',
            'Saunders',
        ],
        'query': [
            'Zeisel',
            'Tabula_muris'
        ],
    },
    'scvelo': {
        'file_name': 'scvelo.h5ad',
        'condition_key': 'study',
        'cell_type_key': ['cell_type'],
        'reference': [
            '12.5',
            '13.5',
        ],
        'query': [
            '14.5',
            '15.5'
        ],
    },
    'lung': {
        'file_name': 'lung.h5ad',
        'condition_key': 'study',
        'cell_type_key': ['cell_type'],
        'reference': [
            'Dropseq_transplant',
            '10x_Biopsy',
        ],
        'query': [
            '10x_Transplant',
        ],
    },
    'tumor': {
        'file_name': 'tumor.h5ad',
        'condition_key': 'study',
        'cell_type_key': ['cell_type'],
        'reference': [
            'breast', 
            'colorectal', 
            'liver2', 
            'liver1', 
            'lung1', 
            'lung2', 
            'multiple', 
            'ovary',
            'pancreas', 
            'skin'
        ],
        'query': [
            'melanoma1',
            'melanoma2',
            'uveal melanoma'
        ],
    },
    'lung_h_sub': {
        'file_name': 'adata_lung_subsampled.h5ad',
        'condition_key': 'study',
        'cell_type_key': ['ann_level_1', 'ann_level_2'],
        'reference': [
            "Stanford_Krasnow_bioRxivTravaglini", 
            "Misharin_new"    
        ],
        'query': [
            "Vanderbilt_Kropski_bioRxivHabermann_vand", 
            "Sanger_Teichmann_2019VieiraBraga"
        ],
    },
    }


In [2]:
import numpy as np

def label_encoder(adata, encoder=None, condition_key='condition'):
    """Encode labels of Annotated `adata` matrix.
       Parameters
       ----------
       adata: : `~anndata.AnnData`
            Annotated data matrix.
       encoder: Dict or None
            dictionary of encoded labels. if `None`, will create one.
       condition_key: String
            column name of conditions in `adata.obs` data frame.

       Returns
       -------
       labels: `~numpy.ndarray`
            Array of encoded labels
       label_encoder: Dict
            dictionary with labels and encoded labels as key, value pairs.
    """
    unique_conditions = list(np.unique(adata.obs[condition_key]))
    if encoder is None:
        encoder = {k: v for k, v in zip(sorted(unique_conditions), np.arange(len(unique_conditions)))}

    labels = np.zeros(adata.shape[0])
    if not set(unique_conditions).issubset(set(encoder.keys())):
        print("Warning: Labels in adata is not a subset of label-encoder!")
        for data_cond in unique_conditions:
            if data_cond not in encoder.keys():
                labels[adata.obs[condition_key] == data_cond] = -1

    for condition, label in encoder.items():
        labels[adata.obs[condition_key] == condition] = label
    return labels.reshape(-1, 1), encoder

In [3]:
import torch
import os
from anndata import AnnData
from benchmarks.mars.args_parser import get_parser
from benchmarks.mars.model.mars import MARS
from benchmarks.mars.model.experiment_dataset import ExperimentDataset
import warnings

warnings.filterwarnings('ignore')

import scanpy as sc
import numpy as np
import pandas as pd
import scarches as sca
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from scarches.dataset.trvae.data_handling import remove_sparsity
from lataq_reproduce.exp_dict import EXPERIMENT_INFO
from lataq_reproduce.utils import label_encoder
from lataq.metrics.metrics import metrics
import time

data = 'pancreas'

DATA_DIR = '/storage/groups/ml01/workspace/carlo.dedonno/lataq_reproduce/data'
RES_PATH = (
    f'/storage/groups/ml01/workspace/carlo.dedonno/'
    f'lataq_reproduce/results/mars/{data}'
)
EXP_PARAMS = EXPERIMENT_INFO[data]
FILE_NAME = EXP_PARAMS['file_name']

def celltype_to_numeric(adata, obs_key):
    """Adds ground truth clusters data."""
    annotations = list(adata.obs[obs_key])
    annotations_set = sorted(set(annotations))

    mapping = {a: idx for idx, a in enumerate(annotations_set)}

    truth_labels = [mapping[a] for a in annotations]
    adata.obs['truth_labels'] = pd.Categorical(values=truth_labels)

    return adata, mapping

params, unknown = get_parser().parse_known_args()
params.cuda = True
params.pretrain_batch = 128
print('PARAMS:', params)
if torch.cuda.is_available() and not params.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")
device = 'cuda:0' if torch.cuda.is_available() and params.cuda else 'cpu'
params.device = device

# LOADING DATA
adata = sc.read(f'{DATA_DIR}/{FILE_NAME}')
condition_key = EXP_PARAMS['condition_key']
cell_type_key = EXP_PARAMS['cell_type_key']
reference = EXP_PARAMS['reference']
query = EXP_PARAMS['query']

if issparse(adata.X):
    adata.X = adata.X.A
# Create Int Mapping for celltypes
adata, celltype_id_map = celltype_to_numeric(adata, cell_type_key[0])
cell_type_name_map = {v: k for k, v in celltype_id_map.items()}

# Preprocess data
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
sc.pp.log1p(adata)
sc.pp.scale(adata, max_value=10, zero_center=True)

# Make labeled Datasets for Mars
annotated = []
labels = []
batches = []
for batch in reference:
    labeled_adata = adata[adata.obs.study.isin([batch])].copy()
    y_labeled = np.array(labeled_adata.obs['truth_labels'], dtype=np.int64)
    annotated.append(ExperimentDataset(labeled_adata.X,
                                       labeled_adata.obs_names,
                                       labeled_adata.var_names,
                                       batch,
                                       y_labeled
                                       ))
    labels += labeled_adata.obs[cell_type_key[0]].tolist()
    batches += labeled_adata.obs[condition_key].tolist()

# Make Unlabeled Datasets for Mars
unlabeled_adata = adata[adata.obs.study.isin(query)].copy()
y_unlabeled = np.array(unlabeled_adata.obs['truth_labels'], dtype=np.int64)
unannnotated = ExperimentDataset(
    unlabeled_adata.X,
    unlabeled_adata.obs_names,
    unlabeled_adata.var_names,
    'query',
    y_unlabeled
)
labels += unlabeled_adata.obs[cell_type_key[0]].tolist()
batches += unlabeled_adata.obs[condition_key].tolist()
n_clusters = len(np.unique(unannnotated.y))

# Make pretrain Dataset
pretrain = ExperimentDataset(
    adata.X,
    adata.obs_names,
    adata.var_names,
    'Pretrain'
)
logging.info('Data loaded succesfully')

# TRAINING REFERENCE MODEL
mars = MARS(
    n_clusters,
    params,
    annotated,
    unannnotated,
    pretrain,
    hid_dim_1=1000,
    hid_dim_2=100
)
ref_time = time.time()
adata, landmarks, _ = mars.train(evaluation_mode=True)
ref_time = time.time() - ref_time
# save ref time


# TODO: CHECK FROM HERE....
names = mars.name_cell_types(adata, landmarks, cell_type_name_map)
print(names)
unproc_labels = adata.obs['truth_labels'].tolist()
unproc_pred = adata.obs['MARS_labels'].tolist()

predictions = []
for count, label in enumerate(unproc_pred):
    if not isinstance(label, int):
        predictions.append(cell_type_name_map[unproc_labels[count]])
    elif len(names[label]) == 1:
        predictions.append(names[label][-1])
    else:
        predictions.append(names[label][-1][0])

labels_after = []
for count, label in enumerate(unproc_labels):
    labels_after.append(cell_type_name_map[label])

#report = classification_report(
#        y_true=np.array(labels_after)[adata.obs['experiment'] == 'query'],
#        y_pred=np.array(predictions)[adata.obs['experiment'] == 'query'],
#        labels=np.array(unlabeled_adata.obs[cell_type_key[0]].unique().tolist()),
#        #output_dict=True,
#)
report = pd.DataFrame(
    classification_report(
        y_true=np.array(labels_after)[adata.obs['experiment'] == 'query'],
        y_pred=np.array(predictions)[adata.obs['experiment'] == 'query'],
        labels=np.array(unlabeled_adata.obs[cell_type_key[0]].unique().tolist()),
        output_dict=True,
    )
).transpose()

report_full = pd.DataFrame(
    classification_report(
        y_true=np.array(labels_after),
        y_pred=np.array(predictions),
        output_dict=True
    )
).transpose().add_prefix('full_')

adata_latent = AnnData(adata.obsm['MARS_embedding'])
adata_latent.obs['celltype'] = labels_after
adata_latent.obs['predictions'] = predictions
adata_latent.obs['batch'] = batches
adata_latent.write_h5ad(f'{RES_PATH}/adata_latent_full.h5ad')

sc.pp.neighbors(adata_latent)
sc.tl.leiden(adata_latent)
sc.tl.umap(adata_latent)
sc.pl.umap(
    adata_latent,
    color=['batch'],
    frameon=False,
    wspace=0.6,
    show=False
)
plt.savefig(
    f'{RES_PATH}/full_umap_batch.png',
    bbox_inches='tight'
)
plt.close()
sc.pl.umap(
    adata_latent,
    color=['celltype'],
    frameon=False,
    wspace=0.6,
    show=False
)
plt.savefig(
    f'{RES_PATH}/full_umap_ct.png',
    bbox_inches='tight'
)
plt.close()
sc.pl.umap(
    adata_latent,
    color=['predictions'],
    frameon=False,
    wspace=0.6,
    show=False
)
plt.savefig(
    f'{RES_PATH}/full_umap_pred.png',
    bbox_inches='tight'
)
plt.close()

conditions, _ = label_encoder(adata, condition_key=condition_key)
labels, _ = label_encoder(adata, condition_key=cell_type_key[0])
adata.obs['batch'] = conditions.squeeze(axis=1)
adata.obs['celltype'] = labels.squeeze(axis=1)
conditions, _ = label_encoder(adata_latent, condition_key='batch')
labels, _ = label_encoder(adata_latent, condition_key='celltype')
adata_latent.obs['batch'] = conditions.squeeze(axis=1)
adata_latent.obs['celltype'] = labels.squeeze(axis=1)

scores = metrics(
    adata,
    adata_latent,
    'batch',
    'celltype',
    nmi_=False,
    ari_=False,
    silhouette_=False,
    pcr_=True,
    graph_conn_=True,
    isolated_labels_=False,
    hvg_score_=False,
    knn_=True,
    ebm_=True,
)

scores = scores.T
scores = scores[[  # 'NMI_cluster/label',
    # 'ARI_cluster/label',
    # 'ASW_label',
    # 'ASW_label/batch',
    'PCR_batch',
    # 'isolated_label_F1',
    # 'isolated_label_silhouette',
    'graph_conn',
    'ebm',
    'knn',
]]



In a future version of Scanpy, `scanpy.api` will be removed.
Simply use `import scanpy as sc` and `import scanpy.external as sce` instead.

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, positive=False):
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  max_n_alphas=1000, n_jobs=None, eps=np.finfo(np.float)

PARAMS: Namespace(cuda=True, epochs=30, epochs_pretrain=25, learning_rate=0.001, lr_scheduler_gamma=0.5, lr_scheduler_step=20, manual_seed=3, model_file='trained_models/source.pt', pretrain=True, pretrain_batch=128)


NameError: name 'issparse' is not defined

In [None]:
type(report)

In [None]:
report