# Integrating samples form DLPFC dataset
Human dorsolateral prefrontal cortex (DLPFC) dataset is an SRT dataset, which contains three sets of slices, with each set contains four slices that exhibit vertical adjacent structure and came from one donor.  
In this case, we demonstrate that INSTINCT has the ability for integrating SRT samples.

In [None]:
import os
import csv
import torch
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc

from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture

import INSTINCT

import warnings
warnings.filterwarnings("ignore")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import fowlkes_mallows_score
from sklearn.metrics.cluster import homogeneity_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import completeness_score
import sklearn
import sklearn.neighbors
import networkx as nx
import scib

In [None]:
def match_cluster_labels(true_labels, est_labels):
    true_labels_arr = np.array(list(true_labels))
    est_labels_arr = np.array(list(est_labels))

    org_cat = list(np.sort(list(pd.unique(true_labels))))
    est_cat = list(np.sort(list(pd.unique(est_labels))))

    B = nx.Graph()
    B.add_nodes_from([i + 1 for i in range(len(org_cat))], bipartite=0)
    B.add_nodes_from([-j - 1 for j in range(len(est_cat))], bipartite=1)

    for i in range(len(org_cat)):
        for j in range(len(est_cat)):
            weight = np.sum((true_labels_arr == org_cat[i]) * (est_labels_arr == est_cat[j]))
            B.add_edge(i + 1, -j - 1, weight=-weight)

    match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B)

    if len(org_cat) >= len(est_cat):
        return np.array([match[-est_cat.index(c) - 1] - 1 for c in est_labels_arr])
    else:
        unmatched = [c for c in est_cat if not (-est_cat.index(c) - 1) in match.keys()]
        l = []
        for c in est_labels_arr:
            if (-est_cat.index(c) - 1) in match:
                l.append(match[-est_cat.index(c) - 1] - 1)
            else:
                l.append(len(org_cat) + unmatched.index(c))
        return np.array(l)


def cluster_metrics(target, pred):
    target = np.array(target)
    pred = np.array(pred)
    
    ari = adjusted_rand_score(target, pred)
    ami = adjusted_mutual_info_score(target, pred)
    nmi = normalized_mutual_info_score(target, pred)
    fmi = fowlkes_mallows_score(target, pred)
    comp = completeness_score(target, pred)
    homo = homogeneity_score(target, pred)
    print('ARI: %.3f, AMI: %.3f, NMI: %.3f, FMI: %.3f, Comp: %.3f, Homo: %.3f' % (ari, ami, nmi, fmi, comp, homo))
    
    return ari, ami, nmi, fmi, comp, homo


def mean_average_precision(x: np.ndarray, y: np.ndarray, k: int=30, **kwargs) -> float:
    r"""
    Mean average precision
    Parameters
    ----------
    x
        Coordinates
    y
        Cell_type/Layer labels
    k
        k neighbors
    **kwargs
        Additional keyword arguments are passed to
        :class:`sklearn.neighbors.NearestNeighbors`
    Returns
    -------
    map
        Mean average precision
    """
    
    def _average_precision(match: np.ndarray) -> float:
        if np.any(match):
            cummean = np.cumsum(match) / (np.arange(match.size) + 1)
            return cummean[match].mean().item()
        return 0.0
    
    y = np.array(y)
    knn = sklearn.neighbors.NearestNeighbors(n_neighbors=min(y.shape[0], k + 1), **kwargs).fit(x)
    nni = knn.kneighbors(x, return_distance=False)
    match = np.equal(y[nni[:, 1:]], np.expand_dims(y, 1))
    
    return np.apply_along_axis(_average_precision, 1, match).mean().item()


def rep_metrics(adata, origin_concat, use_rep, label_key, batch_key, k_map=30):
    if label_key not in adata.obs or batch_key not in adata.obs or use_rep not in adata.obsm:
        print("KeyError")
        return None
    
    adata.obs[label_key] = adata.obs[label_key].astype(str).astype("category")
    adata.obs[batch_key] = adata.obs[batch_key].astype(str).astype("category")
    origin_concat.X = origin_concat.X.astype(float)
    sc.pp.neighbors(adata, use_rep=use_rep)

    MAP = mean_average_precision(adata.obsm[use_rep].copy(), adata.obs[label_key], k=k_map)
    cell_type_ASW = scib.me.silhouette(adata, label_key=label_key, embed=use_rep)
    # g_iLISI = scib.me.ilisi_graph(adata, batch_key=batch_key, type_="embed", use_rep=use_rep)
    batch_ASW = scib.me.silhouette_batch(adata, batch_key=batch_key, label_key=label_key, embed=use_rep, verbose=False)
    batch_PCR = scib.me.pcr_comparison(origin_concat, adata, covariate=batch_key, embed=use_rep)
    kBET = scib.me.kBET(adata, batch_key=batch_key, label_key=label_key, type_='embed', embed=use_rep)
    g_conn = scib.me.graph_connectivity(adata, label_key=label_key)
    print('mAP: %.3f, Cell type ASW: %.3f, Batch ASW: %.3f, Batch PCR: %.3f, kBET: %.3f, Graph connectivity: %.3f' %
          (MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn))
    
    return MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn

### Run model
For preprocessing SRT data, we use INSTINCT.preprocess_SRT()

In [None]:
# DLPFC
data_dir = '../../data/STdata/10xVisium/DLPFC_Maynard2021/'
sample_group_list = [['151507', '151508', '151509', '151510'],
                     ['151669', '151670', '151671', '151672'],
                     ['151673', '151674', '151675', '151676']]
n_cluster_list = [7, 5, 7]

save_dir = '../../results/DLPFC_Maynard2021/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for idx in range(len(sample_group_list)):

    # load data
    slice_name_list = sample_group_list[idx]
    slice_index_list = list(range(len(slice_name_list)))

    rna_list = []
    for sample in slice_name_list:
        adata = sc.read_visium(path=data_dir + f'{sample}/', count_file=sample + '_filtered_feature_bc_matrix.h5')
        adata.var_names_make_unique()

        # read the annotation
        Ann_df = pd.read_csv(data_dir + f'{sample}/meta_data.csv', sep=',', index_col=0)

        if not all(Ann_df.index.isin(adata.obs_names)):
            raise ValueError("Some rows in the annotation file are not present in the adata.obs_names")

        adata.obs['image_row'] = Ann_df.loc[adata.obs_names, 'imagerow']
        adata.obs['image_col'] = Ann_df.loc[adata.obs_names, 'imagecol']
        adata.obs['Manual_Annotation'] = Ann_df.loc[adata.obs_names, 'ManualAnnotation']

        adata.obs_names = [x + '_' + sample for x in adata.obs_names]
        rna_list.append(adata)
        # print(adata.shape)

    # concatenation
    adata_concat = ad.concat(rna_list, label="slice_name", keys=slice_name_list)
    # adata_concat.obs_names_make_unique()

    # preprocess SRT data
    print('Start preprocessing')
    rna_list, adata_concat = INSTINCT.preprocess_SRT(rna_list, adata_concat, n_top_genes=5000)
    print(adata_concat.shape)
    print('Done!')

    origin_concat = ad.concat(rna_list, label="slice_name", keys=slice_index_list)

    print(f'Applying PCA to reduce the feature dimension to 100 ...')
    pca = PCA(n_components=100, random_state=1234)
    input_matrix = pca.fit_transform(adata_concat.X.toarray())
    np.save(save_dir + f'input_matrix_group{idx}.npy', input_matrix)
    print('Done !')

    input_matrix = np.load(save_dir + f'input_matrix_group{idx}.npy')
    adata_concat.obsm['X_pca'] = input_matrix

    # calculate the spatial graph
    INSTINCT.create_neighbor_graph(rna_list, adata_concat)

    spots_count = [0]
    n = 0
    for sample in rna_list:
        num = sample.shape[0]
        n += num
        spots_count.append(n)

    INSTINCT_model = INSTINCT.INSTINCT_Model(rna_list, adata_concat, device=device)

    INSTINCT_model.train(report_loss=True, report_interval=100)

    INSTINCT_model.eval(rna_list)

    result = ad.concat(rna_list, label="slice_name", keys=slice_index_list)

    with open(save_dir + f'INSTINCT_embed_group{idx}.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerows(result.obsm['INSTINCT_latent'])

    with open(save_dir + f'INSTINCT_noise_embed_group{idx}.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerows(result.obsm['INSTINCT_latent_noise'])

    gm = GaussianMixture(n_components=n_cluster_list[idx], covariance_type='tied', random_state=1234)
    y = gm.fit_predict(result.obsm['INSTINCT_latent'], y=None)
    result.obs["gm_clusters"] = pd.Series(y, index=result.obs.index, dtype='category')
    result.obs['matched_clusters'] = pd.Series(match_cluster_labels(result.obs['Manual_Annotation'],
                                                                    result.obs["gm_clusters"]),
                                               index=result.obs.index, dtype='category')

    ari, ami, nmi, fmi, comp, homo = cluster_metrics(result.obs['Manual_Annotation'],
                                                     result.obs['matched_clusters'].tolist())
    map, c_asw, b_asw, b_pcr, kbet, g_conn = rep_metrics(result, origin_concat, use_rep='INSTINCT_latent',
                                                         label_key='Manual_Annotation', batch_key='slice_name')