In [2]:
import numpy as np
import os
import scanpy as sc
import sklearn
import warnings
import faiss
import argparse
import scgpt as scg
import pandas as pd
import biomart

from self_supervision.paths import BASE_DIR, DATA_DIR

# Filter out FutureWarnings
warnings.filterwarnings("ignore", category=FutureWarning)



In [3]:
def install_faiss():
    try:
        import faiss
        faiss_imported = True
    except ImportError:
        faiss_imported = False
        print(
            "faiss not installed! We highly recommend installing it for fast similarity search."
        )
        print("To install it, see https://github.com/facebookresearch/faiss/wiki/Installing-Faiss")

    warnings.filterwarnings("ignore", category=ResourceWarning)
    return faiss_imported


In [4]:
def l2_sim(a, b):
    sims = -np.linalg.norm(a - b, axis=1)
    return sims

def get_similar_vectors(vector, ref, top_k=10):
    sims = l2_sim(vector, ref)
    top_k_idx = np.argsort(sims)[::-1][:top_k]
    return top_k_idx, sims[top_k_idx]


In [36]:
def kNN(ref_embed_adata, test_embed_adata, test_adata, cell_type_key, k=10):
    import numpy as np
    import sklearn.metrics
    import faiss
    import anndata as ad

    ref_cell_embeddings = ref_embed_adata.X
    test_embed = test_embed_adata.X
    faiss_imported = install_faiss()

    if faiss_imported:
        index = faiss.IndexFlatL2(ref_cell_embeddings.shape[1])
        index.add(ref_cell_embeddings)
        distances, labels = index.search(test_embed, k)
    else:
        labels = []
        for i in range(test_embed.shape[0]):
            idx, _ = get_similar_vectors(test_embed[i][np.newaxis, ...], ref_cell_embeddings, k)
            labels.append(idx)

    preds = []
    for idx in labels:
        pred = ref_embed_adata.obs[cell_type_key][idx].value_counts()
        preds.append(pred.index[0])
    
    gt = test_adata.obs[cell_type_key].to_numpy()

    acc = sklearn.metrics.accuracy_score(gt, preds)
    f1 = sklearn.metrics.f1_score(gt, preds, average='micro')
    f1_macro = sklearn.metrics.f1_score(gt, preds, average='macro')

    per_cell_type_acc = {}
    for cell_type in test_adata.obs[cell_type_key].unique():
        idx = test_adata.obs[cell_type_key] == cell_type
        
        gt_subset = np.array(gt)[idx]
        preds_subset = np.array(preds)[idx]

        acc = sklearn.metrics.accuracy_score(gt_subset, preds_subset)
        per_cell_type_acc[cell_type] = acc

    return acc, f1, f1_macro, per_cell_type_acc


In [6]:
val_adata_hlca = sc.read_h5ad(os.path.join(DATA_DIR, "cellxgene_val_dataset_HLCA_adata.h5ad")
val_adata_pbmc = sc.read_h5ad(os.path.join(DATA_DIR, "cellxgene_val_dataset_PBMC_adata.h5ad")
val_adata_tabula_sapiens = sc.read_h5ad(os.path.join(DATA_DIR, "cellxgene_val_dataset_TabulaSapiens_adata.h5ad")

test_adata_hlca = sc.read_h5ad(os.path.join(DATA_DIR, "cellxgene_test_dataset_HLCA_adata.h5ad")
test_adata_pbmc = sc.read_h5ad(os.path.join(DATA_DIR, "cellxgene_test_dataset_PBMC_adata.h5ad")
test_adata_tabula_sapiens = sc.read_h5ad(os.path.join(DATA_DIR, "cellxgene_test_dataset_TabulaSapiens_adata.h5ad")

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


In [3]:
# Function to save Ensembl IDs to CSV
def save_ensembl_ids(adata, filename):
    if "ensembl_id" in adata.var.columns:
        ensembl_ids = adata.var["ensembl_id"].tolist()
        df = pd.DataFrame(ensembl_ids, columns=['ensembl_id'])
        df.to_csv(filename, index=False)
    else:
        print(f"ensembl_id column not found in {filename}")

# Save Ensembl IDs for each AnnData object
# save_ensembl_ids(val_adata_hlca, 'val_adata_hlca_ensembl_ids.csv')
# save_ensembl_ids(val_adata_pbmc, 'val_adata_pbmc_ensembl_ids.csv')
# save_ensembl_ids(val_adata_tabula_sapiens, 'val_adata_tabula_sapiens_ensembl_ids.csv')

# save_ensembl_ids(test_adata_hlca, 'test_adata_hlca_ensembl_ids.csv')
# save_ensembl_ids(test_adata_pbmc, 'test_adata_pbmc_ensembl_ids.csv')
# save_ensembl_ids(test_adata_tabula_sapiens, 'test_adata_tabula_sapiens_ensembl_ids.csv')

In [None]:
# Function to save Ensembl IDs to CSV, full adata

val_adata_full = sc.read_h5ad(os.path.join(DATA_DIR, "log1p_cellxgene_val_adata.h5ad"))
save_ensembl_ids(val_adata_full, 'val_adata_full_ensembl_ids.csv')
del val_adata_full
test_adata_full = sc.read_h5ad(os.path.join(DATA_DIR, "log1p_cellxgene_test_adata.h5ad"))
save_ensembl_ids(test_adata_full, 'test_adata_full_ensembl_ids.csv')
del test_adata_full

  utils.warn_names_duplicates("obs")


ensembl_id column not found in val_adata_full_ensembl_ids.csv


In [46]:
# Define the paths to your processed CSV files
mapping_files = {
    "val_hlca": "processed_mygene_val_hlca_ensembl_to_genesymbol.csv",
    "val_pbmc": "processed_mygene_val_pbmc_ensembl_to_genesymbol.csv",
    "val_tabulasapiens": "processed_mygene_val_tabulasapiens_ensembl_to_genesymbol.csv",
    "test_hlca": "processed_mygene_test_hlca_ensembl_to_genesymbol.csv",
    "test_pbmc": "processed_mygene_test_pbmc_ensembl_to_genesymbol.csv",
    "test_tabula_sapiens": "processed_mygene_test_tabulasapiens_ensembl_to_genesymbol.csv"
}

def load_mapping(file_path):
    df = pd.read_csv(file_path)
    return dict(zip(df['ensembl_id'], df['gene_symbol']))

In [47]:
def reference_mapping(model_dir, adata_dir, dataset_name, cell_type_key, gene_col):
    adata_paths = {
        "train": os.path.join(adata_dir, f"cellxgene_train_dataset_{dataset_name}_adata.h5ad"),
        "val": os.path.join(adata_dir, f"cellxgene_val_dataset_{dataset_name}_adata.h5ad"),
        "test": os.path.join(adata_dir, f"cellxgene_test_dataset_{dataset_name}_adata.h5ad"),
    }

    val_adata = sc.read_h5ad(adata_paths["val"])
    test_adata = sc.read_h5ad(adata_paths["test"])

    # Make observation names unique
    val_adata.obs_names_make_unique()
    test_adata.obs_names_make_unique()

    # Determine the appropriate mapping file to use
    mapping_key = f"val_{dataset_name.lower()}" if "val" in adata_paths["val"] else f"test_{dataset_name.lower()}"
    mapping_file = mapping_files[mapping_key]

    # Load the mapping
    mapping = load_mapping(mapping_file)

    if gene_col not in val_adata.var.columns:
        if "ensembl_id" in val_adata.var.columns:
            print(f"Converting Ensembl IDs to gene symbols using {mapping_file}...")
            ensembl_ids = val_adata.var["ensembl_id"].tolist()
            gene_symbols = [mapping.get(ensembl_id, '') for ensembl_id in ensembl_ids]
            val_adata.var[gene_col] = gene_symbols
            print(f"Successfully converted Ensembl IDs to gene symbols and added '{gene_col}' column.")
        else:
            raise ValueError(f"gene_col '{gene_col}' not found in val_adata.var.columns")

        if "ensembl_id" in test_adata.var.columns:
            print(f"Converting Ensembl IDs to gene symbols using {mapping_file}...")
            ensembl_ids = test_adata.var["ensembl_id"].tolist()
            gene_symbols = [mapping.get(ensembl_id, '') for ensembl_id in ensembl_ids]
            test_adata.var[gene_col] = gene_symbols
            print(f"Successfully converted Ensembl IDs to gene symbols and added '{gene_col}' column.")
        else:
            raise ValueError(f"gene_col '{gene_col}' not found in test_adata.var.columns")

    ref_embed_adata = scg.tasks.embed_data(
        val_adata,
        model_dir,
        gene_col=gene_col,
        obs_to_save=cell_type_key,
        batch_size=64,
        return_new_adata=True,
    )

    sc.pp.neighbors(ref_embed_adata, use_rep="X")
    sc.tl.umap(ref_embed_adata)

    test_embed_adata = scg.tasks.embed_data(
        test_adata,
        model_dir,
        gene_col=gene_col,
        obs_to_save=cell_type_key,
        batch_size=64,
        return_new_adata=True,
    )

    return kNN(ref_embed_adata, test_embed_adata, test_adata, cell_type_key)

In [48]:
args = {
    "model_dir": os.path.join(BASE_DIR, "scGPT_human"),
    "adata_dir": DATA_DIR,
    "dataset_name": None,
    "cell_type_key": "cell_type",
    "gene_col": "feature_name",
}

datasets = ["TabulaSapiens", "PBMC"] if args["dataset_name"] is None else [args["dataset_name"]]

for dataset in datasets:
    print(f"Running Reference Mapping for {dataset}")
    acc, f1, f1_macro, per_cell_type_acc = reference_mapping(
        args["model_dir"], args["adata_dir"], dataset, args["cell_type_key"], args["gene_col"]
    )
    print(f"Dataset: {dataset}")
    print(f"Accuracy: {acc}")
    print(f"Micro F1 Score: {f1}")
    print(f"Macro F1 Score: {f1_macro}")
    print("Done!")


Running Reference Mapping for TabulaSapiens


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


Converting Ensembl IDs to gene symbols using processed_mygene_val_tabulasapiens_ensembl_to_genesymbol.csv...
Successfully converted Ensembl IDs to gene symbols and added 'feature_name' column.
Converting Ensembl IDs to gene symbols using processed_mygene_val_tabulasapiens_ensembl_to_genesymbol.csv...
Successfully converted Ensembl IDs to gene symbols and added 'feature_name' column.
scGPT - INFO - match 19067/19331 genes in vocabulary of size 60697.


Embedding cells: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 858/858 [04:20<00:00,  3.29it/s]


scGPT - INFO - match 19067/19331 genes in vocabulary of size 60697.


Embedding cells: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 901/901 [04:28<00:00,  3.36it/s]


Dataset: TabulaSapiens
Accuracy: 0.6666666666666666
Micro F1 Score: 0.7666793946126075
Macro F1 Score: 0.5324009979708663
Done!
Running Reference Mapping for PBMC


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


Converting Ensembl IDs to gene symbols using processed_mygene_val_pbmc_ensembl_to_genesymbol.csv...
Successfully converted Ensembl IDs to gene symbols and added 'feature_name' column.
Converting Ensembl IDs to gene symbols using processed_mygene_val_pbmc_ensembl_to_genesymbol.csv...
Successfully converted Ensembl IDs to gene symbols and added 'feature_name' column.
scGPT - INFO - match 19067/19331 genes in vocabulary of size 60697.


Embedding cells: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 528/528 [02:38<00:00,  3.34it/s]


scGPT - INFO - match 19067/19331 genes in vocabulary of size 60697.


Embedding cells: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 2965/2965 [14:38<00:00,  3.38it/s]


Dataset: PBMC
Accuracy: 0.0
Micro F1 Score: 0.433588397731824
Macro F1 Score: 0.2814081086390673
Done!
