In [None]:
from ALLCools.mcds import MCDS
from ALLCools.plot import *
from ALLCools.integration import confusion_matrix_clustering

from wmb import *

import pandas as pd
import numpy as np
import anndata
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
category_key = 'L2'
plot_key = 'L2'
dataset = 'AIBS_SMART'

In [None]:
mc_annot = cemba.get_mc_annot()

atac_annot = cemba_atac.get_atac_annot()

In [None]:
adata_merge = anndata.read_h5ad('final_with_coords.h5ad')

In [None]:
adata_merge

In [None]:
atac_adata = adata_merge[adata_merge.obs['Modality'] == 'ATAC'].copy()
mc_adata = adata_merge[adata_merge.obs['Modality'] == 'mC'].copy()

In [None]:
atac_meta = adata_merge.obs[adata_merge.obs['Modality'] == 'ATAC'].copy()
mc_meta = adata_merge.obs[adata_merge.obs['Modality'] == 'mC'].copy()

## Determine integration group
In ATAC L2 resolution, assign ATAC L2 to its most probable mC clusters where the sum of probability > 0.95

In [None]:
confusion_matrix = pd.read_hdf(f'{category_key}.overlap_score.hdf')

In [None]:
row_group, col_group, confusion_matrix, diag_score = confusion_matrix_clustering(
    confusion_matrix, min_value=0, max_value=0.9)

print(f'Diagonal Score: {diag_score:.2f}')

In [None]:
fig, ax = plt.subplots(figsize=(8, 7), dpi=300)
sns.heatmap(confusion_matrix, ax=ax, vmin=0, vmax=0.5)

In [None]:
atac_adata.obs['InteGroup'] = atac_adata.obs[category_key].map(row_group)
atac_adata.obs['InteGroup'].value_counts()

In [None]:
mc_adata.obs['InteGroup'] = mc_adata.obs[category_key].map(col_group)
mc_adata.obs['InteGroup'].value_counts()

## Manual Adjust

### Merge Integration Group

In [None]:
inte_group_map = {
    # if need to merge integration group, add k:v here
    # 1: 0,
}

atac_adata.obs['InteGroup'] = atac_adata.obs['InteGroup'].map(
    lambda i: inte_group_map[i] if i in inte_group_map else i)
mc_adata.obs['InteGroup'] = mc_adata.obs['InteGroup'].map(
    lambda i: inte_group_map[i] if i in inte_group_map else i)

### Plot integration groups

In [None]:
from ALLCools.plot.color import level_one_palette

inte_group_palette = level_one_palette(
    pd.concat([atac_adata.obs['InteGroup'], mc_adata.obs['InteGroup']]), 
    palette='tab20'
)

In [None]:
fig, axes = plt.subplots(figsize=(10, 10), ncols=2, nrows=2, dpi=300)

ax = axes[0, 0]
categorical_scatter(ax=ax,
                    data=atac_adata,
                    coord_base='tsne',
                    palette='tab20',
                    hue=plot_key,
                    text_anno=plot_key,
                    max_points=None)
ax.set(title=f'ATAC {plot_key}')

ax = axes[1, 0]
categorical_scatter(ax=ax,
                    data=atac_adata,
                    coord_base='tsne',
                    hue='InteGroup',
                    text_anno='InteGroup',
                    palette=inte_group_palette,
                    max_points=None)
ax.set(title='ATAC Inte. Group')

ax = axes[0, 1]
categorical_scatter(ax=ax,
                    data=mc_adata,
                    coord_base='tsne',
                    palette='tab20',
                    hue=plot_key,
                    text_anno=plot_key,
                    max_points=None)
ax.set(title=f'mC {plot_key}')
ax = axes[1, 1]
categorical_scatter(ax=ax,
                    data=mc_adata,
                    coord_base='tsne',
                    hue='InteGroup',
                    text_anno='InteGroup',
                    palette=inte_group_palette,
                    max_points=None)
ax.set(title='ATAC Inte. Group')

## Save Integration Group

In [None]:
# map integration group to all cells based on intra-dataset clustering
counts = mc_adata.obs.groupby(category_key)['InteGroup'].value_counts()
mc_cluster_to_inte_group = {
    mc: inte_group
    for mc, inte_group in counts[counts > 0].index
}
mc_cell_inte_group = mc_annot[category_key].to_pandas().map(
    mc_cluster_to_inte_group).dropna().astype(int)

mc_cell_inte_group.to_csv('mc_integration_group.csv.gz')
mc_cell_inte_group.value_counts()

In [None]:
counts = atac_adata.obs.groupby(category_key)['InteGroup'].value_counts()
atac_cluster_to_inte_group = {
    atac: inte_group
    for atac, inte_group in counts[counts > 0].index
}
atac_cell_inte_group = atac_annot[category_key].to_pandas().map(
    atac_cluster_to_inte_group).dropna().astype(int)

atac_cell_inte_group.to_csv('atac_integration_group.csv.gz')
atac_cell_inte_group.value_counts()

## Plot Individual Group

1. is there mC cluster do not match to any ATAC cluster?
2. different separatetion between mC / ATAC

In [None]:
def plot_single_group(group):
    atac_hue = atac_adata.obs['InteGroup'] == group
    mc_hue = mc_adata.obs['InteGroup'] == group

    fig, axes = plt.subplots(figsize=(10, 10), ncols=2, nrows=2, dpi=300)

    ax = axes[0, 0]
    categorical_scatter(ax=ax,
                        data=atac_adata,
                        coord_base='tsne',
                        hue='L1_annot',
                        text_anno='L1_annot',
                        max_points=None)
    ax.set(title='ATAC L1 Annot')

    ax = axes[1, 0]
    categorical_scatter(ax=ax,
                        data=atac_adata,
                        coord_base='tsne',
                        hue=atac_hue,
                        palette={
                            True: 'red',
                            False: 'lightgrey'
                        },
                        text_anno='InteGroup',
                        max_points=None)
    ax.set(title='ATAC Inte. Group')

    ax = axes[0, 1]
    categorical_scatter(ax=ax,
                        data=mc_adata,
                        coord_base='tsne',
                        hue='L1_annot',
                        text_anno='L1_annot',
                        max_points=None)
    ax.set(title='mC L1 Annot')

    ax = axes[1, 1]
    categorical_scatter(ax=ax,
                        data=mc_adata,
                        coord_base='tsne',
                        hue=mc_hue,
                        palette={
                            True: 'red',
                            False: 'lightgrey'
                        },
                        text_anno='InteGroup',
                        max_points=None)
    ax.set(title='ATAC Inte. Group')
    return fig

In [None]:
# import matplotlib.backends.backend_pdf
# 
# with matplotlib.backends.backend_pdf.PdfPages("integration_groups.pdf") as pdf:
#     for group in atac_adata.obs['InteGroup'].unique():
#         fig = plot_single_group(group)
#         pdf.savefig(fig)