In [None]:
import sys
sys.path.append("./")

import scanpy as sc
import pandas as pd
from scipy.sparse import csr_matrix
from scipy.stats import entropy, itemfreq
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture as GMM
from sklearn.metrics import adjusted_rand_score as ARI
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.metrics import silhouette_score
import os
import argparse
from sklearn.preprocessing import LabelEncoder
import numpy as np
import csv
import scanorama

In [None]:
def clustering_scores(labels, newX, batch_ind):
    n_labels = labels.nunique()
    labels_pred = KMeans(n_labels, n_init=200).fit_predict(newX)
    asw_score = silhouette_score(newX, batch_ind)
    nmi_score = NMI(labels, labels_pred)
    ari_score = ARI(labels, labels_pred)
        
    return asw_score, nmi_score, ari_score   

def knn_purity(adata, label_key, n_neighbors=30):
    labels = LabelEncoder().fit_transform(adata.obs[label_key].to_numpy())

    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(adata.X)
    indices = nbrs.kneighbors(adata.X, return_distance=False)[:, 1:]
    neighbors_labels = np.vectorize(lambda i: labels[i])(indices)

    # pre cell purity scores
    scores = ((neighbors_labels - labels.reshape(-1, 1)) == 0).mean(axis=1)
    res = [
        np.mean(scores[labels == i]) for i in np.unique(labels)
    ]  # per cell-type purity

    return np.mean(res)

def entropy_batch_mixing(latent, labels, n_neighbors=50, n_pools=50, n_samples_per_pool=100):
    
    def entropy_from_indices(indices):
        return entropy(np.array(itemfreq(indices)[:, 1].astype(np.int32)))

    neighbors = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(latent)
    indices = neighbors.kneighbors(latent, return_distance=False)[:, 1:]
    batch_indices = np.vectorize(lambda i: labels[i])(indices)

    entropies = np.apply_along_axis(entropy_from_indices, axis=1, arr=batch_indices)

    # average n_pools entropy results where each result is an average of n_samples_per_pool random samples.
    if n_pools == 1:
        score = np.mean(entropies)
    else:
        score = np.mean([
            np.mean(entropies[np.random.choice(len(entropies), size=n_samples_per_pool)])
            for _ in range(n_pools)
        ])    
    
    return score



In [None]:
DATASETS = {
    "pancreas": {"name": "pancreas", "batch_key": "study", "cell_type_key": "cell_type",
                 "target": ["Pancreas SS2", "Pancreas CelSeq2"]},
    "brain": {"name": "mouse_brain", "batch_key": "study", "cell_type_key": "cell_type",
              "target": ["Tabula_muris", "Zeisel"]}
}

In [None]:
sc.settings.autosave = True

In [None]:
for data in ["brain", "peancreas"]:
    data_dict = DATASETS[data]
    data_name = data_dict['name']
    batch_key = data_dict['batch_key']
    cell_type_key = data_dict['cell_type_key']
    target_batches = data_dict['target']

    adata = sc.read(f"./data/{data_name}_normalized.h5ad")


    adata.obs['cell_types'] = adata.obs[cell_type_key]
    
    os.makedirs(f"./results/Scanorama/{data_name}/", exist_ok=True)

    for i in range(5):
        scores = []
        for subsample_frac in [0.1, 0.2, 0.4, 0.6, 0.8, 1.0]:
            final_adata = None
            for target in target_batches:
                adata_sampled = adata[adata.obs[batch_key] == target, :]
                keep_idx = np.loadtxt(f'./data/subsample/{data_name}/{target}/{subsample_frac}/{i}.csv', dtype='int32')
                adata_sampled = adata_sampled[keep_idx, :]

                if final_adata is None:
                    final_adata = adata_sampled
                else:
                    final_adata = final_adata.concatenate(adata_sampled)
            
            adata_list = []
            labels_array = np.array([])
            batch_array = np.array([])
            
            for j in final_adata.obs[batch_key].unique():
                adata_list.append(final_adata[final_adata.obs[batch_key] == j, :])
                labels_array = np.concatenate((labels_array, final_adata.obs[cell_type_key][final_adata.obs[batch_key] == j]))
                batch_array = np.concatenate((batch_array, final_adata.obs[batch_key][final_adata.obs[batch_key] == j]))
            
            print(f"{subsample_frac}-before")
            sc.pp.neighbors(final_adata)
            sc.tl.umap(final_adata)
            sc.settings.figdir = f"./results/Scanorama/{data_name}/{i}/{subsample_frac}/before"
            sc.pl.umap(final_adata, color=[batch_key, cell_type_key], wspace=.5)
                
            corrected = scanorama.correct_scanpy(adata_list)
            final_adata = None
            for corrected_adata in corrected:
                if final_adata is None:
                    final_adata = corrected_adata
                else:
                    final_adata = final_adata.concatenate(corrected_adata)
            
            final_adata.obs[batch_key] = batch_array
            final_adata.obs[cell_type_key] = labels_array
            
            sc.tl.pca(final_adata, svd_solver="arpack", n_comps=10)
            final_adata = sc.AnnData(X=final_adata.obsm['X_pca'], obs=final_adata.obs)
            
            asw_score, nmi_score, ari_score = clustering_scores(final_adata.obs[cell_type_key], final_adata.X, final_adata.obs[batch_key])
            ebm_scores = []
            for k in [15, 25, 50, 100, 200, 300]:
                ebm_scores.append(entropy_batch_mixing(final_adata.X, final_adata.obs[batch_key], n_neighbors=k))
            
            knn_scores = []
            for k in [15, 25, 50, 100, 200, 300]:
                knn_scores.append(knn_purity(final_adata, label_key=cell_type_key, n_neighbors=k))
            
            scores.append([subsample_frac, asw_score, ari_score, nmi_score] + ebm_scores + knn_scores)
            
                
            print(f"{subsample_frac}-after")
            sc.pp.neighbors(final_adata)
            sc.tl.umap(final_adata)
            sc.settings.figdir = f"./results/Scanorama/{data_name}/{i}/{subsample_frac}/after"
            sc.pl.umap(final_adata, color=[batch_key, cell_type_key], wspace=.5)
            final_adata.write(f"./results/Scanorama/{data_name}/{i}/{subsample_frac}/result_adata.h5ad")
            
        scores = np.array(scores)
        np.savetxt(f"./results/Scanorama/{data_name}/{i}.log", X=scores, delimiter=",")