In [None]:
from ALLCools.mcds import MCDS
from ALLCools.plot import *
from ALLCools.integration import confusion_matrix_clustering

from wmb import cemba, aibs, broad, brain

import pandas as pd
import numpy as np
import anndata
import matplotlib.pyplot as plt
import seaborn as sns
import pathlib
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import wmb

In [None]:
# # Parameters
# dataset = "AIBS_TENX"

# if dataset == 'AIBS_SMART':
#     m3c_annot = aibs.get_smart_annot()
# elif dataset == 'AIBS_TENX':
#     m3c_annot = aibs.get_tenx_annot()
# else:
#     m3c_annot = broad.get_tenx_annot()

In [None]:
mc_annot = cemba.get_mc_annot()
m3c_annot = cemba.get_m3c_annot()


In [None]:
def get_pre_data(integroup, category_key):
    #get adata
    adata_merge = anndata.read_h5ad(f'../{category_key}/ALL/{integroup}/final_with_coords.h5ad')
    m3c_adata = adata_merge[adata_merge.obs['Modality'] == 'm3C'].copy()
    mc_adata = adata_merge[adata_merge.obs['Modality'] == 'mC'].copy()
    
    m3c_meta = adata_merge.obs[adata_merge.obs['Modality'] == 'm3C'].copy()
    mc_meta = adata_merge.obs[adata_merge.obs['Modality'] == 'mC'].copy()
    
    #add L1 annot
    mc_adata.obs['L1_annot'] = mc_annot['L1_annot'].to_pandas()
    m3c_adata.obs['L1_annot'] = m3c_annot['L1'].to_pandas()
    
    #get integroup
    m3c_integroup = pd.read_csv(f'../{category_key}/ALL/{integroup}/m3c_integration_group.csv.gz', index_col = 'cell').squeeze()
    mc_integroup = pd.read_csv(f'../{category_key}/ALL/{integroup}/mc_integration_group.csv.gz', index_col = 'cell').squeeze()
    
    m3c_adata.obs[f'{category_key}_InteGroup'] = m3c_adata.obs.index.map(m3c_integroup)
    m3c_adata.obs[f'{category_key}_InteGroup'].value_counts()
    
    mc_adata.obs[f'{category_key}_InteGroup'] = mc_adata.obs.index.map(mc_integroup)
    mc_adata.obs[f'{category_key}_InteGroup'].value_counts()
    
    return m3c_adata, mc_adata
    
        

In [None]:
wmb.brain.get_dissection_region_palette(region_type = 'CEMBA')

In [None]:
def plot_clustering(category_key):
    
    from ALLCools.plot.color import level_one_palette

    inte_group_palette = level_one_palette(
        pd.concat([m3c_adata.obs[f'{category_key}_InteGroup'], mc_adata.obs[f'{category_key}_InteGroup']]), 
        palette='tab20'
    )
    
    fig, axes = plt.subplots(figsize=(10, 15), ncols=2, nrows=3, dpi=200)

    ax = axes[0, 0]
    categorical_scatter(ax=ax,
                        data=m3c_adata,
                        coord_base='tsne',
                        hue=f'{category_key}_InteGroup',
                        text_anno=f'{category_key}_InteGroup',
                        palette=inte_group_palette,
                        max_points=None)
    ax.set(title='m3c Inte. Group')

    ax = axes[0, 1]
    categorical_scatter(ax=ax,
                        data=mc_adata,
                        coord_base='tsne',
                        hue=f'{category_key}_InteGroup',
                        text_anno=f'{category_key}_InteGroup',
                        palette=inte_group_palette,
                        max_points=None)
    ax.set(title='mC Inte. Group')

    ax = axes[1, 0]
    categorical_scatter(ax=ax,
                        data=m3c_adata,
                        coord_base='tsne',
                        palette='tab20',
                        hue='L3',
                        text_anno='L3',
                        max_points=None)
    ax.set(title=f'm3C L3')
    ax.set(title=f'm3C L3')


    ax = axes[1, 1]
    categorical_scatter(ax=ax,
                        data=mc_adata,
                        coord_base='tsne',
                        palette='tab20',
                        hue='L3',
                        text_anno='L3',
                        max_points=None)
    ax.set(title=f'mC L3')

    ax = axes[2, 0]
    categorical_scatter(ax=ax,
                        data=m3c_adata,
                        coord_base='tsne',
                        hue='DissectionRegion',
                        text_anno='DissectionRegion',
                        palette='tab20',
                        max_points=None)
    ax.set(title='m3c DissectionRegionp')

    ax = axes[2, 1]
    categorical_scatter(ax=ax,
                        data=mc_adata,
                        coord_base='tsne',
                        palette=wmb.brain.get_dissection_region_palette(region_type = 'CEMBA'),
                        hue='DissectionRegion',
                        text_anno='DissectionRegion',
                        max_points=None)
    ax.set(title=f'mC DissectionRegion')

# plot

In [None]:
# by changing the category_key, can see all clustering results at each level
category_key = "L4"

In [None]:
integroups = []
for i in pathlib.Path(f'../{category_key}/ALL').glob('InteGroup*'):
    integroups.append(str(i).split('/')[-1])

In [None]:
with PdfPages(f'{category_key}_clusers.pdf') as pdf:
    for integroup in integroups:
        m3c_adata, mc_adata = get_pre_data(integroup, category_key)
        plot_clustering(category_key)
        pdf.savefig()  # saves the current figure into a pdf page
        plt.close()