In [None]:

%load_ext autoreload
%matplotlib inline
%autoreload 2
%config Completer.use_jedi = True
import matplotlib as mpl
from matplotlib import pyplot as plt
mpl.rc("figure", dpi=100)
from warnings import simplefilter
# ignore all future warnings
simplefilter(action='ignore')
import scalp
from scalp.output import draw
import lmz
import numpy as np

In [None]:
datasets = scalp.data.loaddata_scib(scalp.test_config.scib_datapath,  maxdatasets=10,  maxcells = 1000 )
dataset = datasets[1]

# scanorama

In [None]:
import scanpy as sc

def scanorama(dataset, dim = 2):
    dataset = scalp.pca.pca(dataset)
    dataset = scalp.mnn.scanorama(dataset)
    [ds.obsm.pop('umap', None)for ds in dataset]
    scalp.umapwrap.adatas_umap(dataset, label ='umap', from_obsm ='scanorama', dim= dim)
    return scalp.transform.stack(dataset)

st = scanorama(dataset)
sc.pl.umap(st, color=['batch', 'label'])

# UMAP ONLY 

In [None]:
import scanpy as sc
def umaponly(dataset,dim=2):
    dataset = scalp.pca.pca(dataset)
    [ds.obsm.pop('umap', None)for ds in dataset]
    scalp.umapwrap.adatas_umap(dataset, label ='umap', from_obsm ='pca40', dim =dim)
    return scalp.transform.stack(dataset)
    
stack = umaponly(dataset)
sc.pl.umap(stack, color=['batch', 'label'])

In [None]:
import scanpy as sc
def bbknn(dataset, dim = 2):
    dataset = scalp.pca.pca(dataset)
    [ds.obsm.pop('umap', None)for ds in dataset]
    dataset = scalp.mnn.bbknnwrap(dataset, dim = dim)
    stack = scalp.transform.stack(dataset)
    return stack

stack = bbknn(dataset)
sc.pl.umap(stack, color=['batch', 'label'])

# Combat 

In [None]:
import scanpy as sc
def combat(dataset, dim = 2):
    [ds.obsm.pop('umap', None)for ds in dataset]
    dataset = scalp.mnn.combat(dataset)
    scalp.umapwrap.adatas_umap(dataset, label ='umap', from_obsm ='combat', dim =dim)
    return scalp.transform.stack(dataset)
stack= combat(dataset)
sc.pl.umap(stack, color=['batch', 'label'])

# SCALP

In [None]:
import scanpy as sc
def Scalp(dataset, dim = 2):
    parm = {'intra_neigh': 8, 'intra_neighbors_mutual': False,
            'inter_neigh': 2, 'add_tree': True,
                  'copy_lsa_neighbors': False, 'inter_outlier_threshold': 0.95, 'inter_outlier_probabilistic_removal': False}
    dataset, grap = scalp.mkgraph(dataset,**parm)
    scalp.umapwrap.graph_umap(dataset,grap,label = 'umap', n_components = dim)
    return scalp.transform.stack(dataset)
    
stack = Scalp(dataset)
sc.pl.umap(stack, color=['batch', 'label'])

In [None]:
import ubergauss.tools as ut
from lmz import *

funcs = [scanorama, umaponly, bbknn, combat, Scalp] 

fuid = Range(funcs)
dataid = Range(datasets) 
tasks = [(f,d) for f in fuid for d in dataid]

def run(fd):
    f,d = fd
    fun = funcs[f]
    dat = datasets[d]
    stack = fun(dat,8)
    return stack
    
mydata = ut.xxmap(run, tasks) 

In [None]:
from scib.metrics import metrics
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import balanced_accuracy_score
def score(dataset):
    # ds2 = dataset.copy()
    # ds2.X = ds2.obsm['umap']
    # https://scib.readthedocs.io/en/latest/api.html#biological-conservation-metrics
    embed = 'umap' if 'umap' in dataset.obsm else 'X_umap'
    sc =  metrics(dataset, dataset, 'batch', 'label', embed = embed,
                       isolated_labels_asw_=True, silhouette_=True, hvg_score_=True, graph_conn_=True,
           pcr_=True,
             isolated_labels_f1_=True,
             trajectory_=False,
             nmi_=True,
            ari_=True )
    return dict(sc)[0]
    
def score_lin(dataset):
    y = dataset.obs['label']
    X = dataset.obsm['umap'] if 'umap' in dataset.obsm else dataset.obsm['X_umap']
    prediction = SGDClassifier().fit(X,y).predict(X)
    return accuracy_score(y , prediction )
    
def score_lin_batch(dataset):
    # do this per cell line
    def acc(label): 
        instances = dataset.obs['label'] == label 
        tmp_dataset = dataset[instances]
        y = tmp_dataset.obs['batch']
        if len(np.unique(y)) < 2:
            return np.nan
        X = tmp_dataset.obsm['umap'] if 'umap' in tmp_dataset.obsm else tmp_dataset.obsm['X_umap']
        prediction = SGDClassifier().fit(X,y).predict(X)
        return balanced_accuracy_score(y , prediction, adjusted=True )
    
    # scores = np.array([1-acc(l) for l in np.unique(dataset.obs['label'])]) ### !!!! 
    # weight by label counts .. 
    return np.nanmean([1-acc(l) for l in np.unique(dataset.obs['label'])])
    
scoredics_lb = ut.xxmap(score_lin_batch, mydata)
scoredics_scib = ut.xxmap(score, mydata)
scoredics_l = ut.xxmap(score_lin, mydata)

# score_lin_batch and score_lin -> pareto comparison

In [None]:
import pandas as pd
import seaborn as sns
# this only applies if we dont use the scib score fucntion
funcs = 'scanorama, umaponly, bbknn, combat, Scalp'.split(', ')

results = [ {"method":funcs[f], 'score':s, 'dataset':d, 'target':'label' } for s,(f,d) in zip(scoredics_l,tasks)]
results += [ {"method":funcs[f], 'score':s, 'dataset':d, 'target':'batch'} for s,(f,d) in zip(scoredics_lb,tasks)]
df = pd.DataFrame(results)
sns.barplot(data=df, y = 'score', x = 'method', errorbar = 'sd', hue='target')
plt.show()
ours = df.pivot_table(index='method', columns='target', values='score')

In [None]:
import lmz
# this is for SCIB scoring 
def doit(e):
    d = dict(e)
    for k in list(d.keys()):
        if np.isnan(d[k]):
            d.pop(k)
    d.pop('hvg_overlap',0)
    return d
scoredicts = lmz.Map(doit, scoredics_scib)

funcs = 'scanorama, umaponly, bbknn, combat, Scalp'.split(', ')
results = [ [{"method":funcs[f], 'score':ss, 'dataset':d, 'metric':scrmeth } 
             for (scrmeth,ss) in s.items() ]for s,(f,d) in zip(scoredicts,tasks)]
results = lmz.Flatten(results)

df = pd.DataFrame(results)
sns.barplot(data=df, y = 'score', x = 'method', errorbar = 'sd', hue = 'metric')
plt.legend(loc='right', bbox_to_anchor=(1.85, 0.5), ncol=1)
plt.show()

In [None]:

def split_scib_scores(dicts):
    batchwords = 'PCR_batch ASW_label/batch graph_conn'.split()
    
    def split(d):
        b = np.mean([v for k,v in d.items() if k in batchwords ])
        a = np.mean([v for k,v in d.items() if not k in batchwords ])
        return a,b
        
    scores = lmz.Map(split, dicts) 
    return lmz.Transpose(scores)    
  

scr_l, scr_b = split_scib_scores(scoredicts)
results = [ {"method":funcs[f], 'score':s, 'dataset':d, 'target':'bioconservation_scib_avg' } for s,(f,d) in zip(scr_l,tasks)]
results += [ {"method":funcs[f], 'score':s, 'dataset':d, 'target':'batch_scib_avg'} for s,(f,d) in zip(scr_b,tasks)]
df = pd.DataFrame(results)
sns.barplot(data=df, y = 'score', x = 'method', errorbar = 'sd', hue='target')
plt.show()
theirs = df.pivot_table(index='method', columns='target', values='score')

In [None]:
from ubergauss.optimization import pareto_scores
pareto_scores(df)

In [None]:
results

In [None]:
np.corrcoef(ours.batch.values, theirs.batch_scib_avg.values)[0,1], np.corrcoef(ours.label.values, theirs.bioconservation_scib_avg.values)[0,1]