In [None]:
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import pickle
import pandas as pd
%matplotlib inline
import seaborn as sns

In [None]:
"""
Utility functions and classes for cross species
analysis

@yhr91
"""

from sklearn.metrics import euclidean_distances
import matplotlib.pyplot as plt
import matplotlib
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_distances
from scipy.stats import spearmanr
import plotly.express as px
import numpy as np
import pandas as pd
import warnings
import scanpy as sc
from sklearn.metrics import adjusted_mutual_info_score, rand_score
from collections import Counter
from scipy.stats import mode
import operator

# --------

class cross_species_acc():
    """
    Class for calculating cross species accuracy metrics
    """

    def __init__(self, adata, base_species='human', 
                 target_species='mouse', label_col='CL_class_coarse', 
                 metric='cosine', medoid=False, space='raw'):

        self.adata = adata
        self.base_species = base_species
        self.target_species = target_species
        self.label_col = label_col
        self.metric = metric
        self.medoid=medoid
        self.space=space
        
        # Calculate accuracy metrics
        self.calc_cross_species_label_matches()

    def find_all_species_centres(self):
        """
        Finds all species-specific centroids given an AnnData object
        """
        if self.space == 'umap':
            key = 'X_umap'
        elif self.space == 'samap':
            key = 'X_umap_samap'
        elif self.space == 'scanorama':
            key = 'X_scanorama'
        elif self.space == 'harmony':
            key = 'X_harmony'
        
        centres = {}
        centres['size'] = {}
        base_cluster_sizes = {}
        for species in self.adata.obs['species'].unique():
            centres[species] = {}
            species_set = self.adata[self.adata.obs['species']==species]
            
            for l in species_set.obs[self.label_col].unique():
                subset = species_set[species_set.obs[self.label_col] == l]
                
                # If space is not raw then use the right obsm column
                if self.space != 'raw':
                    subset_data = subset.obsm[key]
                else:
                    subset_data = subset.X.toarray()

                # Deal with exceptions
                if len(subset)<1:
                    continue
                elif len(subset)==1:
                    centres[species][l] = subset_data[0]
                
                # Use centroid or medoid
                centroid = np.mean(subset_data, 0)
                if self.medoid:
                    centres[species][l] =\
                        self.get_medoid(subset_data, centroid)
                else:
                    centres[species][l] = centroid
                    
                # This is for normalization of distances
                if species == self.base_species:
                    dist_mat = euclidean_distances(subset_data)
                    centres['size'][l] = np.max(dist_mat)
            
        return pd.DataFrame(centres).dropna()

    
    def calc_cross_species_label_matches(self):
        """
        Given Anndata object, returns:
        - matches: number of cluster centres in base species that 
        have the same cluster label in the target species as nn
        - dist: 'normalized' distance between cluster centre of base 
        species and the cluster centre with the same label in target 
        species
        
        TODO: This is not generalized to more than 2 species
        """
        warnings.filterwarnings("ignore")
        
        centres = self.find_all_species_centres()
        dist = 0
        norm_dist = 0
        matches = 0
        matches_names = []
        matches_names_all = []
        target_centres = np.vstack(
                   centres.loc[:,self.target_species].values)
        
        for idx, ctype in enumerate(centres.index):
            base = centres.loc[ctype, self.base_species]
            base_targets = np.vstack([base, target_centres])

            if self.metric=='cosine':
                distances = cosine_distances(base_targets)[0][1:]
                
            pred_match = np.argmin(distances)
            if  pred_match == idx:
                matches += 1
                matches_names.append(ctype)
            
            matches_names_all.append((ctype,
                         centres.index[pred_match]))
            dist += distances[idx]
            norm_dist += distances[idx]/centres.loc[ctype, 'size']

        self.cross_species_label_dist = dist
        self.cross_species_label_norm_dist = norm_dist
        self.cross_species_label_matches = matches
        self.cross_species_label_matches_names = matches_names
        self.cross_species_label_matches_names_all = matches_names_all
        
        warnings.filterwarnings("always")
     
    
    def get_medoid(self, data, centroid):
        dists = euclidean_distances(np.vstack([centroid,data]))[0]
        return data[np.argsort(dists)[1]-1]
        
        
# --------


class embedding_CL_comparison():
    """
    Class for comparing embedding with cell ontology
    """

    def __init__(self, adata, label_col='CL_class_coarse', CL_ID_col='CL_ID_coarse',
                 metric='cosine', features='raw'):

        warnings.filterwarnings("ignore")
        self.adata = adata
        self.label_col = label_col
        self.CL_ID_col = CL_ID_col
        self.metric = metric
        self.features = features
        self.labels = self.adata.obs[self.label_col].unique()
        self.centres = []
        self.centres_ranked = []
        self.CL_centres_ranked = []

        # Get centres, nns and ranks
        self.get_centre_ranks()
        self.get_CL_ranks()

        # Calculate metrics
        self.spearman_corr = {}
        self.hits_at_k = {}
        for id_ in self.labels:         
            self.spearman_corr[id_] = spearmanr(self.CL_centres_ranked[id_], 
                self.centres_ranked[id_])[0]
            self.hits_at_k[id_] = self.get_hits_topk(self.CL_centres_ranked[id_], 
                self.centres_ranked[id_])
            
        warnings.filterwarnings("always")

        
    # Implement cluster centroid

    def find_centre(self, cluster, medioid=False):
        """
        Find cluster centre: either centroid or medioid
        """
        if medioid:
            dist = euclidean_distances(cluster)
            medioid = np.argmin(dist.sum(0))
            return cluster[medioid].toarray()

        else:
            return np.mean(cluster,0)
    
    
    def get_outlier_idx(self, CL_centres, centres, k=10):
        """
        Get top or bottom ranked nn
        """
        outliers = []
        for i, pair in enumerate(list(zip(CL_centres, centres))):
            if pair[0] < k or pair[1] < k:
                outliers.append(i)
        return outliers

    
    def get_hits_topk(self, CL_centres, centres, k=10):
        """
        Get numbers of matches within top k
        """
        return len(set(CL_centres[:k]).intersection(set(centres[:k])))
        

    def get_centre_ranks(self):
        """
        Get nn ranks for cluster centres
        """
        for cell_type in self.labels:
            if self.features=='raw':
                self.centres.append(self.find_centre(
                    self.adata[self.adata.obs[self.label_col] == cell_type].X))
        self.centres = np.vstack(self.centres)
                
        if self.metric=='euclidean':
            centres_dist = euclidean_distances(self.centres)

        if self.metric=='cosine':
            centres_dist = cosine_distances(self.centres)

        self.centres_ranked = {k:v for k,v in zip(
            self.labels, np.argsort(centres_dist))}
        
    def get_CL_ranks(self):
        """
        Get nn ranks for cell ontology cluster centres
        """
        all_CL_distances = pd.read_csv('/dfs/project/cross-species/data/lung/shared/CL_similarity_RW.csv',
                                       index_col=0)

        CL_sim_matrix = all_CL_distances.loc[self.adata.obs[self.CL_ID_col].unique(),
                                             self.adata.obs[self.CL_ID_col].unique()]
        
        ID_dict = self.adata.obs.set_index(self.label_col).to_dict()['CL_ID_coarse']
        inv_ID_dict = {v: k for k, v in ID_dict.items()}

        self.CL_centres_ranked = {inv_ID_dict[k]:v for k,v in zip(self.adata.obs['CL_ID_coarse'].unique(),
                                                           np.argsort(-CL_sim_matrix.values))}

    
    def plot_rank_scatter(self):
        """
        Create rank scatter plot between embedding nn and CL nn
        """
        fig, axs = plt.subplots(5, 5, sharex=True, sharey=True, figsize=[20,15])
        it = 0
        spearman_corr = {}
        hits_at_k = {}
        outlier=False

        for i in range(5):
            for j in range(5):
                if it == len(self.labels):
                    break

                id_ = self.labels[it]
                if outlier:
                    outlier_idx = get_outlier_idx(self.CL_centres_ranked[id_], 
                        self.centres_ranked[id_])
                    axs[i, j].scatter(self.CL_centres_ranked[id_][outlier_idx], 
                    self.centres_ranked[id_][outlier_idx])
                else:
                    axs[i, j].scatter(self.CL_centres_ranked[id_], 
                    self.centres_ranked[id_])
                axs[i, j].set_title(id_)
                axs[i, j].plot([0,32],[0,32], 'k')   
                it += 1
                
    def plot_hits_at_k(self):
        plot_df = pd.DataFrame.from_dict(self.hits_at_k, orient='index')
        plot_df = plot_df.rename(columns={0:'Value'})
        plot_df = plot_df.sort_values('Value')

        plt.figure(figsize=[8,12])
        plt.barh(plot_df.index, plot_df['Value'])
        plt.ylabel('Cell Type')
        plt.xlabel('Hits @ k')
        plt.title('Hits @ k (Embedding space compared to Cell Ontology)')
        plt.xlim([0,10])
        
        
    def plot_spearman(self):
        plot_df = pd.DataFrame.from_dict(self.spearman_corr, orient='index')
        plot_df = plot_df.rename(columns={0:'Value'})
        plot_df = plot_df.sort_values('Value')

        plt.figure(figsize=[8,12])
        plt.barh(plot_df.index, plot_df['Value'])
        plt.ylabel('Cell Type')
        plt.xlabel('Spearman Correlation')
        plt.title('Spearman Correlation (Embedding space compared to Cell Ontology)')
        plt.xlim([-1,1])
          
        
# --------

## KNN analysis per cell
## TODO integrate these functions into cross_species_acc class

def get_knn_label(cell_names, adata, col):
    """
    Returns majority class labels of nearest neighbors. 
    Will return random label in case of tie
    """
    
    return adata[cell_names].obs[col].value_counts().index[0]


def cross_species_knn_all(adata, k=1, species='human', space='raw',
                          col = 'cell_type', metric='euclidean',
                          verbose = False, consider_same_species=False):
    """Runs cross species k nearest neighbor on all cells
    """

    # Create distance matrix
    if space == 'raw':
        X = adata.X
    elif space == 'umap':
        X = adata.obsm['X_umap']
    elif space == 'samap':
        X = adata.obsm['X_umap_samap']
    elif space == 'scanorama':
        X = adata.obsm['X_scanorama']
    elif space == 'harmony':
        X = adata.obsm['X_harmony']
    
    # Slow step
    if metric == 'euclidean':
        dist_mat = euclidean_distances(X)
    elif metric == 'cosine':
        dist_mat = cosine_distances(X)
        
    if consider_same_species:
        # Get indices for species and nonspecies cells
        species_idx = np.where(adata.obs['species']==species)[0]
        adata.obs['temp_label'] = adata.obs['species'].astype(str) +\
                            '_' + adata.obs[col].astype(str)

        nns = []
        for idx in species_idx:
            curr_temp_label = adata.obs['temp_label'][idx]
            row = dist_mat[idx,:]
            possible_nbrs = np.where(adata.obs['temp_label'] != curr_temp_label)[0]
            
            row = row[possible_nbrs]
            nns.append(possible_nbrs[np.argpartition(row, k)[:k]])
            
        nbrs = [(adata.obs[col][x], Mode(adata.obs[col][y].astype('str').values)) 
                        for x,y in zip(species_idx, nns)]
    
    else:
        # Get indices for species and nonspecies cells
        species_idx = np.where(adata.obs['species']==species)[0]
        nonspecies_idx = np.where(adata.obs['species'] != species)[0]

        # Slow step
        reduced_dist_mat = dist_mat[species_idx,:][:,nonspecies_idx]
        
        nns = [list(nonspecies_idx[y]) 
               for y in np.argpartition(reduced_dist_mat, k)[:,:k]]

        nbrs = [(adata.obs[col][x], Mode(adata.obs[col][y].astype('str').values)) 
                        for x,y in zip(species_idx, nns)]
    
    return nbrs


def cluster_knn(cluster_knn_df, label):
    """
    Given majority k nearest cross species neigbhor class for each cell, 
    identifies the k nearest neighbors for the given cluster
    """
    
    list_ = list(cluster_knn_df[cluster_knn_df['Source_Cell']==label].
                value_counts().items())
    x = pd.DataFrame([(s[0][1],s[1]) for s in list_])
    x['Source_Cluster'] = list_[0][0][0]
    x = x.rename(columns={0:'Cross_Species_KNN_Label', 1:'Score'})
    x['Score'] = x['Score']/x['Score'].sum()
    return x


def cluster_knn_all(all_nbrs):
    """
    Given majority k nearest cross species neigbhor class for each cell, 
    identifies the k nearest neighbors for all clusters
    """
    cluster_knn_df = pd.DataFrame(all_nbrs)
    cluster_knn_df = cluster_knn_df.rename(columns={0:'Source_Cell',1:'Cross_Species_KNN'})

    return [cluster_knn(cluster_knn_df, c) 
         for c in cluster_knn_df['Source_Cell'].unique()]   


def plot_cluster_knn_bar(df, source='human', other='mouse', ax=None, title=None):
    """
    Creates a stacked bar plot to identify majority k nearest neighbors for
    a given cluster, on a specific axis
    """
    if ax is None:
        fig, (ax) = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(6, 17), frameon=False)
    bars = defaultdict(int)
    colors = defaultdict(int)
    tick = -1; 
    tick_pos = {}
    colors_list = ['','r','g','b','y','k','purple','g','b','y','k','g','b','y',
                                                       'k','g','b','y','k',
                                                       'k','g','b','y','k',
                                                       'k','g','b','y','k'] + ["w"]*50
    df = df.sort_values(['Source_Cluster', 'Score'], ascending=False)

    for i in df.iterrows():
        color = 'k'
        x = i[1]['Source_Cluster']
        y = i[1]['Score']        

        if i[1]['Cross_Species_KNN_Label'] == i[1]['Source_Cluster']:
            color= 'w'
        left = bars[x]
        bars[x] = bars[x] + y
        colors[x] = colors[x] + 1
        if colors_list[colors[x]] == 'r':    
            tick = tick+1
            tick_pos[x] = tick
            if bars[x]>=0.50:
                ax.text(0.1,tick-0.2,i[1]['Cross_Species_KNN_Label'], color=color)
        ax.barh(tick_pos[x], y, left=left, color=colors_list[colors[x]], alpha=0.5)
        
    keys = sorted(tick_pos.keys())
    vals = [tick_pos[k] for k in keys]
    ax.set_yticks(vals, keys)
    ax.set_ylabel(source)
    ax.set_xlabel('Percentage of cells with cross-species KNN class')
    if title is not None:
        ax.set_title(title)
    
    return 

## ---------------------------------
## Alignment scores
## -----------------------------------

def alignment_score(fname, col='cell_type', space='raw', k=1,
                    species='human', consider_same_species=False):
    adata = sc.read_h5ad(fname)
    if space=='umap':
        sc.pp.pca(adata, n_comps=50)
        sc.pp.neighbors(adata, n_neighbors=15)
        sc.tl.umap(adata)
    all_nbrs = cross_species_knn_all(adata, col=col, metric='cosine', space=space, k=k,
                                     species=species, consider_same_species=consider_same_species)
    all_cluster_nbrs  = [i for i in cluster_knn_all(all_nbrs)]
    all_cluster_nbrs = pd.concat(all_cluster_nbrs).reset_index(drop=True) 
    return all_cluster_nbrs

def compare_matches(alignments, true_alignments):
    aligns = alignments.merge(true_alignments, on=['Source_Cluster', 'Cross_Species_KNN_Label'])
    #aligns = aligns.merge(true_alignments, on=['Source_Cluster'], how='outer').fillna(0)
    
    return aligns

def score_matches(alignments, true_alignments, thresh=0.5, ret_matches=False):
    matches = compare_matches(alignments, true_alignments)
    if ret_matches:
        return matches[matches['Score']>thresh]
    else:
        return sum(matches['Score']>thresh)
    return

def create_comparison_plot_df(all_cluster_nbrs1, all_cluster_nbrs2, true_map):
    
    knn_scores_1 = compare_matches(all_cluster_nbrs1, true_map)
    knn_scores_2 = compare_matches(all_cluster_nbrs2, true_map)

    plot_df = knn_scores_1.merge(knn_scores_2, on='Source_Cluster') 
    return plot_df

def get_comparison_plot(plot_df, bars=2, labels_=['Method1', 'Method2']):
    plt.figure(figsize=[20,5])
    ax=plt.gca()

    labels = plot_df['Source_Cluster'].values
    x = np.arange(len(labels))  # the label locations
    width = 0.35  # the width of the bars

    #fig, ax = plt.subplots()
    rects1 = ax.bar(x - width/2, plot_df['Score_x'].values, width, label=labels_[0])
    if bars == 2:
        rects2 = ax.bar(x + width/2, plot_df['Score_y'].values, width, label=labels_[1])

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel('% of nearest neighbors of\n correct cross species label')
    ax.set_title('Cross species cell type alignment')
    plt.xticks(x, labels, rotation='vertical')
    ax.set_xticklabels(labels)
    ax.legend() 

    
def get_cell_alignment(align_df, true_df, count_df):
    align_df['Source_Cluster'] = align_df['Source_Cluster'].astype('str')
    count_df['Source_Cluster'] = count_df['Source_Cluster'].astype('str')
    true_df['Source_Cluster'] = true_df['Source_Cluster'].astype('str')
    align_df['Cross_Species_KNN_Label'] = align_df['Cross_Species_KNN_Label'].astype('str')                  
    true_df['Cross_Species_KNN_Label'] = true_df['Cross_Species_KNN_Label'].astype('str')
    
    df = align_df.merge(true_df, on=['Source_Cluster', 'Cross_Species_KNN_Label'])
    df = df.merge(count_df, on='Source_Cluster', how='outer').fillna(0)
    return sum((df['Score']*df['count'])/sum(df['count']))


def get_alignment_metrics(fname, out_label = 'labels2', orig_label='CL_class_coarse',
                          space='raw', species=['human','mouse'], k=1,
                 true_labels_path=None, ret_matches = False, consider_same_species=False):
    """
    Function for computing evaluation metrics for embedding
    
    Outputs:
    - species1_nn: Number of cross-species label matches (species 1)
    - species2_nn: Number of cross-species label matches (species 2)
    - union_nn: Number of cross-species label matches in either species 
    - mutual_nn: Number of cross-species label matches in both species 
    - cell_score1: Percentage of cells in species 1 with cross species nn of correct label
    - cell_score2: Percentage of cells in species 2 with cross species nn of correct label
    - cell_score_combine: Percentage of cells in both species with cross species nn of correct label
    - centroid_matches_species1: Number of species 1 centroids that are nn with correct species 2 centroid
    - centroid_matches_species2: Number of species 2 centroids that are nn with correct species 1 centroid
    - centroid_matches_union: Union of centroid lists
    - medoid_matches_species1: Number of species 1 medoids that are nn with correct species 2 medoid
    - medoid_matches_species2: Number of species 2 medoids that are nn with correct species 1 medoid
    - medoid_matches_union: Union of medoid lists
    """
    
    # TODO: This is a very ugly function that needs to be made into a class alongwith
    # the functions above it
        
    # Get cross-species only alignments
    alignments = []
    print('Finding nns for species 1')
    alignments.append(alignment_score(fname, out_label, space=space, species=species[0], k=k,
                                     consider_same_species=consider_same_species))
    print('Finding nns for species 2')
    alignments.append(alignment_score(fname, out_label, space=space, species=species[1], k=k,
                                     consider_same_species=consider_same_species))
    
    # Get true labels
    if true_labels_path is None:
        if orig_label == 'CL_class_coarse':
            true_labels_path = '/dfs/project/cross-species/data/lung/shared/true_CL_class_coarse.csv'
        elif orig_label == 'cell_type':
            true_labels_path = '/dfs/project/cross-species/data/lung/shared/true_cell_type.csv'
        else:
            print("ERROR: True labels unavailable for this column!, Please set manually")
            return
    
    true_labels = pd.read_csv(true_labels_path, index_col=0)
    cols = []
    results = {}
    cols.append([c for c in true_labels.columns if species[0] in c][0])
    cols.append([c for c in true_labels.columns if species[1] in c][0])

    true_dfs = []
    true_dfs.append(true_labels.rename(columns={
        cols[0]:'Source_Cluster', cols[1]:'Cross_Species_KNN_Label'}))
    true_dfs.append(true_labels.rename(columns={
        cols[1]:'Source_Cluster', cols[0]:'Cross_Species_KNN_Label'}))

    # Score matches
    matches = []
    matches.append(score_matches(alignments[0], 
                    true_dfs[0], thresh=0.5, ret_matches=True))
    matches.append(score_matches(alignments[1], 
                    true_dfs[1], thresh=0.5, ret_matches=True))
        
    for m in matches:
        m = m.rename(columns = {'Cross_Species_KNN_Label_x':'Cross_Species_KNN_Label'})
        m = m.loc[:,['Score', 'Source_Cluster', 'Cross_Species_KNN_Label']]
    results['species1_nn'] = len(matches[0])
    results['species2_nn'] = len(matches[1])

    # Combine matches
    all_matches = matches[0].merge(matches[1], 
                         left_on=['Source_Cluster', 'Cross_Species_KNN_Label'],
                         right_on=['Cross_Species_KNN_Label', 'Source_Cluster'], how='outer')
    results['union_nn'] = len(all_matches)
    results['mutual_nn'] = len(matches[0].merge(matches[1], 
                         left_on=['Source_Cluster', 'Cross_Species_KNN_Label'],
                         right_on=['Cross_Species_KNN_Label', 'Source_Cluster'], how='inner'))

    # Get per-cell alignment scores
    adata = sc.read_h5ad(fname)
    adata = adata[adata.obs['species'].isin(species)]
    ratio1 = sum(adata.obs['species']==species[0])/len(adata)
    ratio2 = sum(adata.obs['species']==species[1])/len(adata)

    adata1 = adata[adata.obs['species']==species[0]]
    adata2 = adata[adata.obs['species']==species[1]]
    count_dfs = []
    count_dfs.append(pd.DataFrame(adata1.obs[out_label].value_counts()).reset_index().rename(
                    columns={'index':'Source_Cluster', out_label:'count'}))
    count_dfs.append(pd.DataFrame(adata2.obs[out_label].value_counts()).reset_index().rename(
                    columns={'index':'Source_Cluster', out_label:'count'}))

    cell_scores = []
    cell_scores.append(get_cell_alignment(alignments[0], true_dfs[0], count_dfs[0]))
    cell_scores.append(get_cell_alignment(alignments[1], true_dfs[1], count_dfs[1]))
    results['cell_score1'] = cell_scores[0]
    results['cell_score2'] = cell_scores[1]
    results['cell_score_combine'] = ratio1*cell_scores[0] + ratio2*cell_scores[1]

    # Get centroid nn score:
    for centre,flag in [('centroid', False), ('medoid', True)]:
        c_nn1 = cross_species_acc(adata, base_species=species[0], target_species=species[1], 
                          label_col=out_label, medoid=flag, space=space)
        c_nn2 = cross_species_acc(adata, base_species=species[1], target_species=species[0], 
                          label_col=out_label, medoid=flag, space=space)
       
    
        results[centre+'_matches_species1'] = c_nn1.cross_species_label_matches
        results[centre+'_matches_species2'] = c_nn2.cross_species_label_matches
        results[centre+'_matches_union'] = len(set(c_nn1.cross_species_label_matches_names).union(
                                            set(c_nn2.cross_species_label_matches_names)))
    
    if ret_matches == True:
        return (results, matches, alignments)
    
    else:
        return results
    
    
def get_louvain_metrics(fname, label='cell_type'):
    # Compute adjusted rand index for measuring label alignment across species using ground truth information
    
    adata = sc.read_h5ad(fname)
    
    try:
        sc.pp.pca(adata, n_comps=50)
        sc.pp.neighbors(adata, n_neighbors=15)
    except:
        pass
    
    if '_' not in adata.obs[label].values[0]:
        adata.obs[label].str.cat(adata.obs["species"], sep="_")
    
    metrics = {}
    for resolution in [10, 5, 2, 1, 0.8, 0.5, 0.4, 0.2, 0.1, 0.01, 0.001, 0.0001]:
        print('Calculating for resolution: ', str(resolution))
        sc.tl.louvain(adata, resolution=resolution)

        true_clusters = pd.read_csv('/dfs/project/cross-species/data/lung/shared/true_cell_type_clusters.csv', index_col=0)
        true_clusters = true_clusters.merge(adata.obs, left_on='cell_type', right_on=label)
        
        metrics[resolution] = {
            #'ARI':adjusted_rand_score(true_clusters['cluster'], true_clusters['louvain'].astype('int')),
            'RI':rand_score(true_clusters['cluster'], true_clusters['louvain'].astype('int')),
            'AMI': adjusted_mutual_info_score(true_clusters['cluster'], true_clusters['louvain'].astype('int'))
        }

    return metrics


# Maria's cell type reannotation function
def reannotate(adata, source='human', target='mouse', label='cell_type'):
    for resolution in [2, 1, 0.8, 0.6, 0.4, 0.2, 0.1]:
        sc.tl.louvain(adata, resolution)
        louvain_clusters = set(adata.obs['louvain'])

        reannotated = {}
        for c in louvain_clusters:
            current_cluster = adata[adata.obs['louvain']==c]
            if len(set(current_cluster.obs['species']))==2:
                cluster_source = current_cluster[current_cluster.obs['species']==source]
                cluster_target = current_cluster[current_cluster.obs['species']==target]
                c = Counter(cluster_source.obs[label])
                major_cell_type = max(c.items(), key=operator.itemgetter(1))[0]      
                for c in cluster_target.obs_names:
                    if c not in reannotated:
                        reannotated[c] = major_cell_type
    adata_source = adata[adata.obs['species']==source]
    tmp = dict(zip(adata_source.obs_names, adata_source.obs[label]))
    reannotated = {**reannotated, **tmp}
    adata.obs['reannotated_'+source] = [reannotated[c] if c in reannotated else 'None' 
                                  for c in adata.obs_names]
    
def get_reannotation_metrics(fname, label='cell_type', source='human', target='mouse'):
    # This is current specific to mouse reannotation
    
    adata = sc.read_h5ad(fname)
    sc.pp.neighbors(adata)
    
    if '_' in adata.obs[label].values[0]:
        label = 'labels2'
    
    reannotate(adata, source=source, label=label)
    m = adata[adata.obs['species']==target]
    
    true_labels_path = '/dfs/project/cross-species/data/lung/shared/true_cell_type.csv'
    true_labels = pd.read_csv(true_labels_path, index_col=0)
    results_df = true_labels.merge(m.obs, left_on=source+'_cell_type', right_on='reannotated_'+source)
    return np.mean(results_df[target+'_cell_type'] == results_df[label])

## ---------------------------------
## General helper functions
## -----------------------------------

def plotly_scatter(adata, embed = 'X_umap', label= 'cell_type', 
                   hover_cols = ['cell_type', 'species']):
    plot_df = pd.DataFrame(adata.obsm[embed])
    plot_df[label] = adata.obs[label].values
    for c in hover_cols:
        plot_df[c] = adata.obs[c].values
    
    fig = px.scatter(plot_df, x=0, y=1, 
                     hover_name=label,
                     color = label,
                     hover_data=hover_cols)

    fig.show()
    
def Mode(arr):
    # Wrapper for mode
    return mode(arr)[0][0]


# Figure 1

### 1A Random Matrices

In [None]:
np.random.seed(2)
y = np.log(np.random.rand(3, 8) ** 1.1 + 1)
ax, fig = plt.subplots(figsize=(8, 3))
ax = sns.heatmap(y, cmap="Blues", cbar=False);
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.savefig("figures/1a_mat1.svg")

In [None]:
np.random.seed(4)
y = np.log(np.random.rand(4, 9) ** 1.5 + 1)
ax, fig = plt.subplots(figsize=(9, 4))
ax = sns.heatmap(y, cmap="Oranges", cbar=False);
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.savefig("figures/1a_mat2.svg")

In [None]:
np.random.seed(14)
y = np.log(np.random.rand(2, 7) ** 1.75 +1)
ax, fig = plt.subplots(figsize=(7, 2))
ax = sns.heatmap(y, cmap="Greens", cbar=False);
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.savefig("figures/1a_mat3.svg")

# Figure 2 Performance

# Frog / Zebrafish Embryogenesis

In [None]:
centroids_path = '/dfs/project/cross-species/yanay/data/centroids_seeds/metric_results/test256_data_frog_zebrafish_org_lasso_fz_run_l1_0.0_pe_1.0_ESM2_seed_5.h5ad'
samap_path = "/dfs/project/cross-species/yanay/data/final_embeds/samap_fz_noct.h5ad"
samap_labels_path = '/dfs/project/cross-species/yanay/data/final_embeds/samap_fz_ct.h5ad'
scvi_path = "/dfs/project/cross-species/yanay/data/final_embeds/scvi_fz.h5ad"
scanorama_path = "/dfs/project/cross-species/yanay/data/scanorama/fz_seeds/seed_0.h5ad"
harmony_path = "/dfs/project/cross-species/yanay/data/harmony/fz_seeds/seed_0.h5ad"

In [None]:
centroids_ad = sc.read(centroids_path)
samap_ad = sc.read(samap_path)
scvi_ad = sc.read(scvi_path)
scanorama_ad = sc.read(scanorama_path)
harmony_ad = sc.read(harmony_path)
samap_labels_ad = sc.read(samap_labels_path)

In [None]:
display(centroids_ad, samap_ad, scvi_ad, scanorama_ad, harmony_ad)

In [None]:
# Process the UMAPs for each dataset.
# We need to do UMAP since that is what SAMAP does only, so need to be fair

In [None]:
if "X_umap" not in centroids_ad.obsm:
    sc.pp.neighbors(centroids_ad, use_rep="X")
    sc.tl.umap(centroids_ad, random_state=0)
    centroids_ad.write(centroids_path)

In [None]:
if "X_umap" not in scvi_ad.obsm:
    sc.pp.neighbors(scvi_ad, use_rep="X")
    sc.tl.umap(scvi_ad, random_state=0)
    scvi_ad.write(scvi_path)    

In [None]:
if "X_umap" not in scanorama_ad.obsm:
    sc.pp.neighbors(scanorama_ad, use_rep="X")
    sc.tl.umap(scanorama_ad, random_state=0)
    scanorama_ad.write(scanorama_path)

In [None]:
if "X_umap" not in harmony_ad.obsm:
    sc.pp.neighbors(harmony_ad, use_rep="X")
    sc.tl.umap(harmony_ad, random_state=0)
    harmony_ad.write(harmony_path)

In [None]:
samap_ad.obsm["X_umap"] = samap_ad.X
samap_labels_ad.obsm["X_umap"] = samap_labels_ad.X

## 2.A Performance barchart

In [None]:
# fz paths
fz_models_to_paths = {"SATURN\n(Our Model)":"/dfs/project/cross-species/yanay/data/scoring_csvs/no_lasso_rank_esm2_scores.csv",#"/dfs/project/cross-species/yanay/data/scoring_csvs/lasso_fz_scores_scores.csv", #"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_30_seeds_scores.csv",
                      #"Our Model (ESM2)":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_esm2_scores.csv",
                      #"Our Model (ProtXL)":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_protxl_scores.csv",
                      "SAMap \n(Weakly supervised)":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_samap_scores_ct_scores.csv",
                      #"SAMap":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_samap_noct_scores_scores.csv",
                      "Harmony":"/dfs/project/cross-species/yanay/data/scoring_csvs/harmony_fz_scores.csv",
                      "scVI":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_scvi_scores_scores.csv",
                      "Scanorama":"/dfs/project/cross-species/yanay/data/scoring_csvs/scanorama_fz_scores.csv",
                      #"Random":"fz_rand_scores.csv"
                     }


### Figure 2A

In [None]:
width = 8 / 6 * len(fz_models_to_paths)
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, 3), frameon=False, dpi=300)
ax.axhline(y=0.93,  linewidth=2, color='r')
ys = np.array([])
xs = np.array([])
hues = np.array([])


for model, path in fz_models_to_paths.items():
    if path is not "":
        scores = pd.read_csv(path)["Logistic Regression"].values
        ys = np.append(ys, scores)
        xs = np.append(xs, np.array([model] * len(scores)))
        hues = np.append(hues, pd.read_csv(path)["Label"].values)
        
        
sns.boxplot(y=ys, x=xs, ax=ax, hue=hues);
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=10);
ax.set_ylim(0, 1);
ax.tick_params(axis='both', which='minor', labelsize=7);
ax.set_title("Model Transfer Accuracy (Frog/Zebrafish)");

title = "Zebrafish to Frog"
df = pd.DataFrame(columns=["x","y","hue"])
df["x"] = xs
df["y"] = ys
df["hue"] = hues
df_1 = df[df["hue"] == title.lower()].drop(columns="hue")
M1 = df_1.groupby(["x"]).agg("median")
SE1 = df_1.groupby(["x"]).agg(np.std, ddof=1)

fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, 3), frameon=False, dpi=300)
ax.axhline(y=0.93,  linewidth=2, color='r', linestyle="--")

ax.set_xticks(ticks=np.arange(len(M1.index)), labels=M1.index)
ax.set_xticklabels(fz_models_to_paths.keys(), rotation=0, fontsize=10);
ax.grid(False)

for x in fz_models_to_paths.keys():
    ax = plt.errorbar(x=x, y=M1.loc[x].values.flatten(), yerr=(SE1.loc[x].values).flatten(), fmt='o', ls="none", label=x)

fig.legend().remove()
plt.ylim(0,1)
plt.title("Model Label Transfer Accuracy " + f"({title})")
plt.savefig("figures/2a.svg")
plt.show()


title = "Frog to Zebrafish"
df_2 = df[df["hue"] == title.lower()].drop(columns="hue")
M2 = df_2.groupby(["x"]).agg("mean")
SE2 = df_2.groupby(["x"]).agg(np.std, ddof=1)


fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, 3), frameon=False, dpi=300)
ax.axhline(y=0.93,  linewidth=2, color='r', linestyle="--")

ax.set_xticks(ticks=np.arange(len(M1.index)), labels=M1.index)
ax.set_xticklabels(fz_models_to_paths.keys(), rotation=0, fontsize=10);
ax.grid(False)
for x in fz_models_to_paths.keys():
    ax = plt.errorbar(x=x, y=M2.loc[x].values.flatten(), yerr=(SE2.loc[x].values).flatten(), fmt='o', ls="none", label=x)

fig.legend().remove()
plt.ylim(0,1)
plt.title("Model Label Transfer Accuracy " + f"({title})")
plt.show()

In [None]:
M1

In [None]:
M2

In [None]:
b = M1.loc["SAMap \n(Weakly supervised)"]
a = M1.loc["SATURN\n(Our Model)"]
print(f"{int((a-b)/b * 100)}% improvement")

## Figure 2B Stacked Bars

In [None]:
cen_al_f = alignment_score(centroids_path,  col="labels2", species="frog")
sam_al_f = alignment_score(samap_labels_path,  col="cell_type", species="frog")

In [None]:
# Most common cell types
labels_by_rank = pd.DataFrame(centroids_ad.obs["labels2"].value_counts()).reset_index()
labels_by_rank

In [None]:
def plot_cluster_knn_bar(df, source='human', other='mouse', ax=None, title=None):
    """
    Creates a stacked bar plot to identify majority k nearest neighbors for
    a given cluster, on a specific axis
    """
    if ax is None:
        fig, (ax) = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(6, 17), frameon=False)
    bars = defaultdict(int)
    colors = defaultdict(int)
    tick = -1; 
    tick_pos = {}
    colors_list = ['','green'] + ["#DEDEDE"]*100
    df = df.sort_values(['labels2', 'Correct', 'Score'], ascending=[True, False, False])
    
    for i in df.iterrows():
        color = 'k'
        x = i[1]['Source_Cluster']
        y = i[1]['Score']        

        if i[1]['Correct']:
            color = 'green'
        else:
            color = "white"
        left = bars[x]
        bars[x] = bars[x] + y
        colors[x] = colors[x] + 1
        if colors_list[colors[x]] == 'green':    
            tick = tick+1
            tick_pos[x] = tick
        ax.barh(tick_pos[x], y, left=left, color=color, alpha=0.5, edgecolor="none")
        
    keys = sorted(tick_pos.keys())
    vals = [tick_pos[k] for k in keys]
    ax.set_yticks(vals, keys)
    ax.set_ylabel(source)
    #ax.set_xlabel('Percentage of cells with cross-species KNN class')
    if title is not None:
        ax.set_title(title)
    ax.set_xlim(0, 1)
    return 

In [None]:
ct_map = pd.read_csv("/dfs/project/cross-species/yanay/fz_true_ct.csv").set_index("Unnamed: 0").reset_index(drop=True)
ct_map.head()

In [None]:
# Most common cell types
labels_by_rank = pd.DataFrame(centroids_ad.obs["labels2"].value_counts()).reset_index()
labels_by_rank_zf = pd.DataFrame(centroids_ad[centroids_ad.obs["species"] == "zebrafish"].obs["labels2"].value_counts()).reset_index()
labels_by_rank = labels_by_rank.merge(labels_by_rank_zf.drop(columns="labels2"), left_on="index", right_on="index", how="inner")

In [None]:
labels_by_rank

In [None]:
cen_al_f, sam_al_f, cen_al_z, sam_al_z = c1.copy(), c2.copy(), c3.copy(), c4.copy()


In [None]:
width = 9
fig, (row1, row2) = plt.subplots(2, 2, sharex=True, sharey=False, figsize=(width, 5), frameon=False)
# Most Common
ax1, ax2 = row1
ax3, ax4 = row2

sam_most = sam_al_f.merge(ct_map, left_on="Source_Cluster", right_on="frog_cell_type")
cen_most = cen_al_f.merge(ct_map, left_on="Source_Cluster", right_on="frog_cell_type")

sam_most["Correct"] = sam_most["Cross_Species_KNN_Label"] == sam_most["zebrafish_cell_type"]
cen_most["Correct"] = cen_most["Cross_Species_KNN_Label"] == cen_most["zebrafish_cell_type"]


cen_most = cen_most.merge(labels_by_rank, left_on="frog_cell_type", right_on="index")
sam_most = sam_most.merge(labels_by_rank, left_on="frog_cell_type", right_on="index")

cen_most = cen_most[cen_most["labels2"] > 10000].sort_values(["labels2", "Source_Cluster"], ascending=False)
 # top 5
sam_most = sam_most[sam_most["labels2"] > 10000].sort_values(["labels2", "Source_Cluster"], ascending=False)
 # top 5



plot_cluster_knn_bar(sam_most, source="frog", ax=ax2, title="SAMap (Weakly supervised)")
plot_cluster_knn_bar(cen_most, source="frog", ax=ax1, title="SATURN (Our Model)")

ax1.set(ylabel=None)
ax1.grid(False)
ax2.grid(False)
ax2.get_yaxis().set_visible(False)
ax2.get_xaxis().set_visible(False)
ax1.get_xaxis().set_visible(False)

# Least Common
sam_least = sam_al_f.merge(ct_map, left_on="Source_Cluster", right_on="frog_cell_type")
cen_least = cen_al_f.merge(ct_map, left_on="Source_Cluster", right_on="frog_cell_type")


sam_least["Correct"] = sam_least["Cross_Species_KNN_Label"] == sam_least["zebrafish_cell_type"]
cen_least["Correct"] = cen_least["Cross_Species_KNN_Label"] == cen_least["zebrafish_cell_type"]

least_common_modifiers = { 
                          "Olfactory placode":["Olfactory placode"],
                          "Germline":["Germline"],
                          "Hatching gland":["Hatching gland"]
                         }
sam_least_correct = []
for row in sam_least.iterrows():
    row = row[1]
    pred = row['Cross_Species_KNN_Label']
    src = row["Source_Cluster"]
    res = row["zebrafish_cell_type"]
    
    sam_least_correct.append((pred == res) or pred in least_common_modifiers.get(src, []))
    #1/0
                             
sam_least["Correct"] = sam_least_correct

cen_least_correct = []
for row in cen_least.iterrows():
    row = row[1]
    pred = row['Cross_Species_KNN_Label']
    src = row["Source_Cluster"]
    res = row["zebrafish_cell_type"]
    
    cen_least_correct.append((pred == res) or pred in least_common_modifiers.get(src, []))
cen_least["Correct"] = cen_least_correct


cen_least = cen_least.merge(labels_by_rank, left_on="frog_cell_type", right_on="index")
sam_least = sam_least.merge(labels_by_rank, left_on="frog_cell_type", right_on="index")

cen_least = cen_least[cen_least["labels2"] < 500].sort_values(["labels2", "Source_Cluster"], ascending=False)
 # top 5
sam_least = sam_least[sam_least["labels2"] < 500].sort_values(["labels2", "Source_Cluster"], ascending=False)
 # top 5

plot_cluster_knn_bar(sam_least, source="frog", ax=ax4, title="")
plot_cluster_knn_bar(cen_least, source="frog", ax=ax3, title="")

ax3.set(ylabel=None)
ax3.grid(False)
ax4.grid(False)
ax4.get_yaxis().set_visible(False);

ax3.set(title=None);
ax4.set(title=None);


fig.tight_layout()

plt.savefig("figures/2b.svg")

### Check Germline Cells

In [None]:
cen_al_f[cen_al_f["Source_Cluster"] == "Germline"]

In [None]:
sam_al_f[sam_al_f["Source_Cluster"] == "Germline"]

## Figure 2D

In [None]:
sc._settings.settings._vector_friendly=True

In [None]:
fig, (axB, axA) = plt.subplots(2, 5, sharex=False, sharey=False, figsize=(13.3, 4), frameon=False, dpi=300)

J = 0
# Our method

# SATURN
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(centroids_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(centroids_ad, color="labels2", title="SATURN", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1


# SAMAP Label Share
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(samap_labels_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(samap_labels_ad, color="cell_type", title="SAMap\n(Weakly supervised)", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1

# Harmony
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(harmony_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(harmony_ad, color="cell_type", title="Harmony", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1

# SCVI
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(scvi_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(scvi_ad, color="cell_type", title="scVI", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1


# Scanorama
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(scanorama_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(scanorama_ad, color="cell_type", title="Scanorama", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1


plt.savefig("figures/2d.svg")

In [None]:
samap_ad

In [None]:
plotly_scatter(samap_labels_ad, label="cell_type", hover_cols=["cell_type", "species"])

In [None]:
plotly_scatter(centroids_ad, label="labels2", hover_cols=["labels2", "species"])

In [None]:
(centroids_ad.obs["labels"].str.contains("frog_Germline")).sum()

In [None]:
(centroids_ad.obs["labels"].str.contains("zebrafish_Germline")).sum()

# 2C FZ Outsets

In [None]:
sc.set_figure_params(figsize=(4, 4))

# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

In [None]:
centroids_ad

In [None]:
top10 = list(centroids_ad.obs["labels2"].value_counts()[:10].index)

labels_hidden_dict = {a:"_" +a for  a in centroids_ad.obs["labels2"].unique()}
for l in top10:
    labels_hidden_dict[l] = l

mapped_labels = [labels_hidden_dict[l] for l in centroids_ad.obs["labels2"]]
centroids_ad.obs["hidden_labels"] = mapped_labels

spaced_labels = [l.replace(" ", "\n").title() for l in centroids_ad.obs["labels2"]]
centroids_ad.obs["spaced_labels"] = spaced_labels

In [None]:
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(4, 5), frameon=False, dpi=300)

to_show = [a.replace(' ', '\n').title() for a in ['Blastula', '', 'Neuroectoderm', 'Non-neural ectoderm', '', 
           '', '', 'Epidermal progenitor', 
           '', '', '', '', 
           'Tailbud', 'Intermediate mesoderm', '', 'Eye primordium', 
           'Placodal area', 'Hindbrain', 'Neural crest', 'Neuron', 'Forebrain/midbrain', 
           '', '', '', '', 
           'Skeletal muscle', '', 'Blood', '', '', 
           '', '', '', '', 
           '', 'Optic', 'Pluripotent', '', '', 
           '', '', '', '', 
           '', '', '', 
           'Macrophage', 'Spemann Organizer', 'Goblet cell', 'Presomitic mesoderm']]
ax = sc.pl.umap(centroids_ad, color="spaced_labels", show=False, alpha=1, legend_loc='on data', ax=ax, legend_fontsize=7, legend_fontoutline=2)
ax.set(xlabel=None, ylabel=None);
ax.set(title=None)
for a in ax.texts:
    if not (a._text in to_show):
        #print(a._text)
        a.set_visible(False)
    if a._text == 'Spemann\nOrganizer':
        a.set_text("1")
    if a._text == 'Goblet\nCell':
        a.set_text("2")
    if a._text == 'Macrophage':
        a.set_text("3")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.savefig("figures/2c_main.svg")    
plt.show()


In [None]:
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(4, 5), frameon=False, dpi=300)

ax = sc.pl.umap(centroids_ad, color="species", show=False, alpha=1, ax=ax, legend_loc=None)
ax.set(xlabel=None, ylabel=None);
ax.set(title=None)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.savefig("figures/2c_spec.svg")
plt.show()
xlims = ax.get_xlim()
ylims = ax.get_ylim()

In [None]:
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(4, 4), frameon=False, dpi=300)

ax = sc.pl.umap(centroids_ad, color="species", show=False, alpha=1, ax=ax, legend_loc=None, groups="zebrafish")
ax.set(xlabel=None, ylabel=None);
ax.set(title=None)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.savefig("figures/2c_zebrafish.svg")
plt.show()
xlims = ax.get_xlim()
ylims = ax.get_ylim()

In [None]:
mod = 1.4

In [None]:
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(4/mod, 4/mod), frameon=False, dpi=300)

zoom = 6
x_offset = -10
y_offset = 9
xlim = (np.array(xlims) - min(xlims))/zoom + x_offset
ylim = (np.array(ylims) - min(ylims))/zoom + y_offset


umap_limited = centroids_ad[(centroids_ad.obsm["X_umap"][:, 0] > xlim[0]) & (centroids_ad.obsm["X_umap"][:, 0] < xlim[1]) \
& (centroids_ad.obsm["X_umap"][:, 1] > ylim[0]) & (centroids_ad.obsm["X_umap"][:, 1] > ylim[0])]
ax = sc.pl.umap(umap_limited, color="labels2", show=False, title="Inset 1", alpha=1, ax=ax, legend_fontsize=6, legend_fontoutline=2, groups=["Macrophage", "Myeloid progenitors", "Blood"])

for a in ax.get_legend().texts:
    a._text = a._text.title()
    if a._text == "Na":
        a._text = "Other"
plt.xlim(*xlim);
plt.ylim(*ylim);
ax.set(title=None, xlabel=None, ylabel=None);
plt.savefig("figures/2c_3.svg")

In [None]:
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(4/mod, 4/mod), frameon=False, dpi=300)

zoom = 12
x_offset = 5.25
y_offset = 0
xlim = (np.array(xlims) - min(xlims))/zoom + x_offset
ylim = (np.array(ylims) - min(ylims))/zoom + y_offset


umap_limited = centroids_ad[(centroids_ad.obsm["X_umap"][:, 0] > xlim[0]) & (centroids_ad.obsm["X_umap"][:, 0] < xlim[1]) \
& (centroids_ad.obsm["X_umap"][:, 1] > ylim[0]) & (centroids_ad.obsm["X_umap"][:, 1] > ylim[0])]
ax = sc.pl.umap(umap_limited, color="labels2", show=False, title="Inset 1", alpha=1, ax=ax, legend_fontsize=6, legend_fontoutline=2, groups=["Ionocyte", "Goblet cell",
                                                                                                                                            "Periderm", "Rare epidermal subtypes",
                                                                                                                                            "Epidermal progenitor"])


for a in ax.get_legend().texts:
    a._text = a._text.title()
    if a._text == "Na":
        a._text = "Other"
plt.xlim(*xlim);
plt.ylim(*ylim);
ax.set(title=None, xlabel=None, ylabel=None);
plt.savefig("figures/2c_2.svg")

#plt.show()
#fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(6, 4), frameon=False, dpi=300)



#ax = sc.pl.umap(umap_limited, color="species", show=False, title="Inset 1", alpha=1, legend_loc='on data', ax=ax, legend_fontsize=6, legend_fontoutline=2)

In [None]:
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(4/mod, 4/mod), frameon=False, dpi=300)

zoom = 5
x_offset = 8.25
y_offset = 5.5
xlim = (np.array(xlims) - min(xlims))/zoom + x_offset
ylim = (np.array(ylims) - min(ylims))/zoom + y_offset


umap_limited = centroids_ad[(centroids_ad.obsm["X_umap"][:, 0] > xlim[0]) & (centroids_ad.obsm["X_umap"][:, 0] < xlim[1]) \
& (centroids_ad.obsm["X_umap"][:, 1] > ylim[0]) & (centroids_ad.obsm["X_umap"][:, 1] > ylim[0])]
ax = sc.pl.umap(umap_limited, color="labels2", show=False, title="Inset 1", alpha=1, ax=ax, legend_fontsize=6, legend_fontoutline=2, groups=["Spemann organizer", "Involuting marginal zone",
                                                                                                                                            "Dorsal organizer", 
                                                                                                                                            "Endoderm"])


for a in ax.get_legend().texts:
    a._text = a._text.title()
    if a._text == "Na":
        a._text = "Other"
plt.xlim(*xlim);
plt.ylim(*ylim);
ax.set(title=None, xlabel=None, ylabel=None);
plt.savefig("figures/2c_1.svg")
#plt.show()
#fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(6, 4), frameon=False, dpi=300)



#ax = sc.pl.umap(umap_limited, color="species", show=False, title="Inset 1", alpha=1, legend_loc='on data', ax=ax, legend_fontsize=6, legend_fontoutline=2)

In [None]:
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(4/mod, 4/mod), frameon=False, dpi=300)

#zoom = 5
#x_offset = 8.25
#y_offset = 5.5
#xlim = (np.array(xlims) - min(xlims))/zoom + x_offset
#ylim = (np.array(ylims) - min(ylims))/zoom + y_offset


#map_limited = centroids_ad[(centroids_ad.obsm["X_umap"][:, 0] > xlim[0]) & (centroids_ad.obsm["X_umap"][:, 0] < xlim[1]) \
#& (centroids_ad.obsm["X_umap"][:, 1] > ylim[0]) & (centroids_ad.obsm["X_umap"][:, 1] > ylim[0])]
ax = sc.pl.umap(centroids_ad, color="species", show=False, title="Inset 4", alpha=1, ax=ax, legend_fontsize=6, legend_fontoutline=2)
#plt.ylim(*ylim);
ax.set(title=None, xlabel=None, ylabel=None);
plt.savefig("figures/2c_4.svg")
#plt.show()
#fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(6, 4), frameon=False, dpi=300)



#ax = sc.pl.umap(umap_limited, color="species", show=False, title="Inset 1", alpha=1, legend_loc='on data', ax=ax, legend_fontsize=6, legend_fontoutline=2)

# Biohub Poster Figure 2 Excerpts 

In [None]:
width = 15 # inches
height = 4

fig, (axB, axA) = plt.subplots(2, 5, sharex=False, sharey=False, figsize=(width, height), frameon=False, dpi=300)

J = 0
# Our method

# SATURN
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(centroids_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(centroids_ad, color="labels2", title="SATURN", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1


# SAMAP Label Share
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(samap_labels_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(samap_labels_ad, color="cell_type", title="SAMap\n(Weakly supervised)", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1

'''
# SAMAP No Label Share
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(samap_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(samap_ad, color="cell_type", title="SAMap", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1
'''

# Harmony
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(harmony_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(harmony_ad, color="cell_type", title="Harmony", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1

# SCVI
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(scvi_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(scvi_ad, color="cell_type", title="scVI", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1


# Scanorama
ax1 = axA[J]
fig.tight_layout()
sc.pl.umap(scanorama_ad, color="species", title="", ax=ax1, show=False)
ax1.get_legend().remove();
ax1.set(xlabel=None, ylabel=None);
ax1.title.set_size(18);

ax2 = axB[J]
sc.pl.umap(scanorama_ad, color="cell_type", title="Scanorama", ax=ax2, show=False)
ax2.get_legend().remove();
ax2.set(xlabel=None, ylabel=None);
ax2.title.set_size(18);
fig.tight_layout()

J += 1


plt.savefig("figures/2d_poster.png")

In [None]:
width = 7.5
height = 7.5
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, height), frameon=False, dpi=300)

to_show = [a.replace(' ', '\n').title() for a in ['Blastula', '', 'Neuroectoderm', 'Non-neural ectoderm', '', 
           '', '', 'Epidermal progenitor', 
           '', '', '', '', 
           'Tailbud', 'Intermediate mesoderm', '', 'Eye primordium', 
           'Placodal area', 'Hindbrain', 'Neural crest', 'Neuron', 'Forebrain/midbrain', 
           '', '', '', '', 
           'Skeletal muscle', '', 'Blood', '', '', 
           '', '', '', '', 
           '', 'Optic', 'Pluripotent', '', '', 
           '', '', '', '', 
           '', '', '', 
           'Macrophage', 'Spemann Organizer', 'Goblet cell', 'Presomitic mesoderm']]
ax = sc.pl.umap(centroids_ad, color="spaced_labels", show=False, alpha=1, legend_loc='on data', ax=ax, legend_fontsize=7, legend_fontoutline=2)
ax.set(xlabel=None, ylabel=None);
ax.set(title=None)
for a in ax.texts:
    if not (a._text in to_show):
        #print(a._text)
        a.set_visible(False)
    if a._text == 'Spemann\nOrganizer':
        a.set_text("")
    if a._text == 'Goblet\nCell':
        a.set_text("")
    if a._text == 'Macrophage':
        a.set_text("")
    a.set_fontsize(12)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.savefig("figures/2c_main_poster.svg")    
plt.show()


In [None]:
width = 7.5 / 1.75
height = 4 / 1.75
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, height), frameon=False, dpi=300)
ax.axhline(y=0.93,  linewidth=2, color='r')
ys = np.array([])
xs = np.array([])
hues = np.array([])


for model, path in fz_models_to_paths.items():
    if path is not "":
        scores = pd.read_csv(path)["Logistic Regression"].values
        ys = np.append(ys, scores)
        xs = np.append(xs, np.array([model] * len(scores)))
        hues = np.append(hues, pd.read_csv(path)["Label"].values)
        
        
sns.boxplot(y=ys, x=xs, ax=ax, hue=hues);
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=10);
ax.set_ylim(0, 1);
ax.tick_params(axis='both', which='minor', labelsize=7);
ax.set_title("Model Transfer Accuracy (Frog/Zebrafish)");

title = "Zebrafish to Frog"
df = pd.DataFrame(columns=["x","y","hue"])
df["x"] = xs
df["y"] = ys
df["hue"] = hues
df_1 = df[df["hue"] == title.lower()].drop(columns="hue")
M1 = df_1.groupby(["x"]).agg("median")
SE1 = df_1.groupby(["x"]).agg(np.std, ddof=1)

fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, height), frameon=False, dpi=300)
ax.axhline(y=0.93,  linewidth=2, color='r', linestyle="--")

ax.set_xticks(ticks=np.arange(len(M1.index)), labels=M1.index)
ax.set_xticklabels([x.split("\n")[0] for x in fz_models_to_paths.keys()], rotation=0, fontsize=10);
ax.grid(False)

for x in fz_models_to_paths.keys():
    ax = plt.errorbar(x=x, y=M1.loc[x].values.flatten(), yerr=(SE1.loc[x].values).flatten(), fmt='o', ls="none", label=x.split("\n")[0], lw=4, ms=8)

fig.legend().remove()
plt.ylim(0,1)
#plt.title("Label Transfer Accuracy " + f"({title})")
plt.savefig("figures/2a_poster.png")
plt.show()


title = "Frog to Zebrafish"
df_2 = df[df["hue"] == title.lower()].drop(columns="hue")
M2 = df_2.groupby(["x"]).agg("mean")
SE2 = df_2.groupby(["x"]).agg(np.std, ddof=1)


fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, height), frameon=False, dpi=300)
ax.axhline(y=0.93,  linewidth=2, color='r', linestyle="--")

ax.set_xticks(ticks=np.arange(len(M1.index)), labels=M1.index)
ax.set_xticklabels([x.split("\n")[0] for x in fz_models_to_paths.keys()], rotation=0, fontsize=10);
ax.grid(False)
for x in fz_models_to_paths.keys():
    ax = plt.errorbar(x=x, y=M2.loc[x].values.flatten(), yerr=(SE2.loc[x].values).flatten(), fmt='o', ls="none", label=x.split("\n")[0], lw=3)

fig.legend().remove()
plt.ylim(0,1)
#plt.title("Label Transfer Accuracy " + f"({title})")
plt.show()

# Figure 3

## Figure 3B

Example subset 1

In [None]:
subset_1_types = ["Macrophage", "Myeloid progenitors"]

In [None]:
subset_1_ad = centroids_ad[centroids_ad.obs["labels2"].isin(subset_1_types)]
sc.pl.umap(subset_1_ad, color="labels2")
sc.pl.umap(subset_1_ad, color="species")

In [None]:
ancs_ad = sc.AnnData(centroids_ad.obsm["ancs"])
ancs_ad.obs = centroids_ad.obs
ancs_ad.obs["subset"] = centroids_ad.obs["labels2"].isin(subset_1_types).astype(str)

ancs_ad.obsm["X_umap"] = centroids_ad.obsm["X_umap"]
ancs_ad.X = np.log(ancs_ad.X+1)

In [None]:
sc.tl.dendrogram(centroids_ad,groupby='labels2', use_rep='X')
ancs_ad.uns["dendrogram_labels2"] = centroids_ad.uns["dendrogram_labels2"]

In [None]:
with open(centroids_path.replace(".h5ad", "_genes_to_centroids.pkl"),'rb') as f:
    genes_to_scores = pickle.load(f)

In [None]:
sc.tl.rank_genes_groups(ancs_ad, 'subset', groups=["True"], method='wilcoxon')

In [None]:
best_gene_cluster_1 = sc.get.rank_genes_groups_df(ancs_ad, group="True")["names"][0]

In [None]:
def get_scores(anc_idx):
    '''
    Given the index of a centroid, return the scores by gene for that centroid
    '''
    scores = {}
    for (gene), embs in genes_to_scores.items():
        scores[gene] = embs[anc_idx]
    return scores

In [None]:
best_gene_cluster_1

In [None]:
de_df = sc.get.rank_genes_groups_df(ancs_ad, group=["True"])
de_df_pos = de_df[de_df["scores"] > 0]
de_df_neg = de_df[de_df["scores"] < 0].sort_values("scores")

In [None]:
de_df_pos.head(10)

In [None]:
for de in de_df_pos.head(6)["names"]:
    max_gene = pd.DataFrame(get_scores(int(de)).items(), columns=["gene", "weight"]).sort_values("weight", ascending=False)
    print(de)
    display(max_gene.head(20))

In [None]:
subset_1_groups = ["Macrophage", "Myeloid progenitors", "Blood", "Periderm", 
                   "Endothelial", "Pluripotent"]

subset_1_ad = ancs_ad[ancs_ad.obs["labels2"].isin(subset_1_groups)]
subset_1_ad_embs = centroids_ad[centroids_ad.obs["labels2"].isin(subset_1_groups)]

sc.tl.dendrogram(subset_1_ad_embs, groupby='labels2', use_rep='X')
subset_1_ad.uns["dendrogram_labels2"] = subset_1_ad_embs.uns["dendrogram_labels2"]
markers = [str(a) for a in de_df_pos.head(6)["names"] if a != "1302"]

fig, ax = plt.subplots(figsize=(6,3), dpi=300)
dp = sc.pl.dotplot(subset_1_ad, markers, groupby='labels2', 
                   dendrogram=True, show=False, cmap="viridis",
                   ax=ax, vmin=0.0, vmax=0.5, dot_min=0, dot_max=1,
                  )

dp["mainplot_ax"].set_xticklabels(["Arhgdi", "Cebp", "Ptp", "Cybb", "Lcp"], rotation = 0, ha="center", fontsize=8)
dp["mainplot_ax"].set_yticklabels(dp["mainplot_ax"].get_yticklabels(), fontsize=8)
plt.savefig("figures/3b_dot.svg")

# Figure 3C
### Subset 2

In [None]:
subset_2_types = ["Ionocyte"]

In [None]:
subset_2_ad = centroids_ad[centroids_ad.obs["labels2"].isin(subset_2_types)]
sc.pl.umap(subset_2_ad, color="labels2")
sc.pl.umap(subset_2_ad, color="species")

In [None]:
subset_2_ad

In [None]:
ancs_ad = sc.AnnData(centroids_ad.obsm["ancs"])
ancs_ad.obs = centroids_ad.obs
ancs_ad.obs["subset"] = centroids_ad.obs["labels2"].isin(subset_2_types).astype(str)

ancs_ad.obsm["X_umap"] = centroids_ad.obsm["X_umap"]
ancs_ad.X = np.log(ancs_ad.X+1)

In [None]:
sc.tl.dendrogram(centroids_ad, groupby='labels2', use_rep='X')
ancs_ad.uns["dendrogram_labels2"] = centroids_ad.uns["dendrogram_labels2"]

In [None]:
with open(centroids_path.replace(".h5ad", "_genes_to_centroids.pkl"),'rb') as f:
    genes_to_scores = pickle.load(f)

In [None]:
sc.tl.rank_genes_groups(ancs_ad, 'subset', groups=["True"], method='wilcoxon')

In [None]:
best_gene_cluster_2 = sc.get.rank_genes_groups_df(ancs_ad, group="True")["names"][0]

In [None]:
def get_scores(anc_idx):
    '''
    Given the index of a centroid, return the scores by gene for that centroid
    '''
    scores = {}
    for (gene), embs in genes_to_scores.items():
        scores[gene] = embs[anc_idx]
    return scores

In [None]:
best_gene_cluster_2

In [None]:
de_df = sc.get.rank_genes_groups_df(ancs_ad, group=["True"])
de_df_pos = de_df[de_df["scores"] > 0]
de_df_neg = de_df[de_df["scores"] < 0].sort_values("scores")

In [None]:
de_df_pos.head(10)

In [None]:
for de in de_df_pos.head(6)["names"]:
    print(de)
    max_gene = pd.DataFrame(get_scores(int(de)).items(), columns=["gene", "weight"]).sort_values("weight", ascending=False)
    display(max_gene.head(10))

In [None]:
ionocyte_groups = ["Ionocyte", "Small secretory cells", "Rare epidermal subtypes", "Goblet cell", 
                   "Endothelial", "Secretory epidermal"]

subset_2_ad = ancs_ad[ancs_ad.obs["labels2"].isin(ionocyte_groups)]
subset_2_ad_embs = centroids_ad[centroids_ad.obs["labels2"].isin(ionocyte_groups)]

sc.tl.dendrogram(subset_2_ad_embs, groupby='labels2', use_rep='X')
subset_2_ad.uns["dendrogram_labels2"] = subset_2_ad_embs.uns["dendrogram_labels2"]
markers = [str(a) for a in de_df_pos.head(6)["names"] if a != "1302"]
fig, ax = plt.subplots(figsize=(6,3), dpi=300)
dp = sc.pl.dotplot(subset_2_ad, markers, groupby='labels2', dendrogram=True, show=False, cmap="viridis", ax=ax, vmin=0.0, vmax=0.5, dot_min=0, dot_max=1)

dp["mainplot_ax"].set_xticklabels(["Foxi", "Dmr2", "Cldn", "Ubp", "Atp6v0"], rotation = 0, ha="center", fontsize=8)
dp["mainplot_ax"].set_yticklabels(dp["mainplot_ax"].get_yticklabels(), fontsize=8)
plt.savefig("figures/3c_dot.svg")

# Figure 3D

## now compare these CTs between species

In [None]:
subset_2_ad = centroids_ad[centroids_ad.obs["labels2"].isin(subset_2_types)]
sc.pp.pca(subset_2_ad)
sc.pp.neighbors(subset_2_ad)
sc.tl.umap(subset_2_ad)
sc.pl.umap(subset_2_ad, color="labels2")
sc.pl.umap(subset_2_ad, color="species")

In [None]:
subset_2_ancs_ad = ancs_ad[ancs_ad.obs["labels2"].isin(subset_2_types)]
subset_2_ancs_ad.obsm["X_umap"] = subset_2_ad.obsm["X_umap"]

In [None]:
sc.tl.rank_genes_groups(subset_2_ancs_ad, 'species', method='wilcoxon')

In [None]:
best_gene_cluster = sc.get.rank_genes_groups_df(subset_2_ancs_ad, group="zebrafish")["names"][0]

In [None]:
def get_scores(anc_idx):
    '''
    Given the index of a centroid, return the scores by gene for that centroid
    '''
    scores = {}
    for (gene), embs in genes_to_scores.items():
        scores[gene] = embs[anc_idx]
    return scores

In [None]:
de_df = sc.get.rank_genes_groups_df(subset_2_ancs_ad, group=["zebrafish"])
de_df_pos = de_df[de_df["scores"] > 0]
de_df_neg = de_df[de_df["scores"] < 0].sort_values("scores")

In [None]:
for de in de_df_pos.head(5)["names"]:
    max_gene = pd.DataFrame(get_scores(int(de)).items(), columns=["gene", "weight"]).sort_values("weight", ascending=False)
    print(de)
    display(max_gene.head(7))

In [None]:
de_df = sc.get.rank_genes_groups_df(subset_2_ancs_ad, group=["frog"])
de_df_neg = de_df[de_df["scores"] > 0]

In [None]:
markers = list(de_df_pos.head(20)["names"].astype(str)) + list(de_df_neg.head(20)["names"].astype(str))
sc.pl.dotplot(ancs_ad, markers, groupby='species', dendrogram=True)

In [None]:
#fair_markers = ["1090", "463", "157", "1160", "1627", "1064"]
#fair_markers = ["1372", "13", "1861"]
fair_markers = ["613", "501", "148"]

print("All Cell Types")
sc.pl.dotplot(ancs_ad, fair_markers, groupby='species', dendrogram=True, title="All Cell Types")
print("Ionocytes Only")

sc.pl.dotplot(subset_2_ancs_ad, fair_markers, groupby='species', dendrogram=True, title="Ionocytes Only")
sc.pl.umap(subset_2_ancs_ad, color=fair_markers)
sc.pl.umap(subset_2_ancs_ad, color="species", title="Ionocytes Only")

In [None]:
for de in fair_markers:
    max_gene = pd.DataFrame(get_scores(int(de)).items(), columns=["gene", "weight"]).sort_values("weight", ascending=False)
    print(de)
    display(max_gene.head(7))

## Make Figure 3d

In [None]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(5,3.5))

#fair_markers = ["1090", "463", "157", "1160", "1627", "1064"]
#fair_markers = ["1372", "13", "1861"]

fair_markers = ["613", "501", "148"]
i=0
sc.pl.umap(subset_2_ancs_ad, color=f"{fair_markers[i]}", title=f"Macrogene: {fair_markers[i]}", ax=ax1, show=False, colorbar_loc=None, vmin=0);
i+=1
sc.pl.umap(subset_2_ancs_ad, color=f"{fair_markers[i]}", title=f"Macrogene: {fair_markers[i]}", ax=ax2, show=False, colorbar_loc=None, vmin=0);
i+=1
sc.pl.umap(subset_2_ancs_ad, color=f"{fair_markers[i]}", title=f"Macrogene: {fair_markers[i]}", ax=ax3, show=False, colorbar_loc=None, vmin=0);




sc.pl.umap(subset_2_ancs_ad, color="species", ax=ax4, show=False, title="Species", legend_loc=None);

ax1.set(xlabel=None, ylabel=None);
ax2.set(xlabel=None, ylabel=None);
ax3.set(xlabel=None, ylabel=None);
ax4.set(xlabel=None, ylabel=None);

plt.savefig("figures/3d.svg")

In [None]:
for de in fair_markers:
    max_gene = pd.DataFrame(get_scores(int(de)).items(), columns=["gene", "weight"]).sort_values("weight", ascending=False)
    print(de)
    display(max_gene.head(7))

# Supplement Figures

# Supplement Figures 1

## Supplement 1A Bar Graphs

In [None]:
def plot_cluster_knn_bar(df, source='human', other='mouse', ax=None, title=None):
    """
    Creates a stacked bar plot to identify majority k nearest neighbors for
    a given cluster, on a specific axis
    """
    if ax is None:
        fig, (ax) = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(6, 17), frameon=False)
    bars = defaultdict(int)
    colors = defaultdict(int)
    tick = -1; 
    tick_pos = {}
    colors_list = ['','green'] + ["#DEDEDE"]*100
    df = df.sort_values(['labels2', 'Correct', 'Score'], ascending=[True, False, False])
    
    for i in df.iterrows():
        color = 'k'
        x = i[1]['Source_Cluster']
        y = i[1]['Score']        

        if i[1]['Correct']:
            color = 'green'
        else:
            color = "white"
        left = bars[x]
        bars[x] = bars[x] + y
        colors[x] = colors[x] + 1
        if colors_list[colors[x]] == 'green':    
            tick = tick+1
            tick_pos[x] = tick
        ax.barh(tick_pos[x], y, left=left, color=color, alpha=0.5, edgecolor="none")
        
    keys = sorted(tick_pos.keys())
    vals = [tick_pos[k] for k in keys]
    ax.set_yticks(vals, keys)
    ax.set_ylabel(source)
    #ax.set_xlabel('Percentage of cells with cross-species KNN class')
    if title is not None:
        ax.set_title(title)
    ax.set_xlim(0, 1)
    return 

In [None]:
cen_al_f = alignment_score(centroids_path,  col="labels2", species="frog")
sam_al_f = alignment_score(samap_labels_path,  col="cell_type", species="frog")

cen_al_z = alignment_score(centroids_path,  col="labels2", species="zebrafish")
sam_al_z = alignment_score(samap_labels_path,  col="cell_type", species="zebrafish")
c1, c2, c3, c4 = cen_al_f.copy(), sam_al_f.copy(), cen_al_z.copy(), sam_al_z.copy()

In [None]:
ct_map = pd.read_csv("/dfs/project/cross-species/yanay/fz_true_ct.csv").set_index("Unnamed: 0").reset_index(drop=True)
ct_map.head()

In [None]:
# Most common cell types
labels_by_rank_fg = pd.DataFrame(centroids_ad[centroids_ad.obs["species"] == "frog"].obs["labels2"].value_counts()).reset_index()
labels_by_rank_zf = pd.DataFrame(centroids_ad[centroids_ad.obs["species"] == "zebrafish"].obs["labels2"].value_counts()).reset_index()

In [None]:
cen_al_f, sam_al_f, cen_al_z, sam_al_z = c1.copy(), c2.copy(), c3.copy(), c4.copy()
width = 9
height = 16
#fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharex=True, sharey=False, figsize=(width, height), frameon=False)
fig, (row1, row2) = plt.subplots(2, 2, sharex=True, sharey=False, figsize=(width, height), frameon=False)
# Most Common
ax1, ax2 = row1
ax3, ax4 = row2

# Frog

sam_al_f = sam_al_f.merge(ct_map, left_on="Source_Cluster", right_on="frog_cell_type")
cen_al_f = cen_al_f.merge(ct_map, left_on="Source_Cluster", right_on="frog_cell_type")


sam_al_f["Correct"] = sam_al_f["Cross_Species_KNN_Label"] == sam_al_f["zebrafish_cell_type"]
cen_al_f["Correct"] = cen_al_f["Cross_Species_KNN_Label"] == cen_al_f["zebrafish_cell_type"]

least_common_modifiers = {
                          "Olfactory placode":["Olfactory placode"],
                          "Germline":["Germline"],
                          "Hatching gland":["Hatching gland"]
                         }
sam_al_f_correct = []
for row in sam_al_f.iterrows():
    row = row[1]
    pred = row['Cross_Species_KNN_Label']
    src = row["Source_Cluster"]
    res = row["zebrafish_cell_type"]
    
    sam_al_f_correct.append((pred == res) or pred in least_common_modifiers.get(src, []))
    #1/0
                             
sam_al_f["Correct"] = sam_al_f_correct

cen_al_f_correct = []
for row in cen_al_f.iterrows():
    row = row[1]
    pred = row['Cross_Species_KNN_Label']
    src = row["Source_Cluster"]
    res = row["zebrafish_cell_type"]
    
    cen_al_f_correct.append((pred == res) or pred in least_common_modifiers.get(src, []))
cen_al_f["Correct"] = cen_al_f_correct


cen_al_f = cen_al_f.merge(labels_by_rank_zf, left_on="frog_cell_type", right_on="index")
sam_al_f = sam_al_f.merge(labels_by_rank_zf, left_on="frog_cell_type", right_on="index")

cen_al_f = cen_al_f.sort_values(["labels2", "Source_Cluster"], ascending=False)
 # top 5
sam_al_f = sam_al_f.sort_values(["labels2", "Source_Cluster"], ascending=False)
 # top 5

plot_cluster_knn_bar(sam_al_f, source="frog", ax=ax2, title="")
plot_cluster_knn_bar(cen_al_f, source="frog", ax=ax1, title="")

ax1.set(ylabel=None)
ax1.grid(False)
ax2.grid(False)
ax2.get_yaxis().set_visible(False);

ax1.set(title=None);
ax2.set(title=None);

# Zebrafish
sam_al_z = sam_al_z.merge(ct_map, left_on="Source_Cluster", right_on="zebrafish_cell_type")
cen_al_z = cen_al_z.merge(ct_map, left_on="Source_Cluster", right_on="zebrafish_cell_type")


sam_al_z["Correct"] = sam_al_z["Cross_Species_KNN_Label"] == sam_al_z["frog_cell_type"]
cen_al_z["Correct"] = cen_al_z["Cross_Species_KNN_Label"] == cen_al_z["frog_cell_type"]

least_common_modifiers = {
                          "Olfactory placode":["Olfactory placode"],
                          "Germline":["Germline"],
                          "Hatching gland":["Hatching gland"]
                         }
sam_al_z_correct = []
for row in sam_al_z.iterrows():
    row = row[1]
    pred = row['Cross_Species_KNN_Label']
    src = row["Source_Cluster"]
    res = row["frog_cell_type"]
    
    sam_al_z_correct.append((pred == res) or pred in least_common_modifiers.get(src, []))
    #1/0
                             
sam_al_z["Correct"] = sam_al_z_correct

cen_al_z_correct = []
for row in cen_al_z.iterrows():
    row = row[1]
    pred = row['Cross_Species_KNN_Label']
    src = row["Source_Cluster"]
    res = row["frog_cell_type"]
    
    cen_al_z_correct.append((pred == res) or pred in least_common_modifiers.get(src, []))
cen_al_z["Correct"] = cen_al_z_correct


cen_al_z = cen_al_z.merge(labels_by_rank_fg, left_on="zebrafish_cell_type", right_on="index")
sam_al_z = sam_al_z.merge(labels_by_rank_fg, left_on="zebrafish_cell_type", right_on="index")

cen_al_z = cen_al_z.sort_values(["labels2", "Source_Cluster"], ascending=False)
 # top 5
sam_al_z = sam_al_z.sort_values(["labels2", "Source_Cluster"], ascending=False)
 # top 5

plot_cluster_knn_bar(sam_al_z, source="zebrafish", ax=ax4, title="")
plot_cluster_knn_bar(cen_al_z, source="zebrafish", ax=ax3, title="")

ax3.set(ylabel=None)
ax3.grid(False)
ax4.grid(False)
ax4.get_yaxis().set_visible(False);

ax3.set(title=None);
ax4.set(title=None);


fig.tight_layout()

plt.savefig("figures/supp1_bar.svg")

## Supplement 1B Bar Chart

In [None]:
fz_models_to_paths_sup = {
#"Both Losses\n(ESM1b)":"/dfs/project/cross-species/yanay/data/scoring_csvs/lasso_fz_scores_scores.csv",
#"L1, no Rank Loss\n(ESM1b)":"/dfs/project/cross-species/yanay/data/scoring_csvs/lasso_no_rank_scores.csv",

#"(ESM1b)":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_30_seeds_scores.csv",
"ESM2":"/dfs/project/cross-species/yanay/data/scoring_csvs/no_lasso_rank_esm2_scores.csv",
"ESM1b":"/dfs/project/cross-species/yanay/data/scoring_csvs/no_lasso_rank_scores.csv",
#"ESM2":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_esm2_scores.csv",
"ProtXL":"/dfs/project/cross-species/yanay/data/scoring_csvs/protxl_new_loss_scores.csv",
"ESM2 Pretrain":"/dfs/project/cross-species/yanay/data/scoring_csvs/no_lasso_rank_esm2_pretrain_scores.csv",
"SAMAP \n(Weakly Supervised)":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_samap_scores_ct_scores.csv",
"SAMAP \n(Unsupervised)":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_samap_noct_scores_scores.csv",
"scVI":"/dfs/project/cross-species/yanay/data/scoring_csvs/fz_scvi_scores_scores.csv",
    
}

In [None]:
width = len(fz_models_to_paths_sup) * 1.25
fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, 3), frameon=False, dpi=300)

ys = np.array([])
xs = np.array([])
hues = np.array([])


for model, path in fz_models_to_paths_sup.items():
    if path is not "":
        scores = pd.read_csv(path)["Logistic Regression"].values
        ys = np.append(ys, scores)
        xs = np.append(xs, np.array([model] * len(scores)))
        hues = np.append(hues, pd.read_csv(path)["Label"].values)
        
        print(f"{path}: Max: {max(scores)}")
sns.boxplot(y=ys, x=xs, ax=ax, hue=hues);
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=10);
#ax.set_ylim(0, 1);
ax.tick_params(axis='both', which='minor', labelsize=7);
ax.set_title("Label Transfer Accuracy");
plt.savefig("figures/supp1b_box.svg")
title = "Zebrafish to Frog"
df = pd.DataFrame(columns=["x","y","hue"])

df["x"] = xs
df["y"] = ys
df["hue"] = hues
df_1 = df[df["hue"] == title.lower()].drop(columns="hue")
M1 = df_1.groupby(["x"]).agg("median")
SE1 = df_1.groupby(["x"]).agg(np.std, ddof=1)

fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, 3), frameon=False, dpi=300)
ax.set_xticks(ticks=np.arange(len(M1.index)), labels=M1.index)
ax.set_xticklabels(fz_models_to_paths_sup.keys(), rotation=0, fontsize=10);
ax.grid(False)

for x in fz_models_to_paths_sup.keys():
    ax = plt.errorbar(x=x, y=M1.loc[x].values.flatten(), yerr=(SE1.loc[x].values).flatten(), fmt='o', ls="none", label=x)

fig.legend().remove()
#plt.ylim(0,1)
plt.title("Accuracy of Transferring Labels from Zebrafish to Frog")

plt.show()


title = "Frog to Zebrafish"
df_2 = df[df["hue"] == title.lower()].drop(columns="hue")
M2 = df_2.groupby(["x"]).agg("mean")
SE2 = df_2.groupby(["x"]).agg(np.std, ddof=1)


fig, ax = plt.subplots(sharex=False, sharey=False, figsize=(width, 3), frameon=False, dpi=300)
ax.set_xticks(ticks=np.arange(len(M1.index)), labels=M1.index)
ax.set_xticklabels(fz_models_to_paths_sup.keys(), rotation=0, fontsize=10);
ax.grid(False)
for x in fz_models_to_paths_sup.keys():
    ax = plt.errorbar(x=x, y=M2.loc[x].values.flatten(), yerr=(SE2.loc[x].values).flatten(), fmt='o', ls="none", label=x)

fig.legend().remove()
#plt.ylim(0,1)
plt.title("Accuracy of Transferring Labels from Frog to Zebrafish")
plt.show()

In [None]:
pd.concat([M1, M2, SE1, SE2], axis=1).set_axis(["Median Z to F", "Median F to Z", "SE Z to F", "SE F to Z"], axis=1)

In [None]:
M2

In [None]:
SE1

In [None]:
SE2

In [None]:
display(df_2.groupby(["x"]).agg("max").sort_values("y"))
display(df_1.groupby(["x"]).agg("max").sort_values("y"))