# Plot Cluster DMGs

In [None]:
import yaml
import pandas as pd
import numpy as np
import anndata
import matplotlib.pyplot as plt
import seaborn as sns

from ALLCools.plot import *
from wmb import cemba, brain

In [None]:
with open('config/07.yaml', 'r') as f:
    config = yaml.safe_load(f)
    locals().update(config)
    print('Notebook configs:')
    for _k, _v in config.items():
        print(f'{_k} = {_v}')

## Load

### Clustering results

In [None]:
if dataset == 'mC':
    cell_meta = cemba.get_mc_mapping_metric()
else:
    cell_meta = cemba.get_m3c_mapping_metric()

adata = anndata.read_h5ad('adata.with_coords.h5ad')
for col, data in cell_meta.iteritems():
    adata.obs[col] = data
adata.obs['log2PlateNormCov'] = np.log2(adata.obs['PlateNormCov'])

In [None]:
adata

### Brain Region

In [None]:
if dataset == 'mC':
    region_type = 'CEMBA'
else:
    region_type = 'CEMBA_3C'

major_region_palette = brain.get_major_region_palette(region_type=region_type)
sub_region_palette = brain.get_sub_region_palette(region_type=region_type)
dissection_region_palette = brain.get_dissection_region_palette(
    region_type=region_type)

adata.obs['DissectionRegion'] = adata.obs['DissectionRegion'].map(
    brain.map_cemba_id_to_dissection_region(region_type=region_type))
adata.obs['MajorRegion'] = adata.obs['DissectionRegion'].map(
    brain.map_dissection_region_to_major_region(region_type=region_type))
adata.obs['SubRegion'] = adata.obs['DissectionRegion'].map(
    brain.map_dissection_region_to_sub_region(region_type=region_type))

## Cell Meta

In [None]:
fig, axes = plt.subplots(figsize=(8, 4), ncols=2, dpi=300)

ax = axes[0]
continuous_scatter(ax=ax,
                   data=adata,
                   hue='mCHFrac',
                   hue_norm=(0., 0.06),
                   text_anno=cluster_col,
                   coord_base='tsne',
                   max_points=None)
ax = axes[1]
continuous_scatter(ax=ax,
                   data=adata,
                   hue='mCHFrac',
                   hue_norm=(0., 0.06),
                   text_anno=cluster_col,
                   coord_base='umap',
                   max_points=None)

In [None]:
fig, axes = plt.subplots(figsize=(8, 4), ncols=2, dpi=300)

ax = axes[0]
continuous_scatter(ax=ax,
                   data=adata,
                   hue='mCGFrac',
                   hue_norm=(0.7, 0.85),
                   text_anno=cluster_col,
                   coord_base='tsne',
                   max_points=None)
ax = axes[1]
continuous_scatter(ax=ax,
                   data=adata,
                   hue='mCGFrac',
                   hue_norm=(0.7, 0.85),
                   text_anno=cluster_col,
                   coord_base='umap',
                   max_points=None)

In [None]:
fig, axes = plt.subplots(figsize=(8, 4), ncols=2, dpi=300)

ax = axes[0]
continuous_scatter(ax=ax,
                   data=adata,
                   hue='log2PlateNormCov',
                   hue_norm=(-1, 1),
                   cmap='coolwarm',
                   text_anno=cluster_col,
                   coord_base='tsne',
                   max_points=None)
ax = axes[1]
continuous_scatter(ax=ax,
                   data=adata,
                   hue='log2PlateNormCov',
                   hue_norm=(-1, 1),
                   cmap='coolwarm',
                   text_anno=cluster_col,
                   coord_base='umap',
                   max_points=None)

In [None]:
fig, axes = plt.subplots(figsize=(8, 4), dpi=300, ncols=2)

ax = axes[0]
categorical_scatter(
    ax=ax,
    data=adata,
    hue=cluster_col,
    #text_anno=cluster_col,
    axis_format=None,
    coord_base='tsne',
    max_points=None)

ax = axes[1]
categorical_scatter(
    ax=ax,
    data=adata,
    hue=cluster_col,
    #text_anno=cluster_col,
    axis_format=None,
    coord_base='umap',
    max_points=None)

## Brain Region

In [None]:
fig, axes = plt.subplots(figsize=(8, 4), dpi=300, ncols=2)

ax = axes[0]
categorical_scatter(ax=ax,
                    data=adata,
                    hue='DissectionRegion',
                    palette=dissection_region_palette,
                    axis_format=None,
                    coord_base='tsne',
                    max_points=None)

ax = axes[1]
categorical_scatter(ax=ax,
                    data=adata,
                    hue='DissectionRegion',
                    palette=dissection_region_palette,
                    axis_format=None,
                    coord_base='umap',
                    max_points=None)

In [None]:
n_plots = adata.obs['MajorRegion'].unique().size
ncols = 4
nrows = n_plots // 4 + (0 if n_plots % ncols == 0 else 1)

fig, axes = plt.subplots(figsize=(4 * ncols, 4 * nrows),
                         ncols=ncols,
                         nrows=nrows,
                         dpi=300)

for ax, (major_region, sub_df) in zip(axes.ravel()[:n_plots],
                                      adata.obs.groupby('MajorRegion')):
    categorical_scatter(ax=ax,
                        data=adata,
                        text_anno=cluster_col,
                        coord_base='tsne',
                        max_points=None,
                        scatter_kws=dict(color='lightgray'))
    categorical_scatter(ax=ax,
                        data=adata[adata.obs_names.isin(sub_df.index), :],
                        hue='MajorRegion',
                        palette=major_region_palette,
                        coord_base='tsne',
                        max_points=None)
    ax.set(title=f'{major_region} {sub_df.shape[0]}')
for ax in axes.ravel()[n_plots:]:
    ax.axis('off')

## Annotation

In [None]:
# RS1 paper annotation
paper_anno = cemba.get_liu_2021_mc_metadata()

In [None]:
major_type_palette = pd.read_csv(
    '/home/hanliu/project/mouse_rostral_brain/metadata/palette/major_type.palette.csv',
    index_col=0,
    header=None,
    squeeze=True).to_dict()
major_type_palette['NA'] = 'lightgray'
adata.obs['MajorType'] = paper_anno['MajorType']

In [None]:
fig, axes = plt.subplots(figsize=(8, 4), dpi=500, ncols=2)

ax = axes[0]
_ = categorical_scatter(ax=ax,
                        data=adata,
                        hue=None,
                        scatter_kws=dict(color='lightgray'),
                        coord_base='tsne',
                        max_points=None)
_ = categorical_scatter(ax=ax,
                        data=adata[adata.obs['MajorType'] != 'NA'],
                        hue='MajorType',
                        text_anno='MajorType',
                        palette=major_type_palette,
                        coord_base='tsne',
                        max_points=None,
                        show_legend=False)
ax = axes[1]
_ = categorical_scatter(ax=ax,
                        data=adata,
                        hue=cluster_col,
                        text_anno=cluster_col,
                        coord_base='tsne',
                        max_points=None,
                        show_legend=False)

In [None]:
!touch finish