# Visualization

2021-04-15

In [None]:
# Import Packages

%load_ext autoreload
%autoreload 2

import os
import warnings 
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
from skimage.filters import threshold_otsu, gaussian
from skimage.morphology import remove_small_objects
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from statannot import add_stat_annotation
from anndata import AnnData

# Customized packages
from starmap.utilities import *
from starmap.sequencing import *
from starmap.obj import STARMapDataset, load_data
# import starmap.analyze as anz
# import starmap.viz as viz
import starmap.sc_util as su

# test()

In [None]:
# Get functions 

import colorsys
from random import shuffle

def intervals(parts, start_point, end_point):
    duration = end_point - start_point
    part_duration = duration / parts
    return [((i * part_duration + (i + 1) * part_duration)/2) + start_point for i in range(parts)]

def change_width(ax, new_value) :
    for patch in ax.patches :
        current_width = patch.get_width()
        diff = current_width - new_value

        # we change the bar width
        patch.set_width(new_value)

        # we recenter the bar
        patch.set_x(patch.get_x() + diff * .5)

## Input

In [None]:
# path 
base_path = 'Z:/Data/Analyzed/2021-03-20-mAD-64-genes-reads-assignment/'

In [None]:
# Load adata
adata = sc.read_h5ad('Z:/Data/Analyzed/2021-03-20-mAD-64-genes-reads-assignment/output/2021-04-22-starmap-mAD-64-genes-scaled.h5ad')
adata

In [None]:
# Load data and store the information to uns (use scaled version to save computational time)
for sample in sorted(adata.obs['sample'].unique()):
    print(sample)
    
    # Load segmentation
    current_seg_path = os.path.join(base_path, sample, 'scaled/labeled_cells.tif')
    current_img = tifffile.imread(current_seg_path)
    
    # Load plaque image
    current_plaque_path = os.path.join(base_path, sample, 'scaled/plaque.tif')
    current_plaque = tifffile.imread(current_plaque_path)
    # uniform manual threshold 
    current_plaque = gaussian(current_plaque.astype(np.float), 2) > 30
    # current_plaque = remove_small_objects(current_plaque, min_size=64, connectivity=4)
    
    # Load tau image
    current_tau_path = os.path.join(base_path, sample, 'scaled/tau_mask.tif')
    current_tau = tifffile.imread(current_tau_path)

    # Load Gfap image
    current_gfap_path = os.path.join(base_path, sample, 'scaled/Gfap_mask.tif')
    current_gfap = tifffile.imread(current_gfap_path)

    
    # Store the images to adata object
    current_key = f"{sample}_morph"
    adata.uns[current_key] = {}
    adata.uns[current_key]['label_img'] = current_img
    adata.uns[current_key]['plaque'] = current_plaque
    adata.uns[current_key]['tau'] = current_tau
    adata.uns[current_key]['Gfap'] = current_gfap

In [None]:
# Restore convex hull and top-level info
for sample in sorted(adata.obs['sample'].unique()):
    print(sample)
    
    current_key = f"{sample}_morph"
    adata.uns[current_key]['qhulls'], adata.uns[current_key]['coords'], adata.uns[current_key]['centroids'] = su.get_qhulls(adata.uns[current_key]['label_img'])
    
    current_index = adata.obs['sample'] == sample
    adata.uns[current_key]['good_cells'] = adata.obs.loc[current_index, 'orig_index'].astype(int).values
    adata.uns[current_key]['colors'] = adata.obs.loc[current_index, 'top_level'].cat.codes.values
    
    # add tau positive info
    # tau_threshold = 30
    # adata.uns[current_key]['tau_index'] = (adata.obs['tau'] > tau_threshold) & (current_index)
    

In [None]:
# Figure parameter
plt.rcParams['savefig.dpi'] = 100

## Top-level

In [None]:
%%time
# Run PCA
sc.tl.pca(adata, svd_solver='arpack')

# Plot explained variance 
sc.pl.pca_variance_ratio(adata, log=False)

# Computing the neighborhood graph
n_neighbors = 50
n_pcs = 30
cluster_resolution = 1

sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs)

# Run UMAP
sc.tl.umap(adata)

### UMAP

In [None]:
# Plot single meta UMAP
sc.pl.umap(adata, color='sample', save='_sample')
sc.pl.umap(adata, color='batch', save='_batch')
sc.pl.umap(adata, color='group', save='_group')
sc.pl.umap(adata, color='time', save='_time')

# Plot sample-wise UMAP
for sample in sorted(adata.obs['sample'].unique()):
    print(sample)
    sc.pl.umap(adata, color='sample', groups=sample, save=f'_sample_{sample}')

# Plot group-wise UMAP
for group in sorted(adata.obs['group'].unique()):
    print(group)
    sc.pl.umap(adata, color='group', groups=group, save=f'_group_{group}')

# Save log
with open(f'./figures/log.txt', 'w') as f:
    f.write(f"""Number of neighbor: {n_neighbors}
Number of PC: {n_pcs}""")

In [None]:
# Check color legend
top_cpl = []
for i in adata.uns['top_level_order_64']:
    top_cpl.append(adata.uns['top_rgb_dict_64'][i])

top_cpl = sns.color_palette(top_cpl)
top_cmap = ListedColormap(top_cpl.as_hex())
sns.palplot(top_cpl, size=3)
plt.xticks(range(len(adata.uns['top_level_order_64'])), adata.uns['top_level_order_64'], size=10, rotation=45)
plt.tight_layout()
plt.savefig(f'./figures/color_legend_top.png')
plt.show()

In [None]:
# Plot UMAP with cluster labels w/ new color
sc.pl.umap(adata, color='top_level', frameon=False, 
           palette=top_cpl, save='_legend_side')

# Plot UMAP with cluster labels w/ new color
sc.pl.umap(adata, color='top_level', legend_loc='on data',
           legend_fontsize=8, legend_fontoutline=2, frameon=False, 
           palette=top_cpl, save=True)

# Plot sample wise UMAP with top-level labels 
for sample in sorted(adata.obs['sample'].unique()):
    print(sample)
    ax = sc.pl.umap(adata, show=False, size=(120000 / adata.n_obs))
    sc.pl.umap(adata[adata.obs["sample"] == sample], color='top_level', frameon=True, ax=ax, size=(120000 / adata.n_obs),
               palette=top_cpl, save=f'_{sample}')

# Plot group wise UMAP with top-level labels 
for group in sorted(adata.obs['group'].unique()):
    print(group)
    ax = sc.pl.umap(adata, show=False, size=(120000 / adata.n_obs))
    sc.pl.umap(adata[adata.obs["group"] == group], color='top_level', frameon=True, ax=ax, size=(120000 / adata.n_obs),
               palette=top_cpl, save=f'_{group}')

### Marker related 

In [None]:
# Find gene markers for each cluster
sc.tl.rank_genes_groups(adata, 'top_level', method='wilcoxon')

# # Plot logFC heatmap
# sc.pl.rank_genes_groups_heatmap(adata, n_genes=5, groupby='top_level', min_logfoldchange=1, use_raw=False, swap_axes=True, 
#                                 vmin=-5, vmax=5, cmap='bwr', show_gene_labels=True, values_to_plot='logfoldchanges',
#                                 dendrogram=False, figsize=(30, 15), save='_logFC')

# Plot z-score heatmap
sc.pl.rank_genes_groups_heatmap(adata, n_genes=5, groupby='top_level', min_logfoldchange=1, use_raw=False, swap_axes=True, 
                                vmin=-5, vmax=5, cmap='bwr', show_gene_labels=True,
                                dendrogram=False, figsize=(30, 15), save='_zscore')

# Plot z-score heatmap big
sc.pl.rank_genes_groups_heatmap(adata, n_genes=15, groupby='top_level', min_logfoldchange=1, use_raw=False, swap_axes=True, 
                                vmin=-5, vmax=5, cmap='bwr', show_gene_labels=False,
                                dendrogram=False, figsize=(30, 15), save='_zscore_big')

# Plot logFC dotplot
sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, groupby='top_level', values_to_plot='logfoldchanges', min_logfoldchange=1, 
                                vmax=5, vmin=-5, cmap='bwr', save='logFC')

# Plot expression violin plot
sc.pl.rank_genes_groups_stacked_violin(adata, n_genes=5, groupby='top_level', min_logfoldchange=1, 
                                       cmap='viridis_r', save='top')

# Plot expression violin plot
sc.pl.rank_genes_groups_stacked_violin(adata, n_genes=5, groupby='top_level', min_logfoldchange=1, 
                                       cmap='viridis_r', dendrogram=False, save='top_noden')


# # Print markers 
# markers = []
# temp = pd.DataFrame(adata.uns['rank_genes_groups']['names']).head(5)
# for i in range(temp.shape[1]):
#     curr_col = temp.iloc[:, i].to_list()
#     markers = markers + curr_col
#     print(curr_col)
    
# print(markers)
# plt.figure(figsize=(20,10))
# su.plot_heatmap_with_labels(adata, markers, 'leiden', use_labels=top_level_order,
#                             cmap=cluster_cmap, show_axis=True, font_size=10)
# plt.savefig('./figures/heatmap_top_v2.pdf')

### Composition bar plot

In [None]:
# Composition Barplot 
n_cat = adata.obs['sample'].cat.categories.shape[0]
fig, ax = plt.subplots(n_cat, 1, figsize=(10,10))
fig.tight_layout()

for i, sample in enumerate(adata.obs['sample'].cat.categories):
    
    curr_cells = (adata.obs['sample'] == sample)
    temp = adata[curr_cells, :]

    cell_dist = pd.DataFrame(temp.obs['top_level'].value_counts())
    cell_dist['counts'] = cell_dist['top_level']
    cell_dist['top_level'] = cell_dist.index
    cell_dist['top_level'] = cell_dist['top_level'].astype(object)
    cell_dist = cell_dist.reset_index(drop=True)
    cell_dist = cell_dist.loc[cell_dist['counts'] != 0, :]
    cell_dist['top_level'] = cell_dist['top_level'].astype('category')
    cell_dist['top_level'] = cell_dist['top_level'].cat.reorder_categories(adata.uns['top_level_order'])


    g = sns.barplot(x='top_level', y='counts', data=cell_dist, palette=top_cpl, ax=ax[i]) 
    for index, sub in enumerate(cell_dist['top_level'].cat.categories):
        curr_position = round(cell_dist.loc[cell_dist['top_level'] == sub, 'counts'].values[0], 1)
        g.text(index, curr_position, curr_position, color='black', ha="center")

    for spine in ax[i].spines.values():
        spine.set_visible(False)
        
    ax[i].set_ylabel(sample, rotation=0, labelpad=50)
    # ax[i].set(ylim=(0, 1000))
    if i == n_cat-1:
        ax[i].get_xaxis().set_visible(True)
        ax[i].tick_params(top=False, bottom=True, left=False, right=False,
                          labeltop=False, labelleft=False, labelright=False, labelbottom=True)
    else:
        ax[i].get_xaxis().set_visible(False)
        ax[i].tick_params(top=False, bottom=False, left=False, right=False,
                          labeltop=False, labelleft=False, labelright=False, labelbottom=False)
            
fig.suptitle(f'Top-level cell type count', y=1.03)  
plt.savefig('./figures/top_level_count_barplot.pdf', bbox_inches='tight')
plt.show()

### Spatial cell type map

In [None]:
# Spatial cell type map
for sample in sorted(adata.obs['sample'].unique()):
    print(sample)
    su.plot_poly_cells_cluster_by_sample(adata, sample, top_cmap, show_plaque=True, show_tau=True, linewidth=0.5,
                                        figscale=3, width=10, height=10, save=True, show=False)

In [None]:
# Spatial cell type map
for sample in sorted(adata.obs['sample'].unique()):
    print(sample)
    su.plot_poly_cells_cluster_by_sample(adata, sample, top_cmap, show_plaque=True, show_tau=True, 
                                         save_as_real_size=True, linewidth=0.5,
                                        figscale=3, width=10, height=10, save=True, show=False)

## Others

### Spatial map with Gfap 

In [None]:
cpl = sns.color_palette(['#0045db'], as_cmap=True)
cmap = ListedColormap(cpl)
sns.palplot(cpl)

In [None]:
# Spatial cell type 
for sample in sorted(sdata.obs['sample'].unique()):
    print(sample)
    
    current_key = f"{sample}_morph"
    current_index = sdata.obs['sample'] == sample
    # change to new color 
    sdata.uns[current_key]['colors'] = sdata.obs.loc[current_index, 'top_level'].cat.codes.values
    sdata.uns[current_key]['good_cells'] = sdata.obs.loc[current_index, 'orig_index'].astype(int).values
    
    su.plot_poly_cells_cluster_by_sample(sdata, sample, cmap, show_plaque=True, show_tau=True,
                                         save_as_real_size=True, linewidth=0.5, show_gfap=True,
                                        figscale=3, width=10, height=10, save='Gfap', show=False)

In [None]:
# Spatial cell type (only plot 9919)
for sample in ['AD_mouse9919']:
    print(sample)
    
    current_key = f"{sample}_morph"
    current_index = sdata.obs['sample'] == sample
    # change to new color 
    sdata.uns[current_key]['colors'] = sdata.obs.loc[current_index, 'top_level'].cat.codes.values
    sdata.uns[current_key]['good_cells'] = sdata.obs.loc[current_index, 'orig_index'].astype(int).values
    
    su.plot_poly_cells_cluster_by_sample(sdata, sample, cmap, show_plaque=False, show_tau=False,
                                         save_as_real_size=True, linewidth=0.1, show_gfap=True,
                                        figscale=3, width=10, height=10, 
                                         save='only_Gfap_01', show=False, output_dir=os.path.join(base_path, 'figures'))

### Spatial map of gene expression

In [None]:
curr_gene = 'CPLX1'
# subset_type = 'Oligo'
subset_type = None

cmap = sns.color_palette('viridis', as_cmap=True)
for sample in sorted(adata.obs['sample'].unique()):
    print(sample)
    current_key = f"{sample}_morph"
    
    if subset_type is not None:
        current_index = adata.obs['top_level'] == subset_type
        sdata = adata[current_index, :]
    else:
        sdata = adata

    current_index = sdata.obs['sample'] == sample
    sdata.uns[current_key]['good_cells'] = sdata.obs.loc[current_index, 'orig_index'].astype(int).values
        
    su.plot_poly_cells_expr_by_sample(sdata, sample, curr_gene, cmap, use_raw=False,
                                      figscale=30, width=10, height=10,
                                      show_plaque=True, show_tau=True, show_tau_cells=False,
                                      show=False, save=f'{curr_gene}_z')