# Spatial domain identification and UMAP visualization

In [None]:
import numpy as np
import anndata as ad
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import networkx as nx
from umap.umap_ import UMAP
from sklearn.mixture import GaussianMixture

from matplotlib.lines import Line2D
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

import warnings
warnings.filterwarnings("ignore")

In [None]:
def match_cluster_labels(true_labels, est_labels):
    true_labels_arr = np.array(list(true_labels))
    est_labels_arr = np.array(list(est_labels))

    org_cat = list(np.sort(list(pd.unique(true_labels))))
    est_cat = list(np.sort(list(pd.unique(est_labels))))

    B = nx.Graph()
    B.add_nodes_from([i + 1 for i in range(len(org_cat))], bipartite=0)
    B.add_nodes_from([-j - 1 for j in range(len(est_cat))], bipartite=1)

    for i in range(len(org_cat)):
        for j in range(len(est_cat)):
            weight = np.sum((true_labels_arr == org_cat[i]) * (est_labels_arr == est_cat[j]))
            B.add_edge(i + 1, -j - 1, weight=-weight)

    match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B)

    if len(org_cat) >= len(est_cat):
        return np.array([match[-est_cat.index(c) - 1] - 1 for c in est_labels_arr])
    else:
        unmatched = [c for c in est_cat if not (-est_cat.index(c) - 1) in match.keys()]
        l = []
        for c in est_labels_arr:
            if (-est_cat.index(c) - 1) in match:
                l.append(match[-est_cat.index(c) - 1] - 1)
            else:
                l.append(len(org_cat) + unmatched.index(c))
        return np.array(l)


def plot_DLPFC(rna_list, adata_concat, ground_truth_key, matched_clusters_key, model, group_idx, cluster_to_color_map,
               matched_to_color_map, cluster_orders, slice_name_list, cls_list, sp_embedding,
               save_root=None, frame_color=None, file_format='pdf', save=False, plot=False):

    samples = ['A', 'B', 'C']

    fig, axs = plt.subplots(2, 4, figsize=(15, 7))
    fig.suptitle(f'{model} Clustering Results (Sample {samples[group_idx]})', fontsize=16)
    for i in range(len(rna_list)):
        real_colors = list(rna_list[i].obs[ground_truth_key].astype('str').map(cluster_to_color_map))
        axs[0, i].scatter(rna_list[i].obsm['spatial'][:, 0], rna_list[i].obsm['spatial'][:, 1], linewidth=0.5, s=30,
                          marker=".", color=real_colors, alpha=0.9)
        axs[0, i].set_title(f'{slice_name_list[i]} (Ground Truth)', size=12)
        axs[0, i].invert_yaxis()
        axs[0, i].axis('off')

        cluster_colors = list(rna_list[i].obs[matched_clusters_key].map(matched_to_color_map))
        axs[1, i].scatter(rna_list[i].obsm['spatial'][:, 0], rna_list[i].obsm['spatial'][:, 1], linewidth=0.5, s=30,
                          marker=".", color=cluster_colors, alpha=0.9)
        axs[1, i].set_title(f'{slice_name_list[i]} (Cluster Results)', size=12)
        axs[1, i].invert_yaxis()
        axs[1, i].axis('off')

    legend_handles_1 = [
        Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=cluster_to_color_map[cluster],
               label=cluster) for cluster in cls_list
    ]
    axs[0, 3].legend(
        handles=legend_handles_1,
        fontsize=8, title='Spot-types', title_fontsize=10, bbox_to_anchor=(1, 1.15))
    legend_handles_2 = [
        Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=matched_to_color_map[order],
               label=f'{i}') for i, order in enumerate(cluster_orders)
    ]
    axs[1, 3].legend(
        handles=legend_handles_2,
        fontsize=8, title='Clusters', title_fontsize=10, bbox_to_anchor=(1, 1.1))
    plt.gcf().subplots_adjust(left=0.05, top=None, bottom=None, right=0.85)
    if save:
        save_path = save_root + f'/{model}_group{group_idx}_clustering_results.{file_format}'
        plt.savefig(save_path, dpi=500)

    n_spots = adata_concat.shape[0]
    size = 10000 / n_spots
    order = np.arange(n_spots)
    colors_for_slices = [[0.2298057, 0.29871797, 0.75368315],
                         [0.70567316, 0.01555616, 0.15023281],
                         [0.2298057, 0.70567316, 0.15023281],
                         [0.5830223, 0.59200322, 0.12993134]]
    slice_cmap = {slice_name_list[i]: colors_for_slices[i] for i in range(len(slice_name_list))}
    colors = list(adata_concat.obs['slice_name'].astype('str').map(slice_cmap))
    plt.figure(figsize=(5, 5))
    if frame_color:
        plt.rc('axes', edgecolor=frame_color, linewidth=2)
    plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
    plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                    labelleft=False, labelbottom=False, grid_alpha=0)
    plt.title(f'Slices ({model}/Sample {samples[group_idx]})', fontsize=14)
    if save:
        save_path = save_root + f"/{model}_group{group_idx}_slices_umap.{file_format}"
        plt.savefig(save_path)

    colors = list(adata_concat.obs[ground_truth_key].astype('str').map(cluster_to_color_map))
    plt.figure(figsize=(5, 5))
    if frame_color:
        plt.rc('axes', edgecolor=frame_color, linewidth=2)
    plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
    plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                    labelleft=False, labelbottom=False, grid_alpha=0)
    plt.title(f'Annotated Spot-types ({model}/Sample {samples[group_idx]})', fontsize=14)
    if save:
        save_path = save_root + f"/{model}_group{group_idx}_annotated_clusters_umap.{file_format}"
        plt.savefig(save_path)

    colors = list(adata_concat.obs[matched_clusters_key].map(matched_to_color_map))
    plt.figure(figsize=(5, 5))
    if frame_color:
        plt.rc('axes', edgecolor=frame_color, linewidth=2)
    plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
    plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                    labelleft=False, labelbottom=False, grid_alpha=0)
    plt.title(f'Identified Clusters ({model}/Sample {samples[group_idx]})', fontsize=14)
    if save:
        save_path = save_root + f"/{model}_group{group_idx}_identified_clusters_umap.{file_format}"
        plt.savefig(save_path)

    if plot:
        plt.show()

In [None]:
save_dir = '../../results/DLPFC_Maynard2021/'
save = True

# DLPFC
data_dir = '../../data/STdata/10xVisium/DLPFC_Maynard2021/'
sample_group_list = [['151507', '151508', '151509', '151510'],
                     ['151669', '151670', '151671', '151672'],
                     ['151673', '151674', '151675', '151676']]
cls_list = ['Layer1', 'Layer2', 'Layer3', 'Layer4', 'Layer5', 'Layer6', 'WM']
num_clusters_list = [7, 5, 7]
samples = ['A', 'B', 'C']

file_format = 'pdf'

layer_to_color_map = {'Layer{0}'.format(i+1): sns.color_palette()[i] for i in range(6)}
layer_to_color_map['WM'] = sns.color_palette()[6]
matched_to_color_map = {i+1: sns.color_palette()[i] for i in range(7)}

In [None]:
reducer = UMAP(n_neighbors=30, n_components=2, metric="correlation", n_epochs=None, learning_rate=1.0,
               min_dist=0.3, spread=1.0, set_op_mix_ratio=1.0, local_connectivity=1, repulsion_strength=1,
               negative_sample_rate=5, a=None, b=None, random_state=1234, metric_kwds=None,
               angular_rp_forest=False, verbose=False)

In [None]:
for idx in range(len(sample_group_list)):

    slice_name_list = sample_group_list[idx]
    slice_index_list = list(range(len(slice_name_list)))

    rna_list = []
    for sample in slice_name_list:
        adata = sc.read_visium(path=data_dir + f'{sample}/',
                               count_file=sample + '_filtered_feature_bc_matrix.h5')
        adata.var_names_make_unique()

        # read the annotation
        Ann_df = pd.read_csv(data_dir + f'{sample}/meta_data.csv', sep=',', index_col=0)

        if not all(Ann_df.index.isin(adata.obs_names)):
            raise ValueError("Some rows in the annotation file are not present in the adata.obs_names")

        adata.obs['image_row'] = Ann_df.loc[adata.obs_names, 'imagerow']
        adata.obs['image_col'] = Ann_df.loc[adata.obs_names, 'imagecol']
        adata.obs['Manual_Annotation'] = Ann_df.loc[adata.obs_names, 'ManualAnnotation']

        adata.obs_names = [x + '_' + sample for x in adata.obs_names]
        rna_list.append(adata)

    # concatenation
    adata_concat = ad.concat(rna_list, label="slice_name", keys=slice_name_list)

    # plot clustering results
    embed = pd.read_csv(save_dir + f'/INSTINCT_embed_group{idx}.csv', header=None).values
    adata_concat.obsm['latent'] = embed
    gm = GaussianMixture(n_components=num_clusters_list[idx], covariance_type='tied', random_state=1234)
    y = gm.fit_predict(adata_concat.obsm['latent'], y=None)
    adata_concat.obs["gm_clusters"] = pd.Series(y, index=adata_concat.obs.index, dtype='category')

    adata_concat = adata_concat[~adata_concat.obs['Manual_Annotation'].isna(), :]
    spots_count = [0]
    n = 0
    for k in range(len(rna_list)):
        rna_list[k] = rna_list[k][~rna_list[k].obs['Manual_Annotation'].isna(), :]
        num = rna_list[k].shape[0]
        n += num
        spots_count.append(n)

    if idx != 1:
        adata_concat.obs['matched_clusters'] = list(pd.Series(1 + match_cluster_labels(
            adata_concat.obs['Manual_Annotation'], adata_concat.obs["gm_clusters"]),
                                                              index=adata_concat.obs.index, dtype='category'))
    else:
        adata_concat.obs['matched_clusters'] = list(pd.Series(3 + match_cluster_labels(
            adata_concat.obs['Manual_Annotation'], adata_concat.obs["gm_clusters"]),
                                                              index=adata_concat.obs.index, dtype='category'))
    my_clusters = np.sort(list(set(adata_concat.obs['matched_clusters'])))

    for i in range(len(rna_list)):
        rna_list[i].obs['matched_clusters'] = list(adata_concat.obs['matched_clusters'][spots_count[i]:spots_count[i+1]])

    sp_embedding = reducer.fit_transform(adata_concat.obsm['latent'])

    plot_DLPFC(rna_list, adata_concat, 'Manual_Annotation', 'matched_clusters', 'INSTINCT', idx, layer_to_color_map,
               matched_to_color_map, my_clusters, slice_name_list, cls_list, sp_embedding,
               save_root=save_dir, frame_color='darkviolet', file_format=file_format,
               save=save, plot=True)