# Integrating all six slices from spatial ATAC ME dataset

In [None]:
import os
import csv
import torch
import numpy as np
import pandas as pd
import anndata as ad

from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture

import INSTINCT

import warnings
warnings.filterwarnings("ignore")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import fowlkes_mallows_score
from sklearn.metrics.cluster import homogeneity_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import completeness_score
import sklearn
import sklearn.neighbors
import networkx as nx

import scib
import scanpy as sc

### Load raw data
The peaks have already been merged by the original study, so their is no need to merge them again.

In [None]:
# mouse embryo
data_dir = '../../data/spCASdata/MouseEmbryo_Llorens-Bobadilla2023/spATAC/'
save_dir = '../../results/MouseEmbryo_Llorens-Bobadilla2023/all/'

slice_name_list = ['E12_5-S1', 'E12_5-S2', 'E13_5-S1', 'E13_5-S2', 'E15_5-S1', 'E15_5-S2']
slice_index_list = list(range(len(slice_name_list)))

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# load dataset
cas_list = [ad.read_h5ad(data_dir + sample + '.h5ad') for sample in slice_name_list]
for i in range(len(cas_list)):
    cas_list[i].obs_names = [x + '_' + slice_name_list[i] for x in cas_list[i].obs_names]

# concatenation
adata_concat = ad.concat(cas_list, label="slice_name", keys=slice_name_list)

### Data preprocessing
Since the data matirces are fragment count matrices already, we set use_fragment_count=False

In [None]:
# preprocess CAS data
# peaks are already merged and fragment counts are stored in the data matrices
print('Start preprocessing')
INSTINCT.preprocess_CAS(cas_list, adata_concat, use_fragment_count=False, min_cells_rate=0.003)
print('Done!')
print(adata_concat)

In [None]:
adata_concat.write_h5ad(save_dir + f"preprocessed_concat.h5ad")
for i in range(len(slice_name_list)):
    cas_list[i].write_h5ad(save_dir + f"filtered_{slice_name_list[i]}.h5ad")

cas_list = [ad.read_h5ad(save_dir + f"filtered_{sample}.h5ad") for sample in slice_name_list]
origin_concat = ad.concat(cas_list, label="slice_idx", keys=slice_index_list)
adata_concat = ad.read_h5ad(save_dir + f"preprocessed_concat.h5ad")

### Perform PCA

In [None]:
print(f'Applying PCA to reduce the feature dimension to 100 ...')
pca = PCA(n_components=100, random_state=1234)
input_matrix = pca.fit_transform(adata_concat.X.toarray())
np.save(save_dir + f'input_matrix.npy', input_matrix)
print('Done !')

input_matrix = np.load(save_dir + 'input_matrix.npy')
adata_concat.obsm['X_pca'] = input_matrix

### Create neighbor graph

In [None]:
# calculate the spatial graph
INSTINCT.create_neighbor_graph(cas_list, adata_concat)

### Run model

In [None]:
INSTINCT_model = INSTINCT.INSTINCT_Model(cas_list, adata_concat, device=device)

INSTINCT_model.train(report_loss=True, report_interval=100)

INSTINCT_model.eval(cas_list)

In [None]:
result = ad.concat(cas_list, label="slice_idx", keys=slice_index_list)

with open(save_dir + 'INSTINCT_embed.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(result.obsm['INSTINCT_latent'])

with open(save_dir + 'INSTINCT_noise_embed.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(result.obsm['INSTINCT_latent_noise'])

### Clustering 

In [None]:
def match_cluster_labels(true_labels, est_labels):
    true_labels_arr = np.array(list(true_labels))
    est_labels_arr = np.array(list(est_labels))

    org_cat = list(np.sort(list(pd.unique(true_labels))))
    est_cat = list(np.sort(list(pd.unique(est_labels))))

    B = nx.Graph()
    B.add_nodes_from([i + 1 for i in range(len(org_cat))], bipartite=0)
    B.add_nodes_from([-j - 1 for j in range(len(est_cat))], bipartite=1)

    for i in range(len(org_cat)):
        for j in range(len(est_cat)):
            weight = np.sum((true_labels_arr == org_cat[i]) * (est_labels_arr == est_cat[j]))
            B.add_edge(i + 1, -j - 1, weight=-weight)

    match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B)

    if len(org_cat) >= len(est_cat):
        return np.array([match[-est_cat.index(c) - 1] - 1 for c in est_labels_arr])
    else:
        unmatched = [c for c in est_cat if not (-est_cat.index(c) - 1) in match.keys()]
        l = []
        for c in est_labels_arr:
            if (-est_cat.index(c) - 1) in match:
                l.append(match[-est_cat.index(c) - 1] - 1)
            else:
                l.append(len(org_cat) + unmatched.index(c))
        return np.array(l)

In [None]:
gm = GaussianMixture(n_components=11, covariance_type='tied', random_state=1234)
y = gm.fit_predict(result.obsm['INSTINCT_latent'], y=None)
result.obs["gm_clusters"] = pd.Series(y, index=result.obs.index, dtype='category')
result.obs['matched_clusters'] = pd.Series(match_cluster_labels(result.obs['clusters'],
                                                                result.obs["gm_clusters"]),
                                           index=result.obs.index, dtype='category')

### Evaluation

In [None]:
def cluster_metrics(target, pred):
    target = np.array(target)
    pred = np.array(pred)
    
    ari = adjusted_rand_score(target, pred)
    ami = adjusted_mutual_info_score(target, pred)
    nmi = normalized_mutual_info_score(target, pred)
    fmi = fowlkes_mallows_score(target, pred)
    comp = completeness_score(target, pred)
    homo = homogeneity_score(target, pred)
    print('ARI: %.3f, AMI: %.3f, NMI: %.3f, FMI: %.3f, Comp: %.3f, Homo: %.3f' % (ari, ami, nmi, fmi, comp, homo))
    
    return ari, ami, nmi, fmi, comp, homo


def mean_average_precision(x: np.ndarray, y: np.ndarray, k: int=30, **kwargs) -> float:
    r"""
    Mean average precision
    Parameters
    ----------
    x
        Coordinates
    y
        Cell_type/Layer labels
    k
        k neighbors
    **kwargs
        Additional keyword arguments are passed to
        :class:`sklearn.neighbors.NearestNeighbors`
    Returns
    -------
    map
        Mean average precision
    """
    
    def _average_precision(match: np.ndarray) -> float:
        if np.any(match):
            cummean = np.cumsum(match) / (np.arange(match.size) + 1)
            return cummean[match].mean().item()
        return 0.0
    
    y = np.array(y)
    knn = sklearn.neighbors.NearestNeighbors(n_neighbors=min(y.shape[0], k + 1), **kwargs).fit(x)
    nni = knn.kneighbors(x, return_distance=False)
    match = np.equal(y[nni[:, 1:]], np.expand_dims(y, 1))
    
    return np.apply_along_axis(_average_precision, 1, match).mean().item()


def rep_metrics(adata, origin_concat, use_rep, label_key, batch_key, k_map=30):
    if label_key not in adata.obs or batch_key not in adata.obs or use_rep not in adata.obsm:
        print("KeyError")
        return None
    
    adata.obs[label_key] = adata.obs[label_key].astype(str).astype("category")
    adata.obs[batch_key] = adata.obs[batch_key].astype(str).astype("category")
    origin_concat.X = origin_concat.X.astype(float)
    sc.pp.neighbors(adata, use_rep=use_rep)

    MAP = mean_average_precision(adata.obsm[use_rep].copy(), adata.obs[label_key], k=k_map)
    cell_type_ASW = scib.me.silhouette(adata, label_key=label_key, embed=use_rep)
    # g_iLISI = scib.me.ilisi_graph(adata, batch_key=batch_key, type_="embed", use_rep=use_rep)
    batch_ASW = scib.me.silhouette_batch(adata, batch_key=batch_key, label_key=label_key, embed=use_rep, verbose=False)
    batch_PCR = scib.me.pcr_comparison(origin_concat, adata, covariate=batch_key, embed=use_rep)
    kBET = scib.me.kBET(adata, batch_key=batch_key, label_key=label_key, type_='embed', embed=use_rep)
    g_conn = scib.me.graph_connectivity(adata, label_key=label_key)
    print('mAP: %.3f, Cell type ASW: %.3f, Batch ASW: %.3f, Batch PCR: %.3f, kBET: %.3f, Graph connectivity: %.3f' %
          (MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn))
    
    return MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn

In [None]:
ari, ami, nmi, fmi, comp, homo = cluster_metrics(result.obs['clusters'],
                                                 result.obs['matched_clusters'].tolist())
map, c_asw, b_asw, b_pcr, kbet, g_conn = rep_metrics(result, origin_concat, use_rep='INSTINCT_latent',
                                                     label_key='clusters', batch_key='slice_idx')

### Spatial domain identification and UMAP visualization

In [None]:
import numpy as np
import anndata as ad
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from umap.umap_ import UMAP
from sklearn.mixture import GaussianMixture

from matplotlib.lines import Line2D
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

In [None]:
save_dir = '../../results/MouseEmbryo_Llorens-Bobadilla2023/all/'
save = True

cluster_list = ['Forebrain', 'Midbrain', 'Hindbrain', 'Periventricular', 'Meningeal_PNS_1', 'Meningeal_PNS_2',
                'Internal', 'Facial_bone', 'Muscle_heart', 'Limb', 'Liver']

label_list = ['Forebrain', 'Midbrain', 'Hindbrain', 'Periventricular', 'Meningeal/PNS_1', 'Meningeal/PNS_2',
              'Internal', 'Facial/bone', 'Muscle/heart', 'Limb', 'Liver']

color_list = ['royalblue', 'dodgerblue', 'deepskyblue', 'forestgreen', 'yellowgreen', 'y',
              'grey', 'crimson', 'deeppink', 'orchid', 'orange']

order_list = [1, 8, 2, 10, 6, 7, 3, 0, 9, 4, 5]

cluster_to_color_map = {cluster: color for cluster, color in zip(cluster_list, color_list)}
order_to_cluster_map = {order: cluster for order, cluster in zip(order_list, cluster_list)}

reducer = UMAP(n_neighbors=30, n_components=2, metric="correlation", n_epochs=None, learning_rate=1.0,
               min_dist=0.3, spread=1.0, set_op_mix_ratio=1.0, local_connectivity=1, repulsion_strength=1,
               negative_sample_rate=5, a=None, b=None, random_state=1234, metric_kwds=None,
               angular_rp_forest=False, verbose=False)

slice_name_list = ['E12_5-S1', 'E12_5-S2', 'E13_5-S1', 'E13_5-S2', 'E15_5-S1', 'E15_5-S2']
cas_list = [ad.read_h5ad(save_dir + f"filtered_{sample}.h5ad") for sample in slice_name_list]
adata_concat = ad.concat(cas_list, label='slice_name', keys=slice_name_list)

spots_count = [0]
n = 0
for sample in cas_list:
    num = sample.shape[0]
    n += num
    spots_count.append(n)

embed = pd.read_csv(save_dir + f'INSTINCT_embed.csv', header=None).values
adata_concat.obsm['latent'] = embed

gm = GaussianMixture(n_components=len(cluster_list), covariance_type='tied', random_state=1234)
y = gm.fit_predict(adata_concat.obsm['latent'], y=None)
adata_concat.obs["gm_clusters"] = pd.Series(y, index=adata_concat.obs.index, dtype='category')
adata_concat.obs['matched_clusters'] = pd.Series(match_cluster_labels(
    adata_concat.obs['clusters'], adata_concat.obs["gm_clusters"]),
    index=adata_concat.obs.index, dtype='category')
# adata_concat.obs['matched_clusters'] = list(adata_concat.obs['matched_clusters'].map(order_to_cluster_map))
my_clusters = np.sort(list(set(adata_concat.obs['matched_clusters'])))
matched_colors = [cluster_to_color_map[order_to_cluster_map[order]] for order in my_clusters]
matched_to_color_map = {matched: color for matched, color in zip(my_clusters, matched_colors)}

for i in range(len(cas_list)):
    cas_list[i].obs['matched_clusters'] = adata_concat.obs['matched_clusters'][spots_count[i]:spots_count[i+1]]

sp_embedding = reducer.fit_transform(adata_concat.obsm['latent'])

In [None]:
def plot_mouseembryo_6(cas_list, adata_concat, ground_truth_key, matched_clusters_key, model,
                       cluster_to_color_map, matched_to_color_map, cluster_orders,
                       slice_name_list, sp_embedding,
                       save_root=None, frame_color=None, save=False, plot=False):

    fig, axs = plt.subplots(2, 3, figsize=(10, 6))
    fig.suptitle(f'{model} Clustering Results', fontsize=16)
    for i in range(len(cas_list)):
        if slice_name_list[i] == 'E12_5-S1' or slice_name_list[i] == 'E12_5-S2':
            size = 20
        else:
            size = 15
        if slice_name_list[i] == 'E15_5-S1':
            axs[int(i % 2), int(i / 2)].invert_xaxis()
            axs[int(i % 2), int(i / 2)].invert_yaxis()
        cluster_colors = list(cas_list[i].obs[matched_clusters_key].map(matched_to_color_map))
        axs[int(i % 2), int(i / 2)].scatter(cas_list[i].obsm['spatial'][:, 1], cas_list[i].obsm['spatial'][:, 0],
                                            linewidth=0.5, s=size, marker=".", color=cluster_colors, alpha=0.9)
        axs[int(i % 2), int(i / 2)].set_title(f'{slice_name_list[i]} (Cluster Results)', size=12)
        axs[int(i % 2), int(i / 2)].axis('off')

    legend_handles = [
        Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=matched_to_color_map[order],
               label=f'{i}') for i, order in enumerate(cluster_orders)
    ]
    axs[0, 2].legend(
        handles=legend_handles,
        fontsize=8, title='Clusters', title_fontsize=10, bbox_to_anchor=(1, 1))
    plt.gcf().subplots_adjust(left=0.05, top=None, bottom=None, right=0.85)
    if save:
        save_path = save_root + f'/{model}_clustering_results.pdf'
        plt.savefig(save_path)

    n_spots = adata_concat.shape[0]
    size = 10000 / n_spots
    order = np.arange(n_spots)
    colors_for_slices = ['deeppink', 'hotpink', 'darkgoldenrod', 'goldenrod', 'c', 'cyan']
    slice_cmap = {slice_name_list[i]: colors_for_slices[i] for i in range(len(slice_name_list))}
    colors = list(adata_concat.obs['slice_name'].astype('str').map(slice_cmap))
    plt.figure(figsize=(5, 5))
    if frame_color:
        plt.rc('axes', edgecolor=frame_color, linewidth=2)
    plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
    plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                    labelleft=False, labelbottom=False, grid_alpha=0)
    legend_handles = [
        Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=slice_cmap[slice_name_list[i]],
               label=slice_name_list[i])
        for i in range(len(slice_name_list))
    ]
    plt.legend(handles=legend_handles, fontsize=8, title='Slices', title_fontsize=10,
               loc='upper left')
    plt.title(f'Slices ({model})', fontsize=16)
    if save:
        save_path = save_root + f"/{model}_slices_umap.pdf"
        plt.savefig(save_path)

    colors = list(adata_concat.obs[ground_truth_key].astype('str').map(cluster_to_color_map))
    plt.figure(figsize=(5, 5))
    if frame_color:
        plt.rc('axes', edgecolor=frame_color, linewidth=2)
    plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
    plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                    labelleft=False, labelbottom=False, grid_alpha=0)
    plt.title(f'Annotated Spot-types ({model})', fontsize=16)
    if save:
        save_path = save_root + f"/{model}_annotated_clusters_umap.pdf"
        plt.savefig(save_path)

    colors = list(adata_concat.obs[matched_clusters_key].map(matched_to_color_map))
    plt.figure(figsize=(5, 5))
    if frame_color:
        plt.rc('axes', edgecolor=frame_color, linewidth=2)
    plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
    plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                    labelleft=False, labelbottom=False, grid_alpha=0)
    plt.title(f'Identified Clusters ({model})', fontsize=16)
    if save:
        save_path = save_root + f"/{model}_identified_clusters_umap.pdf"
        plt.savefig(save_path)

    if plot:
        plt.show()

In [None]:
plot_mouseembryo_6(cas_list, adata_concat, 'clusters', 'matched_clusters', 'INSTINCT', cluster_to_color_map,
                   matched_to_color_map, my_clusters, slice_name_list, sp_embedding,
                   save_root=save_dir, frame_color='darkviolet', save=save, plot=True)