In [1]:
import numpy as np
import scanpy as sc
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import rcParams
from matplotlib.axes import Axes

In [None]:
cluster_palette = ["#1f77b4", "#ff7f0e", "#279e68", "#d62728", "#aa40fc", "#8c564b", \
                  "#e377c2", "#b5bd61", "#17becf", "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", \
                  "#c5b0d5", "#c49c94", "#f7b6d2", "#dbdb8d", "#9edae5", "#ad494a", "#8c6d31", \
                  "#b4d2b1", "#568f8b", "#1d4a60", "#cd7e59", "#ddb247", "#d15252", \
                  "#264653", "#2a9d8f", "#e9c46a", "#f4a261", "#e76f51", "#ef476f", \
                  "#ffd166","#06d6a0","#118ab2","#073b4c", "#fbf8cc","#fde4cf", \
                  "#ffcfd2","#f1c0e8","#cfbaf0","#a3c4f3","#90dbf4","#8eecf5", \
                  '#8359A3', '#5e503f', '#33CC99', '#F2C649', '#B94E48', '#0095B7', \
                  '#FF681F', '#e0aaff', '#FED85D', '#0a0908', '#C32148', '#98f5e1', \
                  "#000000", "#FFFF00", "#1CE6FF", "#FF34FF", "#FF4A46", "#008941", "#006FA6", "#A30059", \
                  "#FFDBE5", "#7A4900", "#0000A6", "#63FFAC", "#B79762", "#004D43", "#8FB0FF", "#997D87", \
                  "#5A0007", "#809693", "#FEFFE6", "#1B4400", "#4FC601", "#3B5DFF", "#4A3B53", "#FF2F80", \
                  "#61615A", "#BA0900", "#6B7900", "#00C2A0", "#FFAA92", "#FF90C9", "#B903AA", "#D16100", \
                  "#DDEFFF", "#000035", "#7B4F4B", "#A1C299", "#300018", "#0AA6D8", "#013349", "#00846F", \
                  "#372101", "#FFB500", "#C2FFED", "#A079BF", "#CC0744", "#C0B9B2", "#C2FF99", "#001E09", \
                  "#00489C", "#6F0062", "#0CBD66", "#EEC3FF", "#456D75", "#B77B68", "#7A87A1", "#788D66", \
                  "#885578", "#FAD09F", "#FF8A9A", "#D157A0", "#BEC459", "#456648", "#0086ED", "#886F4C", \
                  "#34362D", "#B4A8BD", "#00A6AA", "#452C2C", "#636375", "#A3C8C9", "#FF913F", "#938A81", \
                  "#575329", "#00FECF", "#B05B6F", "#8CD0FF", "#3B9700", "#04F757", "#C8A1A1", "#1E6E00", \
                  "#7900D7", "#A77500", "#6367A9", "#A05837", "#6B002C", "#772600", "#D790FF", "#9B9700", \
                  "#549E79", "#FFF69F", "#201625", "#72418F", "#BC23FF", "#99ADC0", "#3A2465", "#922329", \
                  "#5B4534", "#FDE8DC", "#404E55", "#0089A3", "#CB7E98", "#A4E804", "#324E72", "#6A3A4C", \
                  "#00B7FF", "#004DFF", "#00FFFF", "#826400", "#580041", "#FF00FF", "#00FF00", "#C500FF", \
                  "#B4FFD7", "#FFCA00", "#969600", "#B4A2FF", "#C20078", "#0000C1", "#FF8B00", "#FFC8FF", \
                  "#666666", "#FF0000", "#CCCCCC", "#009E8F", "#D7A870", "#8200FF", "#960000", "#BBFF00", \
                  "#FFFF00", "#006F00"]

In [None]:
# communities_path = '/home/ubuntu/results/mouse_testes/wt3d2rgb_r0.25_ws150_en1.0_sct1.0_dwr80_mcc1.0/Diabetes2_ct/Diabetes2_ct.csv'
# annotation = 'annotation'
# adata_path = '/goofys/Samples/slide_seq/mouse_testis/diabetes/Diabetes2_ct.h5ad'

In [None]:
communities_path = '/home/ubuntu/cell-communities/E16.5_E1S3_cell_bin_whole_brain_noborderct.csv'
annotation = 'sim anno'
adata_path = '/goofys/Samples/Stereo_seq/E16.5_E1S3_cell_bin_whole_brain_noborderct.h5ad'

In [None]:
 def calculate_cell_mixture_stats(adata, annotation, unique_cell_type, active_cell_type):
        """
        Calculate cell type percentages per cluster - community and save it in pandas.DataFrame object. 
        
        Percentages are calculated globaly for all cells with single class label. 
        This is saved in self.tissue.uns['cell mixtures'] for further use by plot fn.
        Columns of total cell count per class and percentage of tissue per cluster are added.
        Row of total cell type count is added. DataFrame with additional columns and row is saved in adata.uns['cell mixture stats']
        """

        # extract information on self.cluster_algo clustering labels and cell types to create cell communities statistics
        clustering_labels = 'tissue_sliding_window'
        cell_types_communities = adata.obs[[clustering_labels, annotation]]
        # remove cells with unknown cell community label
        if 'unknown' in cell_types_communities[clustering_labels].cat.categories:
            cell_types_communities = cell_types_communities[cell_types_communities[clustering_labels] != 'unknown']
            cell_types_communities[clustering_labels] = cell_types_communities[clustering_labels].cat.remove_categories('unknown')

        stats_table = {}
        # calculate cell type mixtures for every cluster
        for label, cluster_data in cell_types_communities.groupby(clustering_labels):
            cell_type_dict = {ct:np.sum(cluster_data[annotation]==ct) for ct in unique_cell_type}
            # cell_type_dict = {ct:0 for ct in unique_cell_type}
            # for cell in cluster_data[annotation]:
            #     cell_type_dict[cell]+=1

            # remove excluded cell types
            cell_type_dict = {k:cell_type_dict[k] for k in active_cell_type}
            
            # create a dictionary of cluster cell type distributions
            stats_table[label] = {k:cell_type_dict[k] for k in cell_type_dict}

        stats = pd.DataFrame(stats_table).T
        stats.columns.name = "cell types"

        stats.index = stats.index.astype(int)
        stats = stats.sort_index()
        stats.index = stats.index.astype(str)

        # [TODO] Condsider doing this in some other place
        # if there are cell types with 0 cells in every cluster remove them
        for col in stats.columns:
            if sum(stats.loc[:, col]) == 0:
                stats = stats.drop(labels=col, axis=1)
        # if there are clusters with 0 cells remove them
        for row in stats.index:
            if sum(stats.loc[row, :]) == 0:
                stats = stats.drop(labels=row, axis=0)

        # save absolute cell mixtures to tissue
        adata.uns['cell mixtures'] = stats.iloc[:,:].copy()

        # add column with total cell count per cluster
        stats['total_counts'] = np.array([sum(stats.loc[row, :]) for row in stats.index]).astype(int)

        # add row with total counts per cell types
        cell_type_counts = {ct:[int(sum(stats[ct]))] for ct in stats.columns}
        stats = pd.concat([stats, pd.DataFrame(cell_type_counts, index=['total_cells'])])

        # divide each row with total sum of cells per cluster and mul by 100 to get percentages
        stats.iloc[:-1, :-1] = stats.iloc[:-1, :-1].div(stats['total_counts'][:-1], axis=0).mul(100).astype(int)

        # add column with percentage of all cells belonging to a cluster
        stats['perc_of_all_cells'] = np.around(stats['total_counts'] / stats['total_counts'][-1] * 100, decimals=1)

        # save cell mixture statistics to adata
        adata.uns['cell mixtures stats'] = stats.iloc[:, :]


In [None]:
def set_figure_params(
        dpi: int,
        facecolor: str,
):
    rcParams['figure.facecolor'] = facecolor
    rcParams['axes.facecolor'] = facecolor
    rcParams["figure.dpi"] = dpi

In [None]:
def plot_spatial(
        adata,
        annotation,
        ax: Axes,
        spot_size: float,
        palette=None,
        title: str = ""
):
    """
    Scatter plot in spatial coordinates.

    Parameters:
        - adata (AnnData): Annotated data object which represents the sample
        - annotation (str): adata.obs column used for grouping
        - ax (Axes): Axes object used for plotting
        - spot_size (int): Size of the dot that represents a cell. We are passing it as a diameter of the spot, while
                the plotting library uses radius therefore it is multiplied by 0.5
        - palette (dict): Dictionary that represents a mapping between annotation categories and colors
        - title (str): Title of the figure

    """
    s = spot_size * 0.5
    data = adata
    ax = sns.scatterplot(
        data=data.obs, hue=annotation, x=data.obsm['spatial'][:, 0], y=data.obsm['spatial'][:, 1],
        ax=ax, s=s, linewidth=0, palette=palette, marker='.'
    )
    ax.set(yticklabels=[], xticklabels=[], title=title)
    ax.tick_params(bottom=False, left=False)
    ax.set_aspect("equal")
    sns.despine(bottom=True, left=True, ax=ax)

In [None]:
def plot_cluster_mixtures(adata, annotation, annotation_palette, dpi=200, min_cluster_size=200, min_perc_to_show=4, spot_size=4, cluster_index=None):
        """
        Plot cell mixtures for each cluster (community). Only cell types which have more than min_perc_to_show abundance will be shown.

        The cell mixtures are obtained from `self.tissue.uns['cell mixtures stats']`. 
        The resulting plots are saved as PNG files in the directory specified by `self.dir_path`.

        """
        # plot each cluster and its cells mixture
        set_figure_params(dpi=dpi, facecolor='white')
        stats = adata.uns['cell mixtures stats']

        new_stats = stats.copy()
        new_stats = new_stats.drop(labels=['total_counts', 'perc_of_all_cells'], axis=1)
        new_stats = new_stats.drop(labels='total_cells', axis=0)

        cl_palette = {}
        for cluster in new_stats.index:
            cl_palette[cluster] = '#dcdcdc'
        cl_palette['unknown'] = '#dcdcdc'

        ind=0
        for cluster in new_stats.iterrows():
            if cluster_index != None and cluster_index != ind:
                ind += 1
                continue
            elif cluster_index != None and cluster_index == ind:
                ind += 1
            # only display clusters with more than min_cells_in_cluster cells
            if stats.loc[cluster[0]]['total_counts'] > min_cluster_size:
                # sort cell types by their abundnce in the cluster
                ct_perc = cluster[1].sort_values(ascending=False)
                # only cell types which have more than min_perc_to_show abundance will be shown
                ct_show = ct_perc.index[ct_perc > min_perc_to_show]
                ct_palette = {x: annotation_palette[x] for x in ct_show}
                for y in annotation_palette.keys():
                    if y not in ct_show:
                        ct_palette[y] = '#dcdcdc'
                
                fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15,8))
                fig.subplots_adjust(wspace=0.35)

                plot_spatial(adata, annotation=annotation, palette=ct_palette, spot_size=spot_size, ax=ax[0])
                # ax[0].set_title(f'Cell types')
                handles, labels = ax[0].get_legend_handles_labels()
                handles, labels = zip(*filter(lambda hl: hl[1] in ct_show, zip(handles, labels)))
                labels = [f'{ctype} ({ct_perc[ctype]}%)' for ctype in labels]
                ax[0].legend(handles=handles, labels=labels, bbox_to_anchor=(1.0, 0.5), loc='center left', frameon=False, fontsize=8)
                cl_palette[cluster[0]] = cluster_palette[int(cluster[0])]

                plot_spatial(adata, annotation='tissue_sliding_window', palette=cl_palette, spot_size=spot_size, ax=ax[1])
                # Get the handles and labels of the current legend
                handles, labels = ax[1].get_legend_handles_labels()

                # Filter the handles and labels for the desired cluster
                highlighted_cluster_label = f"{cluster[0]}"
                filtered_handles = [h for h, label in zip(handles, labels) if highlighted_cluster_label == label]
                filtered_labels = [highlighted_cluster_label]

                # Create a new legend with the filtered handles and labels
                ax[1].legend(handles=filtered_handles, labels=filtered_labels, bbox_to_anchor=(1.0, 0.5), loc='center left', frameon=False, fontsize=8)
                fig.savefig(f'cmixtures_c{cluster[0]}.png', bbox_inches='tight')
                plt.close()
                cl_palette[cluster[0]] = '#dcdcdc'

In [None]:
adata = sc.read(adata_path)
communities = pd.read_csv(communities_path)

adata.obs.loc[:, 'tissue_sliding_window'] = communities.loc[:, 'tissue_sliding_window'].values
adata.obs['tissue_sliding_window'] = adata.obs['tissue_sliding_window'].astype('category')
annotation_palette = {ct : adata.uns[f'{annotation}_colors'][i] for i, ct in enumerate( list(sorted(adata.obs[annotation].unique())))}
unique_cell_type = list(adata.obs[annotation].values.unique())
active_cell_type = [ct for ct in unique_cell_type if ct != 'Ery']

calculate_cell_mixture_stats(adata, annotation, unique_cell_type, active_cell_type)

plot_cluster_mixtures(adata, annotation, annotation_palette=annotation_palette, dpi=300)