In [None]:
from ALLCools.mcds import MCDS
from ALLCools.plot import *
from ALLCools.integration import confusion_matrix_clustering

from wmb import cemba, aibs, brain

import pandas as pd
import numpy as np
import anndata
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
category_key = 'L2'
plot_key = 'L1'
ref_label = 'mC'
query_label = 'm3C'

In [None]:
mc_annot = cemba.get_mc_annot()
m3c_annot = cemba.get_m3c_annot()

In [None]:
m3c_annot

In [None]:
adata_merge = anndata.read_h5ad('final_with_coords.h5ad')

In [None]:
adata_merge

In [None]:
m3c_adata = adata_merge[adata_merge.obs['Modality'] == query_label].copy()
mc_adata = adata_merge[adata_merge.obs['Modality'] == ref_label].copy()

In [None]:
m3c_meta = adata_merge.obs[adata_merge.obs['Modality'] == query_label].copy()
mc_meta = adata_merge.obs[adata_merge.obs['Modality'] == ref_label].copy()

## Determine integration group
In m3c L2 resolution, assign m3c L2 to its most probable mC clusters where the sum of probability > 0.95

In [None]:
confusion_matrix = pd.read_hdf(f'{category_key}.overlap_score.hdf')

In [None]:
row_group, col_group, confusion_matrix, diag_score = confusion_matrix_clustering(
    confusion_matrix, min_value=0, max_value=0.9)

In [None]:
fig, ax = plt.subplots(figsize=(8, 7), dpi=300)
sns.heatmap(confusion_matrix, ax=ax, vmin=0, vmax=0.5)

In [None]:
m3c_adata.obs['InteGroup'] = m3c_adata.obs[category_key].map(row_group)
m3c_adata.obs['InteGroup'].value_counts()

In [None]:
mc_adata.obs['InteGroup'] = mc_adata.obs[category_key].map(col_group)
mc_adata.obs['InteGroup'].value_counts()

## Manual Adjust

### Merge Integration Group

In [None]:
inte_group_map = {
    # if need to merge integration group, add k:v here
    # 1: 0,
}

m3c_adata.obs['InteGroup'] = m3c_adata.obs['InteGroup'].map(
    lambda i: inte_group_map[i] if i in inte_group_map else i)
mc_adata.obs['InteGroup'] = mc_adata.obs['InteGroup'].map(
    lambda i: inte_group_map[i] if i in inte_group_map else i)

### Plot integration groups

In [None]:
from ALLCools.plot.color import level_one_palette

inte_group_palette = level_one_palette(
    pd.concat([m3c_adata.obs['InteGroup'], mc_adata.obs['InteGroup']]), 
    palette='tab20'
)

In [None]:
fig, axes = plt.subplots(figsize=(10, 10), ncols=2, nrows=2, dpi=300)

ax = axes[0, 0]
categorical_scatter(ax=ax,
                    data=m3c_adata,
                    coord_base='tsne',
                    palette='tab20',
                    hue=plot_key,
                    text_anno=plot_key,
                    max_points=None)
ax.set(title=f'{query_label} {plot_key}')

ax = axes[1, 0]
categorical_scatter(ax=ax,
                    data=m3c_adata,
                    coord_base='tsne',
                    hue='InteGroup',
                    text_anno='InteGroup',
                    palette=inte_group_palette,
                    max_points=None)
ax.set(title=f'{query_label} Inte. Group')

ax = axes[0, 1]
categorical_scatter(ax=ax,
                    data=mc_adata,
                    coord_base='tsne',
                    palette='tab20',
                    hue=plot_key,
                    text_anno=plot_key,
                    max_points=None)
ax.set(title=f'{ref_label} {plot_key}')
ax = axes[1, 1]
categorical_scatter(ax=ax,
                    data=mc_adata,
                    coord_base='tsne',
                    hue='InteGroup',
                    text_anno='InteGroup',
                    palette=inte_group_palette,
                    max_points=None)
ax.set(title=f'{ref_label} Inte. Group')

## Save Integration Group

In [None]:
# map integration group to all cells based on intra-dataset clustering
counts = mc_adata.obs.groupby(category_key)['InteGroup'].value_counts()
mc_cluster_to_inte_group = {
    mc: inte_group
    for mc, inte_group in counts[counts > 0].index
}
mc_cell_inte_group = mc_annot[category_key].to_pandas().map(
    mc_cluster_to_inte_group).dropna().astype(int)

mc_cell_inte_group.to_csv('mc_integration_group.csv.gz')
mc_cell_inte_group.value_counts()

In [None]:
counts = m3c_adata.obs.groupby(category_key)['InteGroup'].value_counts()
m3c_cluster_to_inte_group = {
    m3c: inte_group
    for m3c, inte_group in counts[counts > 0].index
}
m3c_cell_inte_group = m3c_annot[category_key].to_pandas().map(
    m3c_cluster_to_inte_group).dropna().astype(int)

m3c_cell_inte_group.to_csv('m3c_integration_group.csv.gz')
m3c_cell_inte_group.value_counts()

## Plot Individual Group

1. is there mC cluster do not match to any m3c cluster?
2. different separatetion between mC / m3c

In [None]:
def plot_single_group(group):
    m3c_hue = m3c_adata.obs['InteGroup'] == group
    mc_hue = mc_adata.obs['InteGroup'] == group

    fig, axes = plt.subplots(figsize=(10, 10), ncols=2, nrows=2, dpi=300)

    ax = axes[0, 0]
    categorical_scatter(ax=ax,
                        data=m3c_adata,
                        coord_base='tsne',
                        hue='L1_annot',
                        text_anno='L1_annot',
                        max_points=None)
    ax.set(title='m3c L1 Annot')

    ax = axes[1, 0]
    categorical_scatter(ax=ax,
                        data=m3c_adata,
                        coord_base='tsne',
                        hue=m3c_hue,
                        palette={
                            True: 'red',
                            False: 'lightgrey'
                        },
                        text_anno='InteGroup',
                        max_points=None)
    ax.set(title='m3c Inte. Group')

    ax = axes[0, 1]
    categorical_scatter(ax=ax,
                        data=mc_adata,
                        coord_base='tsne',
                        hue='L1_annot',
                        text_anno='L1_annot',
                        max_points=None)
    ax.set(title='mC L1 Annot')

    ax = axes[1, 1]
    categorical_scatter(ax=ax,
                        data=mc_adata,
                        coord_base='tsne',
                        hue=mc_hue,
                        palette={
                            True: 'red',
                            False: 'lightgrey'
                        },
                        text_anno='InteGroup',
                        max_points=None)
    ax.set(title='m3c Inte. Group')
    return fig

In [None]:
# import matplotlib.backends.backend_pdf
# 
# with matplotlib.backends.backend_pdf.PdfPages("integration_groups.pdf") as pdf:
#     for group in m3c_adata.obs['InteGroup'].unique():
#         fig = plot_single_group(group)
#         pdf.savefig(fig)