# Integrating the four S1 slices from MISAR-seq MB dataset

Before proceeding, ensure that INSTINCT is installed.  
You can follow the instruction provided on [Github](https://github.com/yyLIU12138/INSTINCT) for installation.

In [None]:
import os
import csv
import torch
import numpy as np
import pandas as pd
import anndata as ad
import networkx as nx

from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture

import warnings
warnings.filterwarnings("ignore")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import INSTINCT

### Load the raw data

In [None]:
data_dir = '../../data/spMOdata/EpiTran_MouseBrain_Jiang2023/preprocessed/'
slice_name_list = ['E11_0-S1', 'E13_5-S1', 'E15_5-S1', 'E18_5-S1']
slice_index_list = list(range(len(slice_name_list)))

save_dir = '../../results/MouseBrain_Jiang2023/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

cas_dict = {}
for sample in slice_name_list:
    sample_data = ad.read_h5ad(data_dir + sample + '_atac.h5ad')

    if 'insertion' in sample_data.obsm:
        del sample_data.obsm['insertion']

    cas_dict[sample] = sample_data
cas_list = [cas_dict[sample] for sample in slice_name_list]

### Merge the peaks
Initially, nearly no common peaks are shared among the four slices.  
Therefore, we nead to merge the four peak sets into a common merged peak set, such that the four slices have indentical peak sets.

In [None]:
# merge peaks
cas_list = INSTINCT.peak_sets_alignment(cas_list)

# save the merged data
for idx, adata in enumerate(cas_list):
    adata.write_h5ad(f'{data_dir}merged_{slice_name_list[idx]}_atac.h5ad')

In [None]:
# load the merged data
cas_list = [ad.read_h5ad(data_dir + 'merged_' + sample + '_atac.h5ad') for sample in slice_name_list]

# rename the spots
for j in range(len(cas_list)):
    cas_list[j].obs_names = [x + '_' + slice_name_list[j] for x in cas_list[j].obs_names]

# concatenation
adata_concat = ad.concat(cas_list, label="slice_name", keys=slice_name_list)
# adata_concat.obs_names_make_unique()
print(adata_concat.shape)

### Data preprocessing
The concatenated data matrix is read count matrix, so we set use_fragment_count=True to transform it into fragment count matrix.  

Please not that data samples in cas_list are only filtered but not transformed.

In [None]:
# preprocess CAS data
print('Start preprocessing')
INSTINCT.preprocess_CAS(cas_list, adata_concat, use_fragment_count=True, min_cells_rate=0.03)
print(adata_concat.shape)
print('Done!')

# save the data
adata_concat.write_h5ad(save_dir + f"preprocessed_concat_atac.h5ad")
for i in range(len(slice_name_list)):
    cas_list[i].write_h5ad(save_dir + f"filtered_merged_{slice_name_list[i]}_atac.h5ad")

# load the data
cas_list = [ad.read_h5ad(save_dir + f"filtered_merged_{sample}_atac.h5ad") for sample in slice_name_list]
origin_concat = ad.concat(cas_list, label="slice_name", keys=slice_name_list)
adata_concat = ad.read_h5ad(save_dir + f"preprocessed_concat_atac.h5ad")


### Perform PCA
We performed PCA and the first 100 pcs are used as input.  
The PCA results should be added to .obsm['X_pca'] of the concatenated data, or you can name it by yourself and set the 'input_mat_key' of INSTINCT_Model to the corresponding name.

In [None]:
print(f'Applying PCA to reduce the feature dimension to 100 ...')
pca = PCA(n_components=100, random_state=1234)
input_matrix = pca.fit_transform(adata_concat.X.toarray())
np.save(save_dir + 'input_matrix_atac.npy', input_matrix)
print('Done !')

input_matrix = np.load(save_dir + 'input_matrix_atac.npy')
adata_concat.obsm['X_pca'] = input_matrix

### Create neighbor graph
The neighbor graph for each slice is added to .uns['graph_list'] of the concatenated data.

In [None]:
# calculate the spatial graph
INSTINCT.create_neighbor_graph(cas_list, adata_concat)

### Run the model
The latent representation matrices containing representations for spots are included to .obsm['INSTINCT_latent'] of corresponding slices.

In [None]:
INSTINCT_model = INSTINCT.INSTINCT_Model(cas_list, adata_concat, device=device)

INSTINCT_model.train(report_loss=True, report_interval=100)

INSTINCT_model.eval(cas_list)

In [None]:
result = ad.concat(cas_list, label="slice_name", keys=slice_name_list)

with open(save_dir + 'INSTINCT_embed.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(result.obsm['INSTINCT_latent'])

with open(save_dir + 'INSTINCT_noise_embed.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(result.obsm['INSTINCT_latent_noise'])

### Clustering

In [None]:
def match_cluster_labels(true_labels, est_labels):
    true_labels_arr = np.array(list(true_labels))
    est_labels_arr = np.array(list(est_labels))

    org_cat = list(np.sort(list(pd.unique(true_labels))))
    est_cat = list(np.sort(list(pd.unique(est_labels))))

    B = nx.Graph()
    B.add_nodes_from([i + 1 for i in range(len(org_cat))], bipartite=0)
    B.add_nodes_from([-j - 1 for j in range(len(est_cat))], bipartite=1)

    for i in range(len(org_cat)):
        for j in range(len(est_cat)):
            weight = np.sum((true_labels_arr == org_cat[i]) * (est_labels_arr == est_cat[j]))
            B.add_edge(i + 1, -j - 1, weight=-weight)

    match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B)

    if len(org_cat) >= len(est_cat):
        return np.array([match[-est_cat.index(c) - 1] - 1 for c in est_labels_arr])
    else:
        unmatched = [c for c in est_cat if not (-est_cat.index(c) - 1) in match.keys()]
        l = []
        for c in est_labels_arr:
            if (-est_cat.index(c) - 1) in match:
                l.append(match[-est_cat.index(c) - 1] - 1)
            else:
                l.append(len(org_cat) + unmatched.index(c))
        return np.array(l)

In [None]:
gm = GaussianMixture(n_components=16, covariance_type='tied', random_state=1234)
y = gm.fit_predict(result.obsm['INSTINCT_latent'], y=None)
result.obs["gm_clusters"] = pd.Series(y, index=result.obs.index, dtype='category')
result.obs['matched_clusters'] = pd.Series(match_cluster_labels(result.obs['Annotation_for_Combined'],
                                                                result.obs["gm_clusters"]),
                                           index=result.obs.index, dtype='category')

### Evaluation

In [None]:
import numpy as np
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import fowlkes_mallows_score
from sklearn.metrics.cluster import homogeneity_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import completeness_score
import sklearn
import sklearn.neighbors
import pandas as pd
import networkx as nx

import scib
import scanpy as sc


def cluster_metrics(target, pred):
    target = np.array(target)
    pred = np.array(pred)
    
    ari = adjusted_rand_score(target, pred)
    ami = adjusted_mutual_info_score(target, pred)
    nmi = normalized_mutual_info_score(target, pred)
    fmi = fowlkes_mallows_score(target, pred)
    comp = completeness_score(target, pred)
    homo = homogeneity_score(target, pred)
    print('ARI: %.3f, AMI: %.3f, NMI: %.3f, FMI: %.3f, Comp: %.3f, Homo: %.3f' % (ari, ami, nmi, fmi, comp, homo))
    
    return ari, ami, nmi, fmi, comp, homo


def mean_average_precision(x: np.ndarray, y: np.ndarray, k: int=30, **kwargs) -> float:
    r"""
    Mean average precision
    Parameters
    ----------
    x
        Coordinates
    y
        Cell_type/Layer labels
    k
        k neighbors
    **kwargs
        Additional keyword arguments are passed to
        :class:`sklearn.neighbors.NearestNeighbors`
    Returns
    -------
    map
        Mean average precision
    """
    
    def _average_precision(match: np.ndarray) -> float:
        if np.any(match):
            cummean = np.cumsum(match) / (np.arange(match.size) + 1)
            return cummean[match].mean().item()
        return 0.0
    
    y = np.array(y)
    knn = sklearn.neighbors.NearestNeighbors(n_neighbors=min(y.shape[0], k + 1), **kwargs).fit(x)
    nni = knn.kneighbors(x, return_distance=False)
    match = np.equal(y[nni[:, 1:]], np.expand_dims(y, 1))
    
    return np.apply_along_axis(_average_precision, 1, match).mean().item()


def bio_conservation_metrics(adata, use_rep, label_key, batch_key, k_map=30, threshold=1):
    if label_key not in adata.obs or batch_key not in adata.obs or use_rep not in adata.obsm:
        print("KeyError")
        return None

    adata.obs[label_key] = adata.obs[label_key].astype(str).astype("category")
    adata.obs[batch_key] = adata.obs[batch_key].astype(str).astype("category")
    # sc.pp.neighbors(adata, use_rep=use_rep, random_state=1234)

    MAP = mean_average_precision(adata.obsm[use_rep].copy(), adata.obs[label_key], k=k_map)
    cell_type_ASW = scib.me.silhouette(adata, label_key=label_key, embed=use_rep)
    isolated_asw = scib.me.isolated_labels_asw(adata, batch_key=batch_key, label_key=label_key, embed=use_rep,
                                               iso_threshold=threshold)
    isolated_f1 = scib.me.isolated_labels_f1(adata, batch_key=batch_key, label_key=label_key, embed=use_rep,
                                             iso_threshold=threshold)

    print('mAP: %.3f, Cell type ASW: %.3f, Isolated label ASW: %.3f, Isolated label F1: %.3f' %
          (MAP, cell_type_ASW, isolated_asw, isolated_f1))

    return MAP, cell_type_ASW, isolated_asw, isolated_f1


def batch_correction_metrics(adata, origin_concat, use_rep, label_key, batch_key):
    if label_key not in adata.obs or batch_key not in adata.obs or use_rep not in adata.obsm:
        print("KeyError")
        return None

    adata.obs[label_key] = adata.obs[label_key].astype(str).astype("category")
    adata.obs[batch_key] = adata.obs[batch_key].astype(str).astype("category")
    origin_concat.X = origin_concat.X.astype(float)
    sc.pp.neighbors(adata, use_rep=use_rep, random_state=1234)

    # g_iLISI = scib.me.ilisi_graph(adata, batch_key=batch_key, type_="embed", use_rep=use_rep)
    batch_ASW = scib.me.silhouette_batch(adata, batch_key=batch_key, label_key=label_key, embed=use_rep, verbose=False)
    batch_PCR = scib.me.pcr_comparison(origin_concat, adata, covariate=batch_key, embed=use_rep)
    kBET = scib.me.kBET(adata, batch_key=batch_key, label_key=label_key, type_='embed', embed=use_rep)
    g_conn = scib.me.graph_connectivity(adata, label_key=label_key)
    print('Batch ASW: %.3f, Batch PCR: %.3f, kBET: %.3f, Graph connectivity: %.3f' %
          (batch_ASW, batch_PCR, kBET, g_conn))
    return batch_ASW, batch_PCR, kBET, g_conn

In [None]:
ari, ami, nmi, fmi, comp, homo = cluster_metrics(result.obs['Annotation_for_Combined'],
                                                 result.obs['matched_clusters'].tolist())
map, c_asw, i_asw, i_f1 = bio_conservation_metrics(result, use_rep='INSTINCT_latent',
                                                   label_key='Annotation_for_Combined', batch_key='slice_name')
b_asw, b_pcr, kbet, g_conn = batch_correction_metrics(result, origin_concat, use_rep='INSTINCT_latent',
                                                      label_key='Annotation_for_Combined',
                                                      batch_key='slice_name')

### Spatial domain identification and UMAP visualization

In [None]:
cls_list = ['Primary_brain_1', 'Primary_brain_2', 'Midbrain', 'Diencephalon_and_hindbrain', 'Basal_plate_of_hindbrain',
            'Subpallium_1', 'Subpallium_2', 'Cartilage_1', 'Cartilage_2', 'Cartilage_3', 'Cartilage_4',
            'Mesenchyme', 'Muscle', 'Thalamus', 'DPallm', 'DPallv']

colors_for_clusters = ['red', 'tomato', 'chocolate', 'orange', 'goldenrod',
                       'b', 'royalblue', 'g', 'limegreen', 'lime', 'springgreen',
                       'deepskyblue', 'pink', 'fuchsia', 'yellowgreen', 'olivedrab']

order_for_clusters = [11, 12, 9, 7, 0, 13, 14, 1, 2, 3, 4, 8, 10, 15, 5, 6]

cluster_to_color_map = {cluster: color for cluster, color in zip(cls_list, colors_for_clusters)}
order_to_cluster_map = {order: cluster for order, cluster in zip(order_for_clusters, cls_list)}

from umap.umap_ import UMAP
reducer = UMAP(n_neighbors=30, n_components=2, metric="correlation", n_epochs=None, learning_rate=1.0,
               min_dist=0.3, spread=1.0, set_op_mix_ratio=1.0, local_connectivity=1, repulsion_strength=1,
               negative_sample_rate=5, a=None, b=None, random_state=1234, metric_kwds=None,
               angular_rp_forest=False, verbose=False)

gm = GaussianMixture(n_components=len(cls_list), covariance_type='tied', random_state=1234)
y = gm.fit_predict(result.obsm['INSTINCT_latent'], y=None)
result.obs["gm_clusters"] = pd.Series(y, index=result.obs.index, dtype='category')
result.obs['matched_clusters'] = pd.Series(match_cluster_labels(
    result.obs['Annotation_for_Combined'], result.obs["gm_clusters"]),
    index=result.obs.index, dtype='category')
my_clusters = np.sort(list(set(result.obs['matched_clusters'])))
matched_colors = [cluster_to_color_map[order_to_cluster_map[order]] for order in my_clusters]
matched_to_color_map = {matched: color for matched, color in zip(my_clusters, matched_colors)}

spots_count = [0]
n = 0
for sample in cas_list:
    num = sample.shape[0]
    n += num
    spots_count.append(n)

for i in range(len(cas_list)):
    cas_list[i].obs['matched_clusters'] = result.obs['matched_clusters'][spots_count[i]:spots_count[i+1]]

sp_embedding = reducer.fit_transform(result.obsm['INSTINCT_latent'])

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.lines import Line2D
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42


def plot_mousebrain(cas_list, adata_concat, ground_truth_key, matched_clusters_key, model,
                    cluster_to_color_map, matched_to_color_map, cluster_orders, slice_name_list, cls_list,
                    sp_embedding, save_root=None, frame_color=None, save=False, plot=False):

    fig, axs = plt.subplots(2, 4, figsize=(15, 7))
    fig.suptitle(f'{model} Clustering Results', fontsize=16)
    for i in range(len(cas_list)):
        real_colors = list(cas_list[i].obs[ground_truth_key].astype('str').map(cluster_to_color_map))
        axs[0, i].scatter(cas_list[i].obsm['spatial'][:, 0], cas_list[i].obsm['spatial'][:, 1], linewidth=0.5, s=30,
                          marker=".", color=real_colors, alpha=0.9)
        axs[0, i].set_title(f'{slice_name_list[i]} (Ground Truth)', size=12)
        axs[0, i].invert_yaxis()
        axs[0, i].axis('off')

        cluster_colors = list(cas_list[i].obs[matched_clusters_key].map(matched_to_color_map))
        axs[1, i].scatter(cas_list[i].obsm['spatial'][:, 0], cas_list[i].obsm['spatial'][:, 1], linewidth=0.5, s=30,
                          marker=".", color=cluster_colors, alpha=0.9)
        axs[1, i].set_title(f'{slice_name_list[i]} (Cluster Results)', size=12)
        axs[1, i].invert_yaxis()
        axs[1, i].axis('off')

    legend_handles_1 = [
        Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=cluster_to_color_map[cluster],
               label=cluster) for cluster in cls_list
    ]
    axs[0, 3].legend(
        handles=legend_handles_1,
        fontsize=8, title='Spot-types', title_fontsize=10, bbox_to_anchor=(1, 1.15))
    legend_handles_2 = [
        Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=matched_to_color_map[order],
               label=f'{i}') for i, order in enumerate(cluster_orders)
    ]
    axs[1, 3].legend(
        handles=legend_handles_2,
        fontsize=8, title='Clusters', title_fontsize=10, bbox_to_anchor=(1, 1.1))
    plt.gcf().subplots_adjust(left=0.05, top=None, bottom=None, right=0.85)
    if save:
        save_path = save_root + f'/{model}_clustering_results.pdf'
        plt.savefig(save_path, dpi=500)

    n_spots = adata_concat.shape[0]
    size = 10000 / n_spots
    order = np.arange(n_spots)
    colors_for_slices = [[0.2298057, 0.29871797, 0.75368315],
                         [0.70567316, 0.01555616, 0.15023281],
                         [0.2298057, 0.70567316, 0.15023281],
                         [0.5830223, 0.59200322, 0.12993134]]
    slice_cmap = {slice_name_list[i]: colors_for_slices[i] for i in range(len(slice_name_list))}
    colors = list(adata_concat.obs['slice_name'].astype('str').map(slice_cmap))
    plt.figure(figsize=(5, 5))
    if frame_color:
        plt.rc('axes', edgecolor=frame_color, linewidth=2)
    plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
    plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                    labelleft=False, labelbottom=False, grid_alpha=0)
    # legend_handles = [
    #     Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=slice_cmap[slice_name_list[i]], label=slice_name_list[i])
    #     for i in range(len(slice_name_list))
    # ]
    # plt.legend(handles=legend_handles, fontsize=8, title='Slices', title_fontsize=10,
    #            loc='upper left')
    plt.title(f'Slices ({model})', fontsize=16)
    if save:
        save_path = save_root + f"/{model}_slices_umap.pdf"
        plt.savefig(save_path)

    colors = list(adata_concat.obs[ground_truth_key].astype('str').map(cluster_to_color_map))
    plt.figure(figsize=(5, 5))
    if frame_color:
        plt.rc('axes', edgecolor=frame_color, linewidth=2)
    plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
    plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                    labelleft=False, labelbottom=False, grid_alpha=0)
    plt.title(f'Annotated Spot-types ({model})', fontsize=16)
    if save:
        save_path = save_root + f"/{model}_annotated_clusters_umap.pdf"
        plt.savefig(save_path)

    colors = list(adata_concat.obs[matched_clusters_key].map(matched_to_color_map))
    plt.figure(figsize=(5, 5))
    if frame_color:
        plt.rc('axes', edgecolor=frame_color, linewidth=2)
    plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
    plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
                    labelleft=False, labelbottom=False, grid_alpha=0)
    plt.title(f'Identified Clusters ({model})', fontsize=16)
    if save:
        save_path = save_root + f"/{model}_identified_clusters_umap.pdf"
        plt.savefig(save_path)

    if plot:
        plt.show()

In [None]:
save = False
plot = True
plot_mousebrain(cas_list, result, 'Annotation_for_Combined', 'matched_clusters', 'INSTINCT', cluster_to_color_map,
                matched_to_color_map, my_clusters, slice_name_list, cls_list, sp_embedding, save_root=save_dir,
                frame_color='darkviolet', save=save, plot=plot)