In [None]:
import sys
sys.path.append("./")

import scanpy as sc
import pandas as pd
from scipy.sparse import csr_matrix
from scipy.stats import entropy, itemfreq
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture as GMM
from sklearn.metrics import adjusted_rand_score as ARI
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.metrics import silhouette_score
import os
import argparse
from sklearn.preprocessing import LabelEncoder
import numpy as np
import csv
import scanorama
import time


In [None]:
DATASETS = {
    "pancreas": {"name": "pancreas", "batch_key": "study", "cell_type_key": "cell_type",
                "batches": ["Pancreas inDrop", "Pancreas CelSeq2", "Pancreas CelSeq", "Pancreas Fluidigm C1", "Pancreas SS2"]},
    "brain": {"name": "mouse_brain", "batch_key": "study", "cell_type_key": "cell_type",
             "batches": ["Saunders", "Rosenberg", "Tabula_muris", "Zeisel"]}
}

In [None]:
sc.settings.autosave = True

# Scanorama

In [None]:
for data in ["pancreas", "brain"]:
    data_dict = DATASETS[data]
    data_name = data_dict['name']
    batch_key = data_dict['batch_key']
    cell_type_key = data_dict['cell_type_key']

    adata = sc.read(f"./data/{data_name}_normalized.h5ad")


    adata.obs['cell_types'] = adata.obs[cell_type_key]
    
    os.makedirs(f"./results/Scanorama/{data_name}/", exist_ok=True)
    
    final_adata = None
    adata_list = []
    labels_array = np.array([])
    batch_array = np.array([])
    
    for b in data_dict['batches']:
        batch_adata = adata[adata.obs[batch_key] == b, :]
        if final_adata is None:
            final_adata = batch_adata
        else:
            final_adata = final_adata.concatenate(batch_adata)
        
        adata_list.append(batch_adata)
        labels_array = np.concatenate((labels_array, batch_adata.obs[cell_type_key]))
        batch_array = np.concatenate((batch_array, batch_adata.obs[batch_key]))
    
    print(f"before")
    sc.pp.neighbors(final_adata)
    sc.tl.umap(final_adata)
    sc.settings.figdir = f"./results/Scanorama/{data_name}/before"
    sc.pl.umap(final_adata, color=[batch_key, cell_type_key], wspace=.5)
    
    t1 = time.time()
    
    corrected = scanorama.correct_scanpy(adata_list)
    
    t2 = time.time()

    
    final_adata = None
    for corrected_adata in corrected:
        if final_adata is None:
            final_adata = corrected_adata
        else:
            final_adata = final_adata.concatenate(corrected_adata)
            
    final_adata.obs[batch_key] = batch_array
    final_adata.obs[cell_type_key] = labels_array
            
    sc.tl.pca(final_adata, svd_solver="arpack", n_comps=20)
    final_adata = sc.AnnData(X=final_adata.obsm['X_pca'], obs=final_adata.obs)  
        
    
    print(f"after")
    sc.pp.neighbors(final_adata)
    sc.tl.umap(final_adata)
    sc.settings.figdir = f"./results/Scanorama/{data_name}/after"
    sc.pl.umap(final_adata, color=[batch_key, cell_type_key], wspace=.5)
    final_adata.write(f"./results/Scanorama/{data_name}/result_adata.h5ad")
    
    with open(f'./results/scanorama_time_{data_name}.txt', 'w') as f:
        f.write(str(t2-t1))
    