In [None]:
import matplotlib.pyplot as plt

from ALLCools.clustering import tsne
from ALLCools.plot import *
from ALLCools.integration import SeuratIntegration
from wmb import brain

import scanpy as sc
import anndata
from harmonypy import run_harmony

## Parameters

In [None]:
dataset = 'AIBS_SMART'

In [None]:
ref_region_type = 'CEMBA'
query_region_type = dataset

## Load

In [None]:
adata_merge = anndata.read_h5ad('final.h5ad')
adata_merge

## Harmony for embedding

In [None]:
ho = run_harmony(data_mat=adata_merge.obsm['X_pca_integrate'],
                 meta_data=adata_merge.obs,
                 nclust=50,
                 vars_use=['Modality'], 
                 max_iter_harmony=30)

In [None]:
adata_merge.obsm['X_harmony'] = ho.Z_corr.T

### TSNE

In [None]:
tsne(adata_merge, obsm='X_harmony')

### UMAP

In [None]:
adata_merge.obsm['X_pca'] = adata_merge.obsm['X_harmony']

sc.pp.neighbors(adata_merge)

In [None]:
min_dist = max(0.1, 1 - adata_merge.shape[0] / 60000)
sc.tl.umap(adata_merge, min_dist=min_dist)
del adata_merge.obsm['X_pca']

### Clustering

In [None]:
sc.tl.leiden(adata_merge, resolution=0.3)

## Plot

In [None]:
def plot(coord_base):
    fig, axes = plt.subplots(nrows=2,
                             ncols=3,
                             figsize=(12, 8),
                             dpi=300,
                             constrained_layout=True)

    mc_data = adata_merge[adata_merge.obs['Modality'] == 'mC']
    rna_data = adata_merge[adata_merge.obs['Modality'] == 'RNA']

    ax = axes[0, 0]
    categorical_scatter(data=rna_data,
                        coord_base=coord_base,
                        max_points=None,
                        hue=None,
                        scatter_kws=dict(color='lightgrey'),
                        ax=ax)
    categorical_scatter(
        data=mc_data,
        ax=ax,
        coord_base=coord_base,
        hue='leiden',
        text_anno='leiden',
        palette='tab20',
        max_points=None,
    )

    ax = axes[0, 1]
    categorical_scatter(data=rna_data,
                        coord_base=coord_base,
                        max_points=None,
                        hue=None,
                        scatter_kws=dict(color='lightgrey'),
                        ax=ax)
    categorical_scatter(
        data=mc_data,
        ax=ax,
        coord_base=coord_base,
        hue=f'{plot_key}',
        # text_anno='L2',
        palette='tab20',
        max_points=None,
    )

    ax = axes[0, 2]
    categorical_scatter(data=rna_data,
                        coord_base=coord_base,
                        max_points=None,
                        hue=None,
                        scatter_kws=dict(color='lightgrey'),
                        ax=ax)
    region_palette_1 = brain.get_dissection_region_palette(ref_region_type)
    categorical_scatter(
        data=mc_data,
        ax=ax,
        coord_base=coord_base,
        hue='DissectionRegion',
        palette=region_palette_1,
    )

    ax = axes[1, 0]
    categorical_scatter(data=mc_data,
                        coord_base=coord_base,
                        max_points=None,
                        hue=None,
                        scatter_kws=dict(color='lightgrey'),
                        ax=ax)
    categorical_scatter(
        data=rna_data,
        ax=ax,
        coord_base=coord_base,
        hue='leiden',
        text_anno='leiden',
        palette='tab20',
        max_points=None,
    )

    ax = axes[1, 1]
    categorical_scatter(data=mc_data,
                        coord_base=coord_base,
                        max_points=None,
                        hue=None,
                        scatter_kws=dict(color='lightgrey'),
                        ax=ax)
    categorical_scatter(
        data=rna_data,
        ax=ax,
        coord_base=coord_base,
        hue=f'{plot_key}_transfer',
        # text_anno='L2_transfer',
        palette='tab20',
        max_points=None,
    )

    ax = axes[1, 2]
    categorical_scatter(data=mc_data,
                        coord_base=coord_base,
                        max_points=None,
                        hue=None,
                        scatter_kws=dict(color='lightgrey'),
                        ax=ax)
    region_palette_2 = brain.get_dissection_region_palette(query_region_type)
    categorical_scatter(
        data=rna_data,
        ax=ax,
        coord_base=coord_base,
        hue='DissectionRegion',
        palette=region_palette_2,
    )

    for i, xx in enumerate([
            'Ref Co-cluster', 'Ref CellType', 'Ref Region',
            'Query Co-cluster', 'Query CellType Transfer', 'Query Region'
    ]):
        axes.flatten()[i].set_title(xx, fontsize=15)
    return

In [None]:
plot('umap')

In [None]:
plot('tsne')

## Save

In [None]:
adata_merge.write_h5ad('final_with_coords.h5ad')

In [None]:
adata_merge