# Analysis & Visualization

2023-05-05

In [None]:
# Import Packages

%load_ext autoreload
%autoreload 2

import os
import warnings 
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.filters import threshold_otsu, gaussian
from skimage.morphology import remove_small_objects
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from anndata import AnnData, concat
from tqdm.notebook import tqdm
from statannotations.Annotator import Annotator

# Customized packages
from starmap.utilities import *
from starmap.sequencing import *
# from starmap.obj import STARMapDataset, load_data
# import starmap.analyze as anz
# import starmap.viz as viz

import starmap.sc_util as su

# test()

In [None]:
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')

In [None]:
# sc.settings.figdir = fig_path
sc.set_figure_params(dpi=300)

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

## Set path

In [None]:
# Set path
base_path = 'path/to/dataset/folder'

input_path = os.path.join(base_path, 'input')

out_path = os.path.join(base_path, 'output')
if not os.path.exists(out_path):
    os.mkdir(out_path)
    
fig_path = os.path.join(base_path, 'figures')
if not os.path.exists(fig_path):
    os.mkdir(fig_path)

sc.settings.figdir = fig_path

## Input data

In [None]:
# load h5ad object of RIBOmap sections 
rdata = sc.read_h5ad(os.path.join(out_path, '2023-05-05-Brain-RIBOmap.h5ad'))
rdata

In [None]:
# load h5ad object of all sections 
cdata = sc.read_h5ad(os.path.join(out_path, '2023-05-05-Brain-combined.h5ad'))
cdata

## Matrix for gene clustering

### 1e4

In [None]:
# only use rep2
current_level = 'level_3'
sdata = cdata[(cdata.obs['replicate'] == 'rep2') & (cdata.obs['region'] != 'other'), :].copy()
sdata.X = sdata.layers['raw']
sc.pp.normalize_total(sdata, target_sum=1e4)

# use norm 
current_df = pd.DataFrame(sdata.X, index=sdata.obs.index, columns=sdata.var.index)
current_df['protocol'] = sdata.obs['protocol'].values
current_df[current_level] = sdata.obs[current_level].values

star_df = current_df.loc[current_df['protocol'] == 'STARmap', :]
ribo_df = current_df.loc[current_df['protocol'] == 'RIBOmap', :]

star_res_df = star_df.groupby(current_level).mean()
star_res_df.loc['All_ctype', :] = star_df.mean(axis=0)

ribo_res_df = ribo_df.groupby(current_level).mean()
ribo_res_df.loc['All_ctype', :] = ribo_df.mean(axis=0)

In [None]:
with pd.ExcelWriter(os.path.join(fig_path, f'{date}-averaged-gene-expression-{current_level}-1e4.xlsx'), mode='w') as writer:  
    ribo_res_df.to_excel(writer, sheet_name=f'RIBOmap')
    star_res_df.to_excel(writer, sheet_name=f'STARmap')

## UMAP

In [None]:
level_2_pl = sns.color_palette(cdata.uns['level_2_color'])
level_3_pl = sns.color_palette(cdata.uns['level_3_color'])

### level-2

In [None]:
sc.set_figure_params(dpi_save=300)
sc.settings.figdir = fig_path

In [None]:
fig, ax = plt.subplots(figsize=(9, 7))
sc.pl.umap(rdata, color='level_2', legend_loc=None, frameon=False, ax=ax, size=3, 
           title='', palette=level_2_pl, save='_level_2_no_legend_large.png')

### level-3

In [None]:
# Save plots
# Plot UMAP with cluster labels w/ new color
sc.pl.umap(rdata, color='level_3', legend_loc='right margin',
           legend_fontsize=12, legend_fontoutline=2, frameon=False, 
           title='', palette=level_3_pl, save='_level_3.pdf')

sc.pl.umap(rdata, color='level_3', legend_loc=None, frameon=False, 
           title='', palette=level_3_pl, save='_level_3_no_legend.pdf')

sc.pl.umap(rdata, color='level_3', legend_loc=None, frameon=False, 
           title='', palette=level_3_pl, save='_level_3_no_legend.png')

rdata.obs['level_3_code'] = rdata.obs['level_3'].cat.codes
rdata.obs['level_3_code'] = rdata.obs['level_3_code'].astype('category')
sc.pl.umap(rdata, color='level_3_code', legend_loc='on data',
           legend_fontsize=8, legend_fontoutline=2, frameon=False, 
           title='', palette=level_3_pl, save='_level_3_code.pdf')

In [None]:
fig, ax = plt.subplots(figsize=(9, 7))
sc.pl.umap(rdata, color='level_3', legend_loc=None, frameon=False, ax=ax, size=3, 
           title='', palette=level_3_pl, save='_level_3_no_legend_large.png')

In [None]:
pdata = cdata[cdata.obs['protocol'] == 'STARmap', ]
fig, ax = plt.subplots(figsize=(9, 7))
sc.pl.umap(pdata, color='level_3', legend_loc=None, frameon=False, ax=ax, size=3, 
           title='', palette=level_3_pl, save='_level_3_no_legend_large_STARmap-rep2.png')

### other

In [None]:
fig, ax = plt.subplots(figsize=(9, 7))
sc.pl.umap(rdata, color='replicate', legend_loc=None, frameon=False, ax=ax, size=3, 
           title='', save='_protocol_no_legend_large.png')

## Spatial map

### polygon map

In [None]:
temp_path = 'path/to/processed/data/folder'
cdata.obs['sample'] = cdata.obs['protocol-replicate']

In [None]:
# load segmentation ribomap-rep2
current_seg_path = os.path.join(temp_path, 'segmentation-mask-for-visualization', 'labeled_cells_RIBOmap_rep2.tif')
current_img = tifffile.imread(current_seg_path)

# Store the images to adata object
current_key = f"RIBOmap-rep2_morph"
cdata.uns[current_key] = {}
cdata.uns[current_key]['label_img'] = current_img

# Contruct polygon
cdata.uns[current_key]['qhulls'], cdata.uns[current_key]['coords'], cdata.uns[current_key]['centroids'] = su.get_qhulls_test(cdata.uns[current_key]['label_img'])

In [None]:
# load segmentation starmap-rep2
current_seg_path = os.path.join(temp_path, 'segmentation-mask-for-visualization', 'labeled_cells_STARmap_rep2.tif')
current_img = tifffile.imread(current_seg_path)

# Store the images to adata object
current_key = f"STARmap-rep2_morph"
cdata.uns[current_key] = {}
cdata.uns[current_key]['label_img'] = current_img

# Contruct polygon
cdata.uns[current_key]['qhulls'], cdata.uns[current_key]['coords'], cdata.uns[current_key]['centroids'] = su.get_qhulls_test(cdata.uns[current_key]['label_img'])

In [None]:
# load segmentation ribomap-rep1
current_seg_path = os.path.join(temp_path, 'segmentation-mask-for-visualization', 'labeled_cells_RIBOmap_rep1.tif')
current_img = tifffile.imread(current_seg_path)

# Store the images to adata object
current_key = f"RIBOmap-rep1_morph"
cdata.uns[current_key] = {}
cdata.uns[current_key]['label_img'] = current_img

# Contruct polygon
cdata.uns[current_key]['qhulls'], cdata.uns[current_key]['coords'], cdata.uns[current_key]['centroids'] = su.get_qhulls_test(cdata.uns[current_key]['label_img'])

#### level-2

In [None]:
# Check color legend
level_2_pl = sns.color_palette(cdata.uns['level_2_color'])
level_2_cmap = ListedColormap(level_2_pl.as_hex())
level_2_order = cdata.uns['level_2_order']
sns.palplot(level_2_pl, size=3)
plt.xticks(range(len(level_2_order)), level_2_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
# Add data ribomap-rep2
current_key = f"RIBOmap-rep2_morph"
cdata.uns[current_key]['good_cells'] = cdata.obs.loc[cdata.obs['sample'] == 'RIBOmap-rep2', 'orig_index'].astype(int).values
cdata.uns[current_key]['colors'] = cdata.obs.loc[cdata.obs['sample'] == 'RIBOmap-rep2', 'level_2'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(cdata, 'RIBOmap-rep2', level_2_cmap, show_plaque=False, show_tau=False, linewidth=0.4,
                                        figscale=3, width=10, height=10, bg_color='#a8a8a8', save_as_real_size=True,
                                         save=True, show=False, output_dir=fig_path)

In [None]:
# Add data starmap-rep2
current_key = f"STARmap-rep2_morph"
cdata.uns[current_key]['good_cells'] = cdata.obs.loc[cdata.obs['sample'] == 'STARmap-rep2', 'orig_index'].astype(int).values
cdata.uns[current_key]['colors'] = cdata.obs.loc[cdata.obs['sample'] == 'STARmap-rep2', 'level_2'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(cdata, 'STARmap-rep2', level_2_cmap, show_plaque=False, show_tau=False, linewidth=0.4,
                                        figscale=3, width=10, height=10, bg_color='#a8a8a8', save_as_real_size=True,
                                         save=True, show=False, output_dir=fig_path)

In [None]:
# Add data ribomap-rep1
current_key = f"RIBOmap-rep1_morph"
cdata.uns[current_key]['good_cells'] = cdata.obs.loc[cdata.obs['sample'] == 'RIBOmap-rep1', 'orig_index'].astype(int).values
cdata.uns[current_key]['colors'] = cdata.obs.loc[cdata.obs['sample'] == 'RIBOmap-rep1', 'level_2'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(cdata, 'RIBOmap-rep1', level_2_cmap, show_plaque=False, show_tau=False, linewidth=0.4,
                                        figscale=3, width=10, height=10, bg_color='#a8a8a8', save_as_real_size=True,
                                         save=True, show=False, output_dir=fig_path)

#### level-3

In [None]:
# Check color legend
level_3_pl = sns.color_palette(cdata.uns['level_3_color'])
level_3_cmap = ListedColormap(level_3_pl.as_hex())
level_3_order = cdata.uns['level_3_order']
sns.palplot(level_3_pl, size=3)
plt.xticks(range(len(level_3_order)), level_3_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
# Add data ribomap-rep2
current_key = f"RIBOmap-rep2_morph"
cdata.uns[current_key]['good_cells'] = cdata.obs.loc[cdata.obs['sample'] == 'RIBOmap-rep2', 'orig_index'].astype(int).values
cdata.uns[current_key]['colors'] = cdata.obs.loc[cdata.obs['sample'] == 'RIBOmap-rep2', 'level_3'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(cdata, 'RIBOmap-rep2', level_3_cmap, show_plaque=False, show_tau=False, linewidth=0.4,
                                        figscale=3, width=10, height=10, bg_color='#a8a8a8', save_as_real_size=True,
                                         save=True, show=False, output_dir=fig_path)

In [None]:
# Add data starmap-rep2
current_key = f"STARmap-rep2_morph"
cdata.uns[current_key]['good_cells'] = cdata.obs.loc[cdata.obs['sample'] == 'STARmap-rep2', 'orig_index'].astype(int).values
cdata.uns[current_key]['colors'] = cdata.obs.loc[cdata.obs['sample'] == 'STARmap-rep2', 'level_3'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(cdata, 'STARmap-rep2', level_3_cmap, show_plaque=False, show_tau=False, linewidth=0.4,
                                        figscale=3, width=10, height=10, bg_color='#a8a8a8', save_as_real_size=True,
                                         save=True, show=False, output_dir=fig_path)

In [None]:
# Add data ribomap-rep1
current_key = f"RIBOmap-rep1_morph"
cdata.uns[current_key]['good_cells'] = cdata.obs.loc[cdata.obs['sample'] == 'RIBOmap-rep1', 'orig_index'].astype(int).values
cdata.uns[current_key]['colors'] = cdata.obs.loc[cdata.obs['sample'] == 'RIBOmap-rep1', 'level_3'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(cdata, 'RIBOmap-rep1', level_3_cmap, show_plaque=False, show_tau=False, linewidth=0.1,
                                        figscale=3, width=10, height=10, bg_color='#a8a8a8', save_as_real_size=True,
                                         save=True, show=False, output_dir=fig_path)

#### level-3 groupby level-2

In [None]:
temp_list = ['Telencephalon projecting neurons',
       'Cholinergic, monoaminergic and peptidergic neurons',
       'Di/Mesencephalon neurons', 'Astrocyte', 'Oligodendrocyte',
       'Vascular cells']

level_3_color = cdata.uns['level_3_color']

In [None]:
temp_list = ['Telencephalon interneurons', 'Oligodendrocytes precursor cell', 'Microglia', 'Astroependymal cells', 'Perivascular macrophages']
temp_colors = ['#1f78b4', '#33a02c', '#e31a1c', '#ff7f00', '#6a3d9a', '#a6cee3', '#b2df8a', '#fb9a99', '#fdbf6f']

sns.palplot(sns.color_palette(temp_colors), size=3)

In [None]:
# for i, current_type in enumerate(adata.obs.level_2.cat.categories):
for i, current_type in enumerate(temp_list):
    print(current_type)
    
    pdata = cdata[(cdata.obs['level_2'] == current_type) & (cdata.obs['protocol-replicate'] == 'RIBOmap-rep1'), ]
    
    # use level_3 color
    current_order = []
    current_colors = []

    for j, plot_type in enumerate(level_3_order):
        if plot_type in pdata.obs['level_3'].unique():
            current_order.append(plot_type)
            current_colors.append(level_3_color[j])

#     # use temp color
#     current_order = []
#     current_colors = []

#     for j, plot_type in enumerate(level_3_order):
#         if plot_type in pdata.obs['level_3'].unique():
#             current_order.append(plot_type)

#     print(current_order)
#     for z in current_order:
#         current_colors.append(temp_colors.pop(0))
        
    print(current_colors)

        
    current_pl = sns.color_palette(current_colors)
    current_cmap = ListedColormap(current_pl.as_hex())
    
    current_type = current_type.replace('/', '_')
    current_fig_path = os.path.join(fig_path, f'{current_type}-sct-level3-color')
    if not os.path.exists(current_fig_path):
        os.mkdir(current_fig_path)
    
#     # Add data ribomap-rep2
#     current_key = f"RIBOmap-rep2_morph"
#     pdata.uns[current_key]['good_cells'] = pdata.obs.loc[pdata.obs['sample'] == 'RIBOmap-rep2', 'orig_index'].astype(int).values
#     pdata.uns[current_key]['colors'] = pdata.obs.loc[pdata.obs['sample'] == 'RIBOmap-rep2', 'level_3'].cat.codes.values

#     # Plot
#     su.plot_poly_cells_cluster_by_sample(pdata, 'RIBOmap-rep2', current_cmap, show_plaque=False, show_tau=False, linewidth=0.4,
#                                             figscale=3, width=10, height=10, bg_color='#ebebeb', save_as_real_size=True,
#                                              save=True, show=False, output_dir=current_fig_path)
    
    # Add data ribomap-rep1
    current_key = f"RIBOmap-rep1_morph"
    pdata.uns[current_key]['good_cells'] = pdata.obs.loc[pdata.obs['sample'] == 'RIBOmap-rep1', 'orig_index'].astype(int).values
    pdata.uns[current_key]['colors'] = pdata.obs.loc[pdata.obs['sample'] == 'RIBOmap-rep1', 'level_3'].cat.codes.values

    # Plot
    su.plot_poly_cells_cluster_by_sample(pdata, 'RIBOmap-rep1', current_cmap, show_plaque=False, show_tau=False, linewidth=0.4,
                                            figscale=3, width=10, height=10, bg_color='#ebebeb', save_as_real_size=True,
                                             save=True, show=False, output_dir=current_fig_path)
    
    
#     # Add data starmap-rep2
#     current_key = f"STARmap-rep3_morph"
#     pdata.uns[current_key]['good_cells'] = pdata.obs.loc[pdata.obs['sample'] == 'STARmap-rep3', 'orig_index'].astype(int).values
#     pdata.uns[current_key]['colors'] = pdata.obs.loc[pdata.obs['sample'] == 'STARmap-rep3', 'level_3'].cat.codes.values

#     # Plot
#     su.plot_poly_cells_cluster_by_sample(pdata, 'STARmap-rep3', current_cmap, show_plaque=False, show_tau=False, linewidth=0.4,
#                                             figscale=3, width=10, height=10, bg_color='#ebebeb', save_as_real_size=True,
#                                              save=True, show=False, output_dir=current_fig_path)

### with protein

In [None]:
cdata.obs['level_2_code'].cat.categories

In [None]:
pdata = cdata[(cdata.obs['seg_label'] != 0) & (cdata.obs['replicate'] == 'rep2') & (cdata.obs['level_2_code'].isin(['TEPN', 'INH', 'CHO_PEP', 'DE_MEN', 'AC']))]
pdata

#### RIBOmap

In [None]:
# load segmentation ribomap-rep2
current_seg_path = os.path.join(temp_path, 'protein-images/visualization', 'RIBO_labeled_cells.tif')
current_img = tifffile.imread(current_seg_path)

# Load neun image
current_neun_path = os.path.join(temp_path, 'protein-images/visualization', 'RIBO_NeuN.tif')
current_neun = tifffile.imread(current_neun_path)

# Load Gfap image
current_gfap_path = os.path.join(temp_path, 'protein-images/visualization', 'RIBO_Gfap.tif')
current_gfap = tifffile.imread(current_gfap_path)
    
# Store the images to adata object
current_key = f"RIBOmap-rep2_morph"
pdata.uns[current_key] = {}
pdata.uns[current_key]['label_img'] = current_img
pdata.uns[current_key]['tau'] = current_neun
pdata.uns[current_key]['Gfap'] = current_gfap
    
# Contruct polygon
pdata.uns[current_key]['qhulls'], pdata.uns[current_key]['coords'], pdata.uns[current_key]['centroids'] = su.get_qhulls_test(pdata.uns[current_key]['label_img'])

In [None]:
# Check color legend (old coloring scheme)
sns.reset_orig()
temp_order = ['Neuron', 'Astrocyte']
temp_colors = ['#e8f00c', '#cb99f7']

temp_pl = sns.color_palette(temp_colors)
temp_cmap = ListedColormap(temp_pl.as_hex())
sns.palplot(temp_pl, size=3)
plt.xticks(range(len(temp_order)), temp_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
# get new label for plotting 
pdata.obs['temp_label'] = 'Neuron'
pdata.obs.loc[pdata.obs['level_2_code'] == 'AC', 'temp_label'] = 'Astrocyte'

pdata.obs['temp_label'] = pdata.obs['temp_label'].astype('category')
pdata.obs['temp_label'] = pdata.obs['temp_label'].cat.reorder_categories(temp_order)

In [None]:
current_fig_path = os.path.join(fig_path, f'sct-with-protein')
if not os.path.exists(current_fig_path):
    os.mkdir(current_fig_path)

In [None]:
# get good cells
ribo_seg_labels = []
for i, region in enumerate(regionprops(current_img)):
    ribo_seg_labels.append(region.label)

ribo_seg_labels = np.array(ribo_seg_labels)

current_seg_labels = pdata.obs.loc[pdata.obs['protocol-replicate'] == 'RIBOmap-rep2', 'seg_label'].values
current_seg_labels = current_seg_labels.astype(int)
# current_seg_labels

current_good_cells = np.array([np.argwhere(ribo_seg_labels == i)[0][0] for i in current_seg_labels])
current_good_cells = current_good_cells.astype(int)
pdata.obs.loc[pdata.obs['protocol-replicate'] == 'RIBOmap-rep2', 'good_cells'] = current_good_cells
pdata.obs = pdata.obs.sort_values('good_cells')

In [None]:
# Add data ribomap-rep2
current_key = f"RIBOmap-rep2_morph"
pdata.uns[current_key]['good_cells'] = pdata.obs.loc[pdata.obs['sample'] == 'RIBOmap-rep2', 'good_cells'].values
pdata.uns[current_key]['colors'] = pdata.obs.loc[pdata.obs['sample'] == 'RIBOmap-rep2', 'temp_label'].cat.codes.values

In [None]:
# Plot
su.plot_poly_cells_cluster_by_sample_test(pdata, 'RIBOmap-rep2', temp_cmap, show_plaque=False, show_tau=True, linewidth=0.1, show_gfap=True,
                                        figscale=3, width=10, height=10, bg_color='#d7d7d7', save_as_real_size=True,
                                         save=True, show=False, output_dir=current_fig_path)

#### STARmap

In [None]:
# load segmentation ribomap-rep3
current_seg_path = os.path.join(temp_path, 'protein-images/visualization', 'STAR_labeled_cells.tif')
current_img = tifffile.imread(current_seg_path)

# Load neun image
current_neun_path = os.path.join(temp_path, 'protein-images/visualization', 'STAR_NeuN.tif')
current_neun = tifffile.imread(current_neun_path)

# Load Gfap image
current_gfap_path = os.path.join(temp_path, 'protein-images/visualization', 'STAR_Gfap.tif')
current_gfap = tifffile.imread(current_gfap_path)
    
# Store the images to adata object
current_key = f"STARmap-rep2_morph"
pdata.uns[current_key] = {}
pdata.uns[current_key]['label_img'] = current_img
pdata.uns[current_key]['tau'] = current_neun
pdata.uns[current_key]['Gfap'] = current_gfap
    
# Contruct polygon
pdata.uns[current_key]['qhulls'], pdata.uns[current_key]['coords'], pdata.uns[current_key]['centroids'] = su.get_qhulls_test(pdata.uns[current_key]['label_img'])

In [None]:
# Check color legend (old coloring scheme)
sns.reset_orig()
temp_order = ['Neuron', 'Astrocyte']
temp_colors = ['#e8f00c', '#cb99f7']

temp_pl = sns.color_palette(temp_colors)
temp_cmap = ListedColormap(temp_pl.as_hex())
sns.palplot(temp_pl, size=3)
plt.xticks(range(len(temp_order)), temp_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
# get new label for plotting 
pdata.obs['temp_label'] = 'Neuron'
pdata.obs.loc[pdata.obs['level_2_code'] == 'AC', 'temp_label'] = 'Astrocyte'

pdata.obs['temp_label'] = pdata.obs['temp_label'].astype('category')
pdata.obs['temp_label'] = pdata.obs['temp_label'].cat.reorder_categories(temp_order)

In [None]:
current_fig_path = os.path.join(fig_path, f'sct-with-protein')
if not os.path.exists(current_fig_path):
    os.mkdir(current_fig_path)

In [None]:
# get good cells
ribo_seg_labels = []
for i, region in enumerate(regionprops(current_img)):
    ribo_seg_labels.append(region.label)

ribo_seg_labels = np.array(ribo_seg_labels)

current_seg_labels = pdata.obs.loc[pdata.obs['protocol-replicate'] == 'STARmap-rep2', 'seg_label'].values
current_seg_labels = current_seg_labels.astype(int)
# current_seg_labels

current_good_cells = np.array([np.argwhere(ribo_seg_labels == i)[0][0] for i in current_seg_labels])
current_good_cells = current_good_cells.astype(int)
pdata.obs.loc[pdata.obs['protocol-replicate'] == 'STARmap-rep2', 'good_cells'] = current_good_cells
pdata.obs = pdata.obs.sort_values('good_cells')

In [None]:
# Add data starmap-rep2
current_key = f"STARmap-rep2_morph"
pdata.uns[current_key]['good_cells'] = pdata.obs.loc[pdata.obs['sample'] == 'STARmap-rep2', 'good_cells'].values
pdata.uns[current_key]['colors'] = pdata.obs.loc[pdata.obs['sample'] == 'STARmap-rep2', 'temp_label'].cat.codes.values

In [None]:
# Plot
su.plot_poly_cells_cluster_by_sample_test(pdata, 'STARmap-rep2', temp_cmap, show_plaque=False, show_tau=True, linewidth=0.1, show_gfap=True,
                                        figscale=3, width=10, height=10, bg_color='#d7d7d7', save_as_real_size=True,
                                         save=True, show=False, output_dir=current_fig_path)

### Oligo & OPC (as dots)

In [None]:
sdata = cdata[cdata.obs['level_2'].isin(['Oligodendrocyte', 'Oligodendrocytes precursor cell']), :]
sdata

#### single type

In [None]:
# color palette
# sub_pl = sns.color_palette(['#1568ed', '#FBB040', '#92278F'])
sub_pl = sns.color_palette(['#00A651', '#FBB040', '#92278F'])
sub_cmap = ListedColormap(sub_pl.as_hex())
sns.palplot(sub_pl)

In [None]:
sns.set(rc={'figure.facecolor':'white', 'axes.facecolor':'white'})
colors = ['#00A651', '#FBB040', '#92278F']
# get dfs 
for i, current_type in enumerate(['OPC', 'OLG1', 'OLG2']):
    
    current_sample = 'RIBOmap-rep2'
    current_df = cdata.obs.loc[cdata.obs['protocol-replicate'] == current_sample, :]
    current_oo_df = sdata.obs.loc[(sdata.obs['protocol-replicate'] == current_sample) & (sdata.obs['level_3'] == current_type), :].copy()

    size = 450

    # construct plots
    fig, ax = plt.subplots(figsize=(48, 60))

    # plot
    b1 = sns.scatterplot(x='column', y='row', color='#f0f0f0', 
                        data=current_df, 
                        s=size,
                        ax=ax)

    # b1.invert_yaxis()
    b1.axes.xaxis.set_visible(False)
    b1.axes.yaxis.set_visible(False)

    g1 = sns.scatterplot(x='column', y='row', 
                        color=colors[i],
                        data=current_oo_df,
                        marker='o',
                        s=size,
                        alpha=1,
                        linewidth=0,
                        legend=False,
                        ax=ax)

    # g1.set_title(current_sample)
    g1.invert_yaxis()
    g1.axes.xaxis.set_visible(False)
    g1.axes.yaxis.set_visible(False)

    plt.tight_layout()
    # plt.show()

    current_out_path = os.path.join(fig_path)
    if not os.path.exists(current_out_path):
        os.mkdir(current_out_path)
    plt.savefig(os.path.join(current_out_path, f'{current_sample}-{current_type}-s{size}.png'), dpi=300)

    # plt.clf()

    sns.reset_orig()

#### single sample

In [None]:
# color palette
# sub_pl = sns.color_palette(['#1568ed', '#FBB040', '#92278F'])
sub_pl = sns.color_palette(['#00A651', '#FBB040', '#92278F'])
sub_cmap = ListedColormap(sub_pl.as_hex())
sns.palplot(sub_pl)

In [None]:
sns.reset_orig()

In [None]:
sns.set(rc={'figure.facecolor':'white', 'axes.facecolor':'white'})

# get dfs 
current_sample = 'RIBOmap-rep2'
# markers = {'OPC': '^', 'Oligo1': 's', 'Oligo2': 'p'}
markers = {'OPC': 'o', 'OLG1': 'o', 'OLG2': 'o'}
current_df = cdata.obs.loc[cdata.obs['protocol-replicate'] == current_sample, :]
current_oo_df = sdata.obs.loc[sdata.obs['protocol-replicate'] == current_sample, :].copy()
current_oo_df['level_3'] = current_oo_df['level_3'].astype(object)
current_oo_df = current_oo_df.loc[current_oo_df['level_3'].isin(markers.keys()), :]
current_oo_df['level_3'] = current_oo_df['level_3'].astype('category')
current_oo_df['level_3'] = current_oo_df['level_3'].cat.reorder_categories(['OPC', 'OLG1', 'OLG2'])

# current_vector = sdata[sdata.obs['protocol-replicate'] == current_sample, :].layers['scaled'].mean(axis=1)
# vmax = ribo_vector.max()
# vmax = vmax * 0.8

size = 300

# construct plots
fig, ax = plt.subplots(figsize=(48, 60))

# plot
b1 = sns.scatterplot(x='column', y='row', color='#f0f0f0', 
                    data=current_df, 
                    s=size,
                    ax=ax)

# b1.invert_yaxis()
b1.axes.xaxis.set_visible(False)
b1.axes.yaxis.set_visible(False)

g1 = sns.scatterplot(x='column', y='row', hue='level_3', 
                    palette=sub_pl,
                    data=current_oo_df, style='level_3',
                    markers=markers,
                    s=size,
                    alpha=1,
                    linewidth=0,
                    legend=False,
                    ax=ax)

# g1.set_title(current_sample)
g1.invert_yaxis()
g1.axes.xaxis.set_visible(False)
g1.axes.yaxis.set_visible(False)

plt.tight_layout()
# plt.show()

current_out_path = os.path.join(fig_path)
if not os.path.exists(current_out_path):
    os.mkdir(current_out_path)
plt.savefig(os.path.join(current_out_path, f'{current_sample}-s{size}.png'), dpi=300)

# plt.clf()

sns.reset_orig()

#### with gene list

In [None]:
gene_df = pd.read_excel(os.path.join(base_path, "other-datasets", "Oligo_cell_type_6_14_15_20221217.xlsx"))
gene_df

In [None]:
sns.set(rc={'figure.facecolor':'white', 'axes.facecolor':'white'})

# get dfs 
current_sample = 'STARmap-rep3'
markers = {'OPC': '^', 'Oligo1': 's', 'Oligo2': 'p'}
current_df = adata.obs.loc[adata.obs['protocol-replicate'] == current_sample, :]
current_oo_df = sdata.obs.loc[sdata.obs['protocol-replicate'] == current_sample, :].copy()
current_oo_df['level_3'] = current_oo_df['level_3'].astype(object)
current_oo_df = current_oo_df.loc[current_oo_df['level_3'].isin(markers.keys()), :]
current_oo_df['level_3'] = current_oo_df['level_3'].astype('category')
current_oo_df['level_3'] = current_oo_df['level_3'].cat.reorder_categories(['OPC', 'Oligo1', 'Oligo2'])

current_adata = adata[current_oo_df.index, :]


for i, gene in enumerate(tqdm(gene_df['level_3'].to_list())):
# for i, gene in enumerate(tqdm(['Mbp'])):
    
    current_vector = current_adata[:, gene].layers['scaled'].flatten()
    vmax = current_vector.max()
    vmin = current_vector.min()

    # construct plots
    fig, ax = plt.subplots(figsize=(48, 60))

    # plot
    b1 = sns.scatterplot(x='column', y='row', color='#f0f0f0', 
                        data=current_df, 
                        s=90,
                        ax=ax)

    # b1.invert_yaxis()
    b1.axes.xaxis.set_visible(False)
    b1.axes.yaxis.set_visible(False)

    g1 = sns.scatterplot(x='column', y='row', hue=current_vector, 
                        # palette='Spectral_r',
                        palette=cmap,
                        data=current_oo_df, style='level_3',
                        markers=markers,
                        vmin=vmin,
                        vmax=vmax,
                        s=90,
                        alpha=1,
                        linewidth=0,
                        legend=False,
                        ax=ax)

    # g1.set_title(current_sample)
    g1.invert_yaxis()
    g1.axes.xaxis.set_visible(False)
    g1.axes.yaxis.set_visible(False)

    plt.tight_layout()
    # plt.show()
    plt.savefig(os.path.join(fig_path, '2022-12-18-OO-gene-spatial-map', f'{current_sample}-{gene}.png'), dpi=300)
    plt.clf()

    sns.reset_orig()

In [None]:
sns.set(rc={'figure.facecolor':'white', 'axes.facecolor':'white'})

# get dfs 
current_sample = 'RIBOmap-rep3'
markers = {'OPC': '^', 'Oligo1': 's', 'Oligo2': 'p'}
current_df = adata.obs.loc[adata.obs['protocol-replicate'] == current_sample, :]
current_oo_df = sdata.obs.loc[sdata.obs['protocol-replicate'] == current_sample, :].copy()
current_oo_df['level_3'] = current_oo_df['level_3'].astype(object)
current_oo_df = current_oo_df.loc[current_oo_df['level_3'].isin(markers.keys()), :]
current_oo_df['level_3'] = current_oo_df['level_3'].astype('category')
current_oo_df['level_3'] = current_oo_df['level_3'].cat.reorder_categories(['OPC', 'Oligo1', 'Oligo2'])


current_adata = adata[current_oo_df.index, :]


for i, gene in enumerate(tqdm(gene_df['level_3'].to_list())):
# for i, gene in enumerate(tqdm(['Mbp'])):
    
    current_vector = current_adata[:, gene].layers['scaled'].flatten()
    vmax = current_vector.max()
    vmin = current_vector.min()

    # construct plots
    fig, ax = plt.subplots(figsize=(48, 60))

    # plot
    b1 = sns.scatterplot(x='column', y='row', color='#f0f0f0', 
                        data=current_df, 
                        s=90,
                        ax=ax)

    # b1.invert_yaxis()
    b1.axes.xaxis.set_visible(False)
    b1.axes.yaxis.set_visible(False)

    g1 = sns.scatterplot(x='column', y='row', hue=current_vector, 
                        palette=cmap,
                        # palette='Spectral_r',
                        data=current_oo_df, style='level_3',
                        markers=markers,
                        vmin=vmin,
                        vmax=vmax,
                        s=90,
                        alpha=1,
                        linewidth=0,
                        legend=False,
                        ax=ax)

    # g1.set_title(current_sample)
    g1.invert_yaxis()
    g1.axes.xaxis.set_visible(False)
    g1.axes.yaxis.set_visible(False)

    plt.tight_layout()
    # plt.show()
    plt.savefig(os.path.join(fig_path, '2022-12-18-OO-gene-spatial-map', f'{current_sample}-{gene}.png'), dpi=300)
    plt.clf()

    sns.reset_orig()

## Gene markers

### level-2

In [None]:
rdata = rdata[rdata.obs['level_2_code'] != 'Mix', :]

In [None]:
rdata.uns['log1p']['base'] = None

In [None]:
# Add log layer
rdata.layers['log_raw'] = np.log1p(rdata.layers['raw'])
sc.pp.normalize_total(rdata, layer='log_raw')

# Find gene markers for each cluster
sc.tl.rank_genes_groups(rdata, 'level_2_code', method='wilcoxon', layer='log_raw', pts=True, use_raw=False, n_genes=rdata.shape[1])

# Filter markers
sc.tl.filter_rank_genes_groups(rdata, min_fold_change=.1, min_in_group_fraction=0.2, max_out_group_fraction=0.8)

In [None]:
# astro
marker_genes_dict = {}
marker_genes_dict['TEPN'] = ['Slc17a7', 'Atp1a1', 'Ppp3r1', 'Nrgn', 'Mapk1']
marker_genes_dict['INH'] = ['Gad1', 'Gad2', 'Sst', 'Pvalb', 'Slc32a1']
marker_genes_dict['CHO_PEP'] = ['Resp18', 'Scg2', 'Hap1', 'Pnmal2', 'Ly6h', ]
marker_genes_dict['DE_MEN'] = ['Pcp4', 'Prkcd', 'Synpo2', 'Plekhg1', 'Ntng1', ]
marker_genes_dict['AC'] = ['Aldoc', 'Gja1', 'Clu', 'Ttyh1', 'Mt2']
marker_genes_dict['OLG'] = ['Mbp', 'Mal', 'Fth1', 'Aplp1', 'Plp1']
marker_genes_dict['VAS'] = ['Bsg', 'Myh9', 'Flt1', 'Vtn', 'B2m']
marker_genes_dict['CHOR_EPEN'] = ['Ttr', 'Enpp2', 'Rarres2', 'Cd24a', 'Prelp']
marker_genes_dict['PVM'] = ['Cyp2f2', 'Galm', 'H2-Aa', 'Cd74']
marker_genes_dict['MLG'] = ['Csf1r', 'Hexb', 'Ctss', 'C1qb', 'C1qa']
marker_genes_dict['OPC'] = ['Pdgfra', 'Cacng4', 'Cspg5', 'Kcnip3']

In [None]:
# Dot plot logfoldchanges
sc.pl.rank_genes_groups_dotplot(rdata, key='rank_genes_groups', var_names=marker_genes_dict, 
                                values_to_plot='logfoldchanges', min_logfoldchange=1, vmax=5, vmin=-5, cmap='bwr', 
                                dendrogram=False, swap_axes=True, save='level_2')

In [None]:
rep2_adata = cdata[cdata.obs['replicate'] == 'rep2', ]
rep2_adata

In [None]:
adata = rep2_adata[(rep2_adata.obs['protocol'] == 'RIBOmap') & (rep2_adata.obs['level_2'] != 'Mix'), ]

In [None]:
adata.uns['log1p']['base'] = None

In [None]:
# Add log layer
adata.layers['log_raw'] = np.log1p(adata.layers['raw'])
sc.pp.normalize_total(adata, layer='log_raw')

# Find gene markers for each cluster
sc.tl.rank_genes_groups(adata, 'level_2_code', method='wilcoxon', layer='log_raw', pts=True, use_raw=False, n_genes=adata.shape[1])

# Filter markers
sc.tl.filter_rank_genes_groups(adata, min_fold_change=.1, min_in_group_fraction=0.2, max_out_group_fraction=0.8)

In [None]:
# Dot plot logfoldchanges
sc.pl.rank_genes_groups_matrixplot(adata, key='rank_genes_groups', var_names=marker_genes_dict, groupby='level_2_code', categories_order=list(marker_genes_dict.keys()), 
                                values_to_plot='logfoldchanges', min_logfoldchange=1, vmax=5, vmin=-5, cmap='bwr',
                                dendrogram=False, swap_axes=False, save='rep2_ribo_level_2')

### level-3

In [None]:
# Find gene markers for each cluster
sc.tl.rank_genes_groups(rdata, 'level_3', method='wilcoxon', layer='log_raw', pts=True, use_raw=False, n_genes=rdata.shape[1])

# Filter markers
sc.tl.filter_rank_genes_groups(rdata, min_fold_change=.1, min_in_group_fraction=0.2, max_out_group_fraction=0.8)

In [None]:
# astro
marker_genes_dict = {}
marker_genes_dict['TEGLU COA'] = ['Synpr', 'Nr2f2'] # 
marker_genes_dict['TEGLU CA1'] = ['Tmsb4x', 'Ppp3r1'] # 
marker_genes_dict['TEGLU CA2'] = ['Sv2b',] # 
marker_genes_dict['TEGLU CA3'] = ['Chgb', 'Nell2'] #

marker_genes_dict['TEGLU L2_3'] = ['Slc17a7', 'Cplx2'] #
marker_genes_dict['TEGLU L4_5'] = ['Rgs4', 'Dkk3'] #
marker_genes_dict['TEGLU L5'] = ['Snca',] # 
marker_genes_dict['TEGLU L6'] = ['Sncb', 'Nr4a2',] # 
marker_genes_dict['TEGLU L6a'] = ['Pcp4',] # 
marker_genes_dict['TEGLU Mix'] = ['Prkcb', 'Diras2'] #
marker_genes_dict['TEGLU PIR'] = ['Lmo3', ] #
marker_genes_dict['DGGRC'] = ['Rbfox1', 'Ppp3ca'] # 
marker_genes_dict['MSN'] = ['Phactr1', 'Ppp1r1b', 'Rasd2'] # 

marker_genes_dict['INH Pvalb1'] = ['Pvalb',] # 
marker_genes_dict['INH Pvalb2'] = ['Gad2',] # 
marker_genes_dict['INH Sst'] = ['Sst', 'Gad1', 'Npy'] # 

marker_genes_dict['DECHO'] = ['Gabbr1', 'Cadps2'] # 
marker_genes_dict['PEP1'] = ['Resp18',] # 
marker_genes_dict['PEP2'] = ['Gap43', ] # 
marker_genes_dict['PEP3'] = ['Scg2', 'Pnmal2', 'Dlk1'] # 

marker_genes_dict['DEGLU1'] = ['Stmn1', 'Tubb5'] # 
marker_genes_dict['DEGLU2'] = ['Synpo2', 'Plekhg1'] # 

marker_genes_dict['AC1'] = ['Mfge8',] # 
marker_genes_dict['AC2'] = ['Apoe',] # 
marker_genes_dict['AC3'] = ['Prnp',] # 
marker_genes_dict['AC4'] = ['Gfap', 'Clu', 'Tspan7'] #

marker_genes_dict['OLG1'] = ['Mal', 'Bin1', 'Tubb4a'] # 
marker_genes_dict['OLG2'] = ['Mbp', ] # 

marker_genes_dict['Peri_VEC1'] = ['Bsg',] # 
marker_genes_dict['Peri_VEC2'] = ['Epas1',] # 
marker_genes_dict['VLMC'] = ['Ptgds', 'Gjb2'] #
marker_genes_dict['VSMC'] = ['Myl9', 'Myh11'] # 

marker_genes_dict['CHOR'] = ['Ttr', 'Enpp2'] # 
marker_genes_dict['EPEN'] = ['Rarres2', 'Tppp3'] #
marker_genes_dict['PVM1'] = ['Cyp2f2', 'Galm'] # 
marker_genes_dict['PVM2'] = ['H2-Aa', 'Cd74'] # 

marker_genes_dict['MLG'] = ['Hexb', 'Ctss'] # 
marker_genes_dict['OPC'] = ['Pdgfra', 'Cacng4'] # 

In [None]:
# Dot plot logfoldchanges
sc.pl.rank_genes_groups_dotplot(rdata, key='rank_genes_groups', var_names=marker_genes_dict, categories_order=list(marker_genes_dict.keys()), 
                                values_to_plot='logfoldchanges', min_logfoldchange=1, vmax=5, vmin=-5, cmap='bwr', 
                                dendrogram=False, swap_axes=True, save='level_3')

In [None]:
# Dot plot logfoldchanges
sc.pl.rank_genes_groups_matrixplot(rdata, key='rank_genes_groups', var_names=marker_genes_dict, categories_order=list(marker_genes_dict.keys()), 
                                values_to_plot='logfoldchanges', min_logfoldchange=1, vmax=5, vmin=-5, cmap='bwr', 
                                dendrogram=False, swap_axes=True, save='level_3')

## Other Analysis

### Gfap/NeuN signal comparison

In [None]:
# add comparison label 
cdata.obs['comparison_label'] = 'NA'
cdata.obs.loc[cdata.obs['level_2_code'].isin(['TEPN', 'INH', 'CHO_PEP', 'DE_MEN']), 'comparison_label'] = 'Neuron'
cdata.obs.loc[cdata.obs['level_2_code'] == 'AC', 'comparison_label'] = 'Astro'

In [None]:
current_sample = 'STARmap-rep2'
current_protein = 'NeuN'
current_df = cdata.obs.loc[(cdata.obs['protocol-replicate'] == current_sample) & (cdata.obs['comparison_label'] != 'NA'), :].copy()
current_df['comparison_label'] = current_df['comparison_label'].astype(object)
current_df['comparison_label'] = current_df['comparison_label'].astype('category')
current_df = current_df.loc[current_df[f'{current_protein}_pixel'] != 0, :]

In [None]:
# box plot

fig, ax = plt.subplots(figsize=(5, 5))
ax = sns.boxplot(x="comparison_label", y=f"{current_protein}_pixel_norm", data=current_df, showfliers=False, width=.4)

plt.xticks(rotation=45)

annot = Annotator(ax, [('Neuron', 'Astro')], plot='violinplot', data=current_df, x='comparison_label', y=f"{current_protein}_pixel_norm")
annot.configure(test='t-test_ind', text_format='star', loc='outside', verbose=2)
annot.apply_test(alternative='two-sided').annotate()

plt.savefig(os.path.join(fig_path, f'boxplot_{current_sample}_{current_protein}.pdf'))
plt.show()

### Oligo & OPC diffusion

In [None]:
pdata_ribo = cdata[cdata.obs['protocol-replicate'] == 'RIBOmap-rep2', :].copy()
pdata_ribo

#### RIBOmap

In [None]:
# Redo preprocessing
pdata_ribo.X = pdata_ribo.layers['raw'].copy()
del pdata_ribo.layers

pdata_ribo.layers['raw'] = pdata_ribo.X.copy()

# # Normalization scaling
sc.pp.normalize_total(pdata_ribo)
sc.pp.log1p(pdata_ribo)

pdata_ribo.layers['norm'] = pdata_ribo.X.copy()
pdata_ribo.raw = pdata_ribo

# # sc.pp.highly_variable_genes(sdata, min_mean=0.01, max_mean=3, min_disp=0.5)
# # sc.pl.highly_variable_genes(sdata)

# # Scale data to unit variance and zero mean
sc.pp.scale(pdata_ribo)
pdata_ribo.layers['scaled'] = pdata_ribo.X.copy()

# Batch correction
sc.pp.regress_out(pdata_ribo, 'total_counts')
pdata_ribo.layers['corrected'] = pdata_ribo.X.copy()

In [None]:
# Subset and Run PCA # 1
pdata_ribo = pdata_ribo[pdata_ribo.obs['level_2'].isin(['Oligodendrocyte', 'Oligodendrocytes precursor cell']), ]

pdata_ribo.X = pdata_ribo.layers['corrected'].copy()
sc.tl.pca(pdata_ribo, svd_solver='full', use_highly_variable=False, zero_center=True)

In [None]:
n_neighbors = 50
n_pcs = 5
sc.pp.neighbors(pdata_ribo, n_neighbors=n_neighbors, n_pcs=n_pcs)
sc.tl.diffmap(pdata_ribo, n_comps=n_pcs)

In [None]:
# sc.pl.diffmap(pdata_ribo, color='level_2_code')
sc.pl.diffmap(pdata_ribo, color='level_2', components=[('2, 3')])
sc.pl.diffmap(pdata_ribo, color='level_3', components=[('2, 3')])
sc.pl.diffmap(pdata_ribo, color='Plp1', components=[('2, 3')])

In [None]:
# save embeddings
np.savetxt(f'{fig_path}/embedding_ribomap_oo_diffmap.csv', pdata_ribo.obsm['X_diffmap'], delimiter=",")

In [None]:
# save h5ad
pdata_ribo.write_h5ad(os.path.join(out_path, "2023-05-04-RIBOmap-oo-diffmap.h5ad"))

In [None]:
# set color
sub_pl = sns.color_palette(['#00A651', '#FBB040', '#92278F'])
sub_cmap = ListedColormap(sub_pl.as_hex())
pdata_ribo.obs['level_3'] = pdata_ribo.obs['level_3'].cat.reorder_categories(['OPC', 'OLG1', 'OLG2'])
sns.palplot(sub_pl)

In [None]:
# sc.settings.figdir = fig_path
sc.set_figure_params(dpi_save=300, vector_friendly=False)

In [None]:
# Save plots
# Plot UMAP with cluster labels w/ new color
sc.pl.diffmap(pdata_ribo, color='level_3', legend_loc='right margin', components=[('2, 3')],
           legend_fontsize=12, legend_fontoutline=2, frameon=False, 
           title='', palette=sub_pl, save='_ribo_OO.pdf')

sc.pl.diffmap(pdata_ribo, color='level_3', legend_loc=None, frameon=False, components=[('2, 3')],
           title='', palette=sub_pl, save='_ribo_OO_no_legend.pdf')

sc.pl.diffmap(pdata_ribo, color='level_3', legend_loc=None, frameon=False, components=[('2, 3')],
           title='', palette=sub_pl, save='_ribo_OO_no_legend.png')

In [None]:
# gene on diffmap
genes = ['Pdgfra', 'Vcan', 'Bmp4', 'Klk6', 'Pcdh15', 'Mobp', 'Plp1', 'Mbp']

for gene in genes:
    fig, ax = plt.subplots(figsize=(7,5))
    sc.pl.diffmap(pdata_ribo, color=gene, legend_loc=None, frameon=False, ax=ax, components=[('2, 3')],
           title='', palette='viridis', save=f'_{gene}.png')

In [None]:
pdata_ribo.obsm['X_diffmap'].shape

### Oligo type composition (region)

In [None]:
current_df = cdata.obs.loc[(cdata.obs['replicate'] == 'rep2') & (~cdata.obs['region'].isin(['NA', 'other'])), :].copy()
current_df['region'] = current_df['region'].astype(object)
current_df['region'] = current_df['region'].astype('category')
current_df

In [None]:
current_df = current_df.loc[(current_df['protocol'] == 'RIBOmap') & (current_df['level_3'].isin(['OLG1', 'OLG2', 'OPC'])), :]
current_df['level_3'] = current_df['level_3'].astype(object)
current_df['level_3'] = current_df['level_3'].astype('category')

In [None]:
count_df = pd.DataFrame(current_df.groupby('level_3')['region'].value_counts())
count_df = count_df.reset_index()
count_df = count_df.pivot(index='level_3', columns='level_1')
# count_df = count_df.div(count_df.sum(axis=1), axis=0)
count_df.columns = count_df.columns.droplevel()
# count_df = count_df.stack().reset_index()
# count_df.columns = ['level_3', 'region', 'counts']
count_df

In [None]:
count_df.to_csv(os.path.join(fig_path, 'cluster_freq_ribomap_oo.csv'))

### Sankey diagram

In [None]:
def contingency(a, b, unique_a, unique_b):
    """Populate contingency matrix. Rows and columns are not normalized in any way.

    Args:
        a (np.array): labels
        b (np.array): labels
        unique_a (np.array): unique list of labels. Can have more entries than np.unique(a)
        unique_b (np.array): unique list of labels. Can have more entries than np.unique(b)

    Returns:
        C (np.array): contingency matrix.
    """
    # assert a.shape == b.shape
    C = np.zeros((np.size(unique_a), np.size(unique_b)))
    for i, la in enumerate(unique_a):
        for j, lb in enumerate(unique_b):
            C[i, j] = np.sum(np.logical_and(a == la, b == lb))
            
    df = pd.DataFrame(C)
    df.index = unique_a
    df.columns = unique_b

    return df

In [None]:
def generate_nodes(df, key1, key2):
    
    df_river = {key1: [], key2: [], 'value': []}
    for key2_type in df.columns.to_list():
        for key1_type in df.index.to_list():
            df_river[key1].append(key1_type)
            df_river[key2].append(key2_type)
            df_river['value'].append(df.loc[key1_type, key2_type])
    df_river = pd.DataFrame(df_river)
    df_river = df_river.loc[df_river['value'] != 0, :]

    all_nodes = df_river[key1].unique().tolist() + df_river[key2].unique().tolist()
    source_indices = [all_nodes.index(key1_type) for key1_type in df_river[key1]]
    target_indices = [all_nodes.index(key2_type) for key2_type in df_river[key2]]
    
    return df_river, all_nodes, source_indices, target_indices

In [None]:
level_12_df = contingency(rdata.obs['level_1'].values, 
                   rdata.obs['level_2'].values, 
                   rdata.obs['level_1'].cat.categories.values, 
                   rdata.obs['level_2'].cat.categories.values)
level_12_df

In [None]:
df_river, all_nodes, source_indices, target_indices = generate_nodes(level_12_df, 'level_1', 'level_2')

level_12_colors = list(rdata.uns['level_1_color']) + list(rdata.uns['level_2_color'])
node_colors_mappings = dict([(node, level_12_colors[i]) for i, node in enumerate(all_nodes)])
node_colors = [node_colors_mappings[node] for node in all_nodes]
edge_colors = [node_colors_mappings[node] for node in df_river['level_1']]
edge_colors = ['#ebebeb'] * len(all_nodes)

import plotly.graph_objects as go
save_river = False
fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=20,
            thickness=20,
            line=dict(color="black", width=1.0),
            label=all_nodes,
            color=node_colors,
        ),

        link=dict(
            source=source_indices,
            target=target_indices,
            value=df_river['value'],
            color=edge_colors,
        ))])

fig.update_layout(title_text="level1 to level2",
                  height=600,
                  font_size=10)
if save_river:
    fig.write_image(save_river)
fig.show()

In [None]:
level_23_df = contingency(rdata.obs['level_2'].values, 
                   rdata.obs['level_3'].values, 
                   rdata.obs['level_2'].cat.categories.values, 
                   rdata.obs['level_3'].cat.categories.values)
level_23_df

In [None]:
df_river, all_nodes, source_indices, target_indices = generate_nodes(level_23_df, 'level_2', 'level_3')

node_x = [0] * 12 + [1] * 39
node_y = [i for i in range(12)] + [i*0.05 for i in range(39)]

level_12_colors = list(rdata.uns['level_2_color']) + list(rdata.uns['level_3_color'])
node_colors_mappings = dict([(node, level_12_colors[i]) for i, node in enumerate(all_nodes)])
node_colors = [node_colors_mappings[node] for node in all_nodes]
edge_colors = [node_colors_mappings[node] for node in df_river['level_2']]
edge_colors = ['#ebebeb'] * len(all_nodes)

import plotly.graph_objects as go
save_river = False
fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=10,
            thickness=20,
            line=dict(color="black", width=1.0),
            label=all_nodes,
            color=node_colors,
            # x=node_x, 
            # y=node_y,
        ),

        link=dict(
            source=source_indices,
            target=target_indices,
            value=df_river['value'],
            color=edge_colors,
        ))])

fig.update_layout(
                  height=3000,
                  font_size=20)
if save_river:
    fig.write_image(save_river)
fig.show()

#### combine

In [None]:
# get tables 

level_12_df = contingency(rdata.obs['level_1'].values, 
                   rdata.obs['level_2'].values, 
                   list(rdata.uns['level_1_order']), 
                   list(rdata.uns['level_2_order']))

level_23_df = contingency(rdata.obs['level_2'].values, 
                   rdata.obs['level_3'].values, 
                   list(rdata.uns['level_2_order']), 
                   list(rdata.uns['level_3_order']))

In [None]:
# get nodes

df_river_12, all_nodes_12, source_indices_12, target_indices_12 = generate_nodes(level_12_df, 'level_1', 'level_2')
df_river_23, all_nodes_23, source_indices_23, target_indices_23 = generate_nodes(level_23_df, 'level_2', 'level_3')

all_nodes = list(rdata.uns['level_1_order']) + list(rdata.uns['level_2_order']) + list(rdata.uns['level_3_order'])
source_indices = source_indices_12 + [i+3 for i in source_indices_23]
target_indices = target_indices_12 + [i+3 for i in target_indices_23]

In [None]:
# get colors and values 
all_colors = list(rdata.uns['level_1_color']) + list(rdata.uns['level_2_color']) + list(rdata.uns['level_3_color'])
node_colors_mappings = dict([(node, all_colors[i]) for i, node in enumerate(all_nodes)])
node_colors = [node_colors_mappings[node] for node in all_nodes]
edge_colors = [node_colors_mappings[node] for node in df_river_12['level_1']] + [node_colors_mappings[node] for node in df_river_23['level_2']]
edge_colors = ['#ebebeb'] * (df_river_12['level_1'].shape[0] + df_river_23['level_2'].shape[0])

link_values = df_river_12['value'].to_list() + df_river_23['value'].to_list()

In [None]:
import plotly.graph_objects as go
save_river = False
fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=30,
            thickness=20,
            line=dict(color="black", width=1.0),
            label=all_nodes,
            color=node_colors,
        ),

        link=dict(
            source=source_indices,
            target=target_indices,
            value=link_values,
            color=edge_colors,
        ))])

fig.update_layout(width=1500, 
                  height=6000,
                  font_size=20)
# if save_river:
fig.write_image(os.path.join(fig_path, 'sankey-test.pdf'))
fig.show()