In [None]:
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pathlib

import anndata

from ALLCools.mcds import MCDS
from ALLCools.plot import *
from ALLCools.plot.color import level_one_palette
from wmb import *

In [None]:
total_result = pd.read_hdf('L1_integration_summary.hdf')

In [None]:
total_result.head()

In [None]:
mc_result = total_result[total_result['L1_Modality'] == 'mC']
rna_result = total_result[total_result['L1_Modality'] == 'RNA']

## Get L4 Aggregation Dataframe

In [None]:
mc_annot = cemba.get_mc_annot()
rna_annot = aibs.get_smart_annot()

In [None]:
mc_l4 = mc_annot['L4'].to_pandas()
mc_result_l4 = mc_result.groupby(mc_l4).agg({
    'L1_InteGroup': lambda i: i.unique()[0],
    'L1_tsne_0': 'median',
    'L1_tsne_1': 'median',
    'L1_umap_0': 'median',
    'L1_umap_1': 'median',
})

In [None]:
rna_l4 = rna_annot['L4'].to_pandas()
rna_result_l4 = rna_result.groupby(rna_l4).agg({
    'L1_InteGroup': lambda i: i.unique()[0],
    'L1_tsne_0': 'median',
    'L1_tsne_1': 'median',
    'L1_umap_0': 'median',
    'L1_umap_1': 'median',
})

## Cell level gene plot

In [None]:
fig, axes = plt.subplots(figsize=(10, 10), nrows=2, ncols=2, dpi=300)

coord = 'L1_tsne'
gene = 'Gad1'
mc_type = 'CHN'

ax = axes[0, 0]
categorical_scatter(mc_result, ax=ax, hue='L1_InteGroup', coord_base=coord)
ax.set(title='mC Inte Group')
ax = axes[0, 1]
categorical_scatter(rna_result, ax=ax, hue='L1_InteGroup', coord_base=coord)
ax.set(title='RNA Inte Group')
ax = axes[1, 0]
mc_gene_data = cemba.get_mc_gene_frac(gene, mc_type=mc_type)
continuous_scatter(mc_result,
                   ax=ax,
                   hue=mc_gene_data,
                   coord_base=coord,
                   cmap='viridis',
                   hue_portion=0.8)
ax.set(title=f'mC {gene} {mc_type}')
ax = axes[1, 1]
rna_gene_data = aibs.get_smart_gene_data(gene)
continuous_scatter(rna_result,
                   ax=ax,
                   hue=rna_gene_data,
                   coord_base=coord,
                   cmap='viridis_r')
ax.set(title=f'RNA {gene} log1p(CPM)')

## Cluster level gene plot

In [None]:
fig, axes = plt.subplots(figsize=(10, 10), nrows=2, ncols=2, dpi=300)

coord = 'L1_tsne'
gene = 'Gad1'
mc_type = 'CHN'
mc_df = mc_result_l4
rna_df = rna_result_l4

mc_gene_data = cemba.get_mc_gene_frac(gene, mc_type=mc_type)
mc_gene_data = mc_gene_data.groupby(mc_l4).mean()
mc_size = mc_l4.value_counts()

rna_gene_data = aibs.get_smart_gene_data(gene)
rna_gene_data = rna_gene_data.groupby(rna_l4).mean()
rna_size = rna_l4.value_counts()

ax = axes[0, 0]
categorical_scatter(mc_df,
                    ax=ax,
                    hue='L1_InteGroup',
                    size=mc_size,
                    sizes=(0.5, 20),
                    size_portion=0.95,
                    coord_base=coord)
ax.set(title='mC Inte Group')

ax = axes[0, 1]
categorical_scatter(rna_df,
                    ax=ax,
                    hue='L1_InteGroup',
                    size=rna_size,
                    sizes=(0.5, 20),
                    size_portion=0.95,
                    coord_base=coord)
ax.set(title='RNA Inte Group')

ax = axes[1, 0]
continuous_scatter(mc_df,
                   ax=ax,
                   hue=mc_gene_data,
                   size=mc_size,
                   sizes=(0.5, 20),
                   size_portion=0.95,
                   coord_base=coord,
                   cmap='viridis')
ax.set(title=f'mC {gene} {mc_type}')

ax = axes[1, 1]
continuous_scatter(rna_df,
                   ax=ax,
                   hue=rna_gene_data,
                   size=rna_size,
                   sizes=(0.5, 20),
                   size_portion=0.95,
                   coord_base=coord,
                   cmap='viridis_r')
ax.set(title=f'RNA {gene} log1p(CPM)')

### Multi-gene comparison

In [None]:
genes = pd.read_csv('genes_to_plot.txt', header=None, index_col=0).index

In [None]:
inte_group_palette = level_one_palette(mc_result['L1_InteGroup'])

In [None]:
mc_gene_data = pd.DataFrame(
    {gene: cemba.get_mc_gene_frac(gene)
     for gene in genes})
mc_l4_gene_data = mc_gene_data.groupby(mc_l4).mean()

# reorder data by inte group
cluster_order_by_inte_group = mc_result_l4['L1_InteGroup'].sort_values().index
mc_l4_gene_data = mc_l4_gene_data.loc[cluster_order_by_inte_group].copy()

In [None]:
fig = plt.figure(figsize=(6, 5), dpi=300)

gs = fig.add_gridspec(nrows=5, ncols=6)

plot_data = mc_l4_gene_data

group_ax = fig.add_subplot(gs[:, 0])
inte_group_img = np.array([
    inte_group_palette[g]
    for g in plot_data.index.map(mc_result_l4['L1_InteGroup'])
])
inte_group_img = inte_group_img[:, None, :]
group_ax.imshow(inte_group_img, aspect='auto')

heatmap_ax = fig.add_subplot(gs[:, 1:])
heatmap_ax.imshow(plot_data, aspect='auto', vmin=0, vmax=2)

In [None]:
rna_gene_data = pd.DataFrame(
    {gene: aibs.get_smart_gene_data(gene)
     for gene in genes})
rna_l4_gene_data = rna_gene_data.groupby(rna_l4).mean()

# reorder data by inte group
cluster_order_by_inte_group = rna_result_l4['L1_InteGroup'].sort_values().index
rna_l4_gene_data = rna_l4_gene_data.loc[cluster_order_by_inte_group].copy()

In [None]:
fig = plt.figure(figsize=(6, 5), dpi=300)

gs = fig.add_gridspec(nrows=5, ncols=6)

plot_data = rna_l4_gene_data

group_ax = fig.add_subplot(gs[:, 0])
inte_group_img = np.array([
    inte_group_palette[g]
    for g in plot_data.index.map(rna_result_l4['L1_InteGroup'])
])
inte_group_img = inte_group_img[:, None, :]
group_ax.imshow(inte_group_img, aspect='auto')

heatmap_ax = fig.add_subplot(gs[:, 1:])
heatmap_ax.imshow(plot_data, aspect='auto', vmin=0, vmax=2)