# Spatial visualization

In [None]:
# Import Packages

%load_ext autoreload
%autoreload 2

import os
import warnings 
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.filters import threshold_otsu, gaussian
from skimage.morphology import remove_small_objects
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from anndata import AnnData, concat
from tqdm.notebook import tqdm

# Customized packages
from starmap.utilities import *
from starmap.sequencing import *
# from starmap.obj import STARMapDataset, load_data
# import starmap.analyze as anz
# import starmap.viz as viz

import starmap.sc_util as su

# test()

In [None]:
# Get functions 

import colorsys
from random import shuffle

def intervals(parts, start_point, end_point):
    duration = end_point - start_point
    part_duration = duration / parts
    return [((i * part_duration + (i + 1) * part_duration)/2) + start_point for i in range(parts)]

## IO

In [None]:
# Set path
base_path = 'Z:/Data/Analyzed/2021-11-23-Hu-MouseBrain/'
out_path = os.path.join(base_path, 'output')
fig_path = os.path.join(base_path, 'figures')

out_path = os.path.join(base_path, 'output')
if not os.path.exists(out_path):
    os.mkdir(out_path)
    
fig_path = os.path.join(base_path, 'figures')
if not os.path.exists(fig_path):
    os.mkdir(fig_path)

In [None]:
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')

In [None]:
sc.settings.figdir = fig_path
sc.set_figure_params(format='tif', dpi=150)

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

## Input

In [None]:
# Load new data
adata = sc.read_h5ad(os.path.join(out_path, '2022-04-24-Hu-TissueRIBOmap-level3.h5ad'))
adata

In [None]:
adata.uns['level_1_colors'] = np.array(['#eded58', '#356be8'], dtype=object)

In [None]:
adata.uns['level_1_color_list'] = np.array(['#eded58', '#356be8'], dtype=object)

In [None]:
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
adata.write_h5ad(f"{out_path}/2022-04-24-Hu-TissueRIBOmap-level3.h5ad")

orig resolution: 0.09 micron / pixel

In [None]:
x = round(23215 / 0.3 * 0.09 / 1000, 5)
y = round(18332 / 0.3 * 0.09 / 1000, 5)
print(f'X: {x} mm - Y: {y} mm')

## Colors

### level-3

In [None]:
# Get colormap
level_3_colors = adata.uns['level_3_color_list']
level_3_pl = sns.color_palette(level_3_colors)
level_3_cmap = ListedColormap(level_3_pl.as_hex())

level_3_order = adata.uns['level_3_order']
sns.palplot(level_3_pl)
plt.xticks(range(len(level_3_order)), level_3_order, size=10, rotation=45)
plt.tight_layout()
plt.show()

## Spatial map

### scatter plot

In [None]:
save_as = True

In [None]:
clustermap_center_path = os.path.join(base_path, 'RIBOmap', 'cell_center_polished.csv')
cell_center_df = pd.read_csv(clustermap_center_path, index_col=0)

In [None]:
sc.set_figure_params(format='tif', dpi=150)

dot_size = 12
fig, ax = plt.subplots(figsize=(18, 23))
sns.scatterplot(x='column', y='row', data=cell_center_df, color='#787878', s=dot_size, legend=False, edgecolor=None, ax=ax)
sns.scatterplot(x='column', y='row', hue='level_3', data=adata.obs, palette=level_3_pl, s=dot_size, edgecolor=None, legend=False, ax=ax)
ax.axis('off')
plt.tight_layout(pad=0)

if save_as:
    plt.savefig(os.path.join(fig_path, 'spatial-map-level-3.tif'))
plt.show()

### label image

In [None]:
# Create label image from clustermap output
from skimage.segmentation import expand_labels

# load clustermap results
reads_file_path = os.path.join(base_path, 'RIBOmap', 'remain_reads_polished.csv')
reads_df = pd.read_csv(reads_file_path, index_col=0)

center_file_path = os.path.join(base_path, 'RIBOmap', 'cell_center_polished.csv')
centers_df = pd.read_csv(center_file_path, index_col=0)

# modify reads filexh
reads_df['spot_location_1'] = reads_df['spot_location_1'] * 0.3
reads_df['spot_location_2'] = reads_df['spot_location_2'] * 0.3
reads_df['spot_location_1'] = reads_df['spot_location_1'].astype(int)
reads_df['spot_location_2'] = reads_df['spot_location_2'].astype(int)

row = reads_df['spot_location_2'].max() + 1
col = reads_df['spot_location_1'].max() + 1

In [None]:
cell_barcode_dict = dict(zip(centers_df['cell_barcode'].values, centers_df.index.values))
reads_df['orig_index'] = reads_df['cell_barcode'].map(cell_barcode_dict)

In [None]:
# create label image 
label_img = np.zeros([row, col], dtype=np.uint32)
label_img[reads_df['spot_location_2'].values, reads_df['spot_location_1'].values] = reads_df['orig_index'].values + 1
# label_img = expand_labels(label_img, distance=1)

label_img_out_path = os.path.join(base_path, 'RIBOmap', 'labeled_cells_no_extend.tif')
tifffile.imwrite(label_img_out_path, label_img)

### polygon map

In [None]:
# load segmentation 
current_seg_path = os.path.join(base_path, 'RIBOmap', 'labeled_cells_no_extend.tif')
current_img = tifffile.imread(current_seg_path)

# Store the images to adata object
current_key = f"RIBOmap_morph"
adata.uns[current_key] = {}
adata.uns[current_key]['label_img'] = current_img

In [None]:
# Contruct polygon
adata.uns[current_key]['qhulls'], adata.uns[current_key]['coords'], adata.uns[current_key]['centroids'] = su.get_qhulls_test(adata.uns[current_key]['label_img'])

#### level-3

In [None]:
# Add data
adata.uns[current_key]['good_cells'] = adata.obs.index.astype(int).values
adata.uns[current_key]['colors'] = adata.obs.loc[:, 'level_3'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(adata, 'RIBOmap', level_3_cmap, show_plaque=False, show_tau=False, linewidth=0.5,
                                        figscale=3, width=10, height=10, bg_color='#787878', save_as_real_size=True,
                                         save=True, show=False, output_dir=fig_path)

#### Glia

In [None]:
# Subset adata
sdata = adata[adata.obs['level_1'] == 'Glia', ]
sdata

In [None]:
# Get colormap
glia_order = []
glia_colors = []

for i, current_type in enumerate(sdata.uns['level_3_order']):
    if current_type in sdata.obs['level_3'].unique():
        glia_order.append(current_type)
        glia_colors.append(sdata.uns['level_3_color_list'][i])
        
        
glia_pl = sns.color_palette(glia_colors)
glia_cmap = ListedColormap(glia_pl.as_hex())
sns.palplot(glia_pl)
plt.xticks(range(len(glia_order)), glia_order, size=10, rotation=45)
plt.tight_layout()
plt.show()

In [None]:
current_fig_path = os.path.join(fig_path, 'Glia', 'sct')
if not os.path.exists(current_fig_path):
    os.mkdir(current_fig_path)

# Add data
sdata.uns[current_key]['good_cells'] = sdata.obs.index.astype(int).values
sdata.uns[current_key]['colors'] = sdata.obs.loc[:, 'level_3'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(sdata, 'RIBOmap', glia_cmap, show_plaque=False, show_tau=False, linewidth=0.5,
                                        figscale=3, width=10, height=10, save_as_real_size=True,
                                         save=True, show=False, output_dir=current_fig_path)

#### Excitatory neuron

In [None]:
# Subset adata
sdata = adata[adata.obs['level_2'] == 'Excitatory neuron', ]
sdata

In [None]:
# Get colormap
ex_order = []
ex_colors = []

for i, current_type in enumerate(sdata.uns['level_3_order']):
    if current_type in sdata.obs['level_3'].unique():
        ex_order.append(current_type)
        ex_colors.append(sdata.uns['level_3_color_list'][i])
        
        
ex_pl = sns.color_palette(ex_colors)
ex_cmap = ListedColormap(ex_pl.as_hex())
sns.palplot(ex_pl)
plt.xticks(range(len(ex_order)), ex_order, size=10, rotation=45)
plt.tight_layout()
plt.show()

In [None]:
current_fig_path = os.path.join(fig_path, 'Excitatory neuron', 'sct')
if not os.path.exists(current_fig_path):
    os.mkdir(current_fig_path)

# Add data
sdata.uns[current_key]['good_cells'] = sdata.obs.index.astype(int).values
sdata.uns[current_key]['colors'] = sdata.obs.loc[:, 'level_3'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(sdata, 'RIBOmap', ex_cmap, show_plaque=False, show_tau=False, linewidth=0.5,
                                        figscale=3, width=10, height=10, save_as_real_size=True,
                                         save=True, show=False, output_dir=current_fig_path)

#### Inhibitory neuron

In [None]:
# Subset adata
sdata = adata[adata.obs['level_2'] == 'Inhibitory neuron', ]
sdata

In [None]:
# Get colormap
inh_order = []
inh_colors = []

for i, current_type in enumerate(sdata.uns['level_3_order']):
    if current_type in sdata.obs['level_3'].unique():
        inh_order.append(current_type)
        inh_colors.append(sdata.uns['level_3_color_list'][i])
        
        
inh_pl = sns.color_palette(inh_colors)
inh_cmap = ListedColormap(inh_pl.as_hex())
sns.palplot(inh_pl)
plt.xticks(range(len(inh_order)), inh_order, size=10, rotation=45)
plt.tight_layout()
plt.show()

In [None]:
current_fig_path = os.path.join(fig_path, 'Inhibitory neuron', 'sct')
if not os.path.exists(current_fig_path):
    os.mkdir(current_fig_path)

# Add data
sdata.uns[current_key]['good_cells'] = sdata.obs.index.astype(int).values
sdata.uns[current_key]['colors'] = sdata.obs.loc[:, 'level_3'].cat.codes.values

# Plot
su.plot_poly_cells_cluster_by_sample(sdata, 'RIBOmap', inh_cmap, show_plaque=False, show_tau=False, linewidth=0.5,
                                        figscale=3, width=10, height=10, save_as_real_size=True,
                                         save=True, show=False, output_dir=current_fig_path)

### polygon map with dots

In [None]:
# load reads 
good_reads = pd.read_csv(os.path.join(base_path, 'RIBOmap', 'remain_reads_polished.csv'), index_col=0)
background_reads = pd.read_csv(os.path.join(base_path, 'RIBOmap', 'background_reads_polished.csv'), index_col=0)
cell_center_df = pd.read_csv(os.path.join(base_path, 'RIBOmap', 'cell_center_polished.csv'), index_col=0)

all_reads = pd.concat([good_reads, background_reads])

In [None]:
# modify reads filexh
good_reads['spot_location_1'] = good_reads['spot_location_1'] * 0.3
good_reads['spot_location_2'] = good_reads['spot_location_2'] * 0.3
good_reads['spot_location_1'] = good_reads['spot_location_1'].astype(int)
good_reads['spot_location_2'] = good_reads['spot_location_2'].astype(int)

row = good_reads['spot_location_2'].max() + 1
col = good_reads['spot_location_1'].max() + 1

background_reads['spot_location_1'] = background_reads['spot_location_1'] * 0.3
background_reads['spot_location_2'] = background_reads['spot_location_2'] * 0.3
background_reads['spot_location_1'] = background_reads['spot_location_1'].astype(int)
background_reads['spot_location_2'] = background_reads['spot_location_2'].astype(int)

In [None]:
cell_barcode_dict = dict(zip(cell_center_df['cell_barcode'].values, cell_center_df.index.values))
good_reads['orig_index'] = good_reads['cell_barcode'].map(cell_barcode_dict)

In [None]:
# create scaled coords
adata.obs['column_scaled'] = adata.obs['column'] * .3
adata.obs['row_scaled'] = adata.obs['row'] * .3
adata.obs['column_scaled'] = adata.obs['column_scaled'].astype(int)
adata.obs['row_scaled'] = adata.obs['row_scaled'].astype(int)

cell_center_df['column_scaled'] = cell_center_df['column'] * .3
cell_center_df['row_scaled'] = cell_center_df['row'] * .3
cell_center_df['column_scaled'] = cell_center_df['column_scaled'].astype(int)
cell_center_df['row_scaled'] = cell_center_df['row_scaled'].astype(int)

good_reads['column_scaled'] = good_reads['spot_location_1']
good_reads['row_scaled'] = good_reads['spot_location_2']

background_reads['column_scaled'] = background_reads['spot_location_1']
background_reads['row_scaled'] = background_reads['spot_location_2']

In [None]:
box_length = 5000
# row_start = 18000
# col_start = 7500

row_start = 15000
col_start = 8000

row_end = row_start + box_length
col_end = col_start + box_length

dot_size = 7
fig, ax = plt.subplots(figsize=(18, 23))
sns.scatterplot(x='column_scaled', y='row_scaled', data=cell_center_df, color='#787878', s=dot_size, legend=False, edgecolor=None, ax=ax)

rect = patches.Rectangle((col_start, row_start), box_length, box_length, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
# plt.axis('off')
plt.show()

In [None]:
# create subsets 
adata_logical = (adata.obs['column_scaled'].isin(range(col_start, col_end))) & (adata.obs['row_scaled'].isin(range(row_start, row_end)))
sdata = adata[adata_logical, :]

reads_logical = (good_reads['column_scaled'].isin(range(col_start, col_end))) & (good_reads['row_scaled'].isin(range(row_start, row_end)))
sreads_df = good_reads.loc[reads_logical, :]

bg_reads_logical = (background_reads['column_scaled'].isin(range(col_start, col_end))) & (background_reads['row_scaled'].isin(range(row_start, row_end)))
bgreads_df = background_reads.loc[bg_reads_logical, :]

In [None]:
# check subset
dot_size = 7
fig, ax = plt.subplots(figsize=(18, 23))

sns.scatterplot(x='column_scaled', y='row_scaled', data=cell_center_df, color='#787878', s=dot_size, legend=False, edgecolor=None, ax=ax)
sns.scatterplot(x='column_scaled', y='row_scaled', data=sdata.obs, color='red', s=dot_size, legend=False, edgecolor=None, ax=ax)
sns.scatterplot(x='column_scaled', y='row_scaled', data=sreads_df, color='blue', s=.1, legend=False, edgecolor=None, ax=ax)
sns.scatterplot(x='column_scaled', y='row_scaled', data=bgreads_df, color='green', s=.1, legend=False, edgecolor=None, ax=ax)

plt.show()

In [None]:
# subset image 
sdata.uns['RIBOmap_morph']['label_img_crop'] = sdata.uns['RIBOmap_morph']['label_img'][row_start:row_end, col_start:col_end]
sdata.uns['RIBOmap_morph']['qhulls'], sdata.uns['RIBOmap_morph']['coords'], sdata.uns['RIBOmap_morph']['centroids'] = su.get_qhulls_test(sdata.uns['RIBOmap_morph']['label_img_crop'])

In [None]:
sdata.obs['column_crop'] = sdata.obs['column_scaled'] - sdata.obs['column_scaled'].min()
sdata.obs['row_crop'] = sdata.obs['row_scaled'] - sdata.obs['row_scaled'].min()

sreads_df['column_crop'] = sreads_df['column_scaled'] - sreads_df['column_scaled'].min()
sreads_df['row_crop'] = sreads_df['row_scaled'] - sreads_df['row_scaled'].min()

bgreads_df['column_crop'] = bgreads_df['column_scaled'] - bgreads_df['column_scaled'].min()
bgreads_df['row_crop'] = bgreads_df['row_scaled'] - bgreads_df['row_scaled'].min()

In [None]:
level_3_transfer_dict = dict(zip(sdata.obs.index.astype(int), sdata.obs['level_3']))
sreads_df['level_3'] = sreads_df['orig_index'].map(level_3_transfer_dict)
sreads_df.loc[sreads_df['level_3'].isna(), 'level_3'] = 'Unknown'
sreads_df['level_3'] = sreads_df['level_3'].astype('category')

subset_order = []
subset_colors = []
for i, current_cat in enumerate(sdata.uns['level_3_order']):
   #  print(current_cat)
    if current_cat in sreads_df['level_3'].values:
        subset_order.append(current_cat)
        subset_colors.append(sdata.uns['level_3_color_list'][i])
        
sreads_df['level_3'] = sreads_df['level_3'].cat.reorder_categories(subset_order)

subset_pl = sns.color_palette(subset_colors)
subset_cmap = ListedColormap(subset_pl.as_hex())
sns.palplot(subset_pl)

In [None]:
from matplotlib.collections import PatchCollection
sample = 'RIBOmap'
sample_key = f"{sample}_morph"
nissl = sdata.uns[sample_key]['label_img_crop']
hulls = sdata.uns[sample_key]['qhulls']

sdata.uns[sample_key]['good_cells'] = sdata.obs.index.astype(int).values
sdata.uns[sample_key]['colors'] = sdata.obs.loc[:, 'level_3'].cat.codes.values
colors = sdata.uns[sample_key]['colors']
good_cells = sdata.uns[sample_key]['good_cells']

save_as_real_size = False
figscale = 10
if save_as_real_size:
    plt.figure(figsize=(nissl.shape[0]/1000, nissl.shape[1]/1000), dpi=100)
else:
    plt.figure(figsize=(nissl.shape[0]/1000 * figscale, nissl.shape[1]/1000 * figscale), dpi=100)

polys = []
for h in hulls:
    if h == []:
        polys.append([])
    else:
        polys.append(su.hull_to_polygon(h))

    
if good_cells is not None:
    others = [p for i, p in enumerate(polys) if i not in good_cells and p != []]
    polys = [p for i, p in enumerate(polys) if i in good_cells]

alpha = 1
linewidth = .5
p = PatchCollection(polys, alpha=alpha, cmap=level_3_cmap, edgecolor='k', linewidth=linewidth, zorder=3)

other_cmap = sns.color_palette(['#ffffff']) 
other_cmap = ListedColormap(other_cmap)
o = PatchCollection(others, alpha=1, cmap=other_cmap, edgecolor='k', linewidth=1, zorder=1)

vmin = None
vmax = None
rescale_colors = False
if vmin or vmax is not None:
    p.set_array(colors)
    p.set_clim(vmin=vmin, vmax=vmax)
else:
    if rescale_colors:
        p.set_array(colors+1)
        p.set_clim(vmin=0, vmax=max(colors+1))
    else:
        p.set_array(colors)
        p.set_clim(vmin=0, vmax=len(level_3_cmap.colors))

        o_colors = np.ones(len(others)).astype(int)
        o.set_array(o_colors)
        o.set_clim(vmin=0, vmax=max(o_colors))


nissl = (nissl > 0).astype(np.int)
plt.imshow(nissl.T, cmap=plt.get_cmap('gray_r'), alpha=0, zorder=1)

# plt.gca().add_collection(p)
plt.gca().add_collection(o)

sns.scatterplot(x='row_crop', y='column_crop', hue='level_3', data=sreads_df, palette=subset_pl, s=3, legend=False, edgecolor=None, alpha=.5)

plt.axis('off')
plt.tight_layout(pad=0)

current_fig_path = f"{fig_path}/sct_RIBOmap_dots.tif"
plt.savefig(current_fig_path, bbox_inches='tight', pad_inches=0)
plt.show()   

In [None]:
from matplotlib.patches import Polygon
sample = 'RIBOmap'
sample_key = f"{sample}_morph"
nissl = sdata.uns[sample_key]['label_img_crop']
hulls = sdata.uns[sample_key]['qhulls']

sdata.uns[sample_key]['good_cells'] = sdata.obs.index.astype(int).values
sdata.uns[sample_key]['colors'] = sdata.obs.loc[:, 'level_3'].cat.codes.values
colors = sdata.uns[sample_key]['colors']
good_cells = sdata.uns[sample_key]['good_cells']

save_as_real_size = False
figscale = 10
if save_as_real_size:
    plt.figure(figsize=(nissl.shape[0]/1000, nissl.shape[1]/1000), dpi=100)
else:
    plt.figure(figsize=(nissl.shape[0]/1000 * figscale, nissl.shape[1]/1000 * figscale), dpi=100)

polys = []
for h in hulls:
    if h == []:
        polys.append([])
    else:
        polys.append(su.hull_to_polygon(h))

    
if good_cells is not None:
    others = [p for i, p in enumerate(polys) if i not in good_cells and p != []]
    polys = [p for i, p in enumerate(polys) if i in good_cells]

alpha = 1
linewidth = .5
p = PatchCollection(polys, alpha=alpha, cmap=level_3_cmap, edgecolor='k', linewidth=linewidth, zorder=3)

other_cmap = sns.color_palette(['#ffffff']) 
other_cmap = ListedColormap(other_cmap)
o = PatchCollection(others, alpha=1, cmap=other_cmap, edgecolor='k', facecolor=None, linewidth=1, zorder=1)

vmin = None
vmax = None
rescale_colors = False
if vmin or vmax is not None:
    p.set_array(colors)
    p.set_clim(vmin=vmin, vmax=vmax)
else:
    if rescale_colors:
        p.set_array(colors+1)
        p.set_clim(vmin=0, vmax=max(colors+1))
    else:
        p.set_array(colors)
        p.set_clim(vmin=0, vmax=len(level_3_cmap.colors))

        o_colors = np.ones(len(others)).astype(int)
        o.set_array(o_colors)
        o.set_clim(vmin=0, vmax=max(o_colors))


nissl = (nissl > 0).astype(np.int)
plt.imshow(nissl.T, cmap=plt.get_cmap('gray_r'), alpha=0, zorder=1)

# plt.gca().add_collection(p)
plt.gca().add_collection(o)

sns.scatterplot(x='row_crop', y='column_crop', color='b', data=sreads_df, s=3, legend=False, edgecolor=None, alpha=.5)
sns.scatterplot(x='row_crop', y='column_crop', color='r', data=bgreads_df, s=3, legend=False, edgecolor=None, alpha=.5)

# plt.gca().add_collection(o)

# for patch in others:
#     plt.gca().add_patch(Polygon(patch.get_xy(), closed=True, ec='k', lw=1, fill=False))

plt.axis('off')
plt.tight_layout(pad=0)

current_fig_path = f"{fig_path}/sct_RIBOmap_dots_2.tif"
plt.savefig(current_fig_path, bbox_inches='tight', pad_inches=0)
plt.show()   

### gene expression plots

In [None]:
n = 50
save_as = True
# plot_genes = adata.var.sort_values('total_counts', ascending=False).head(n).index.to_list()
# plot_genes = plot_genes + ['Gad1', 'Cux2', 'Rorb', 'Pcp4', 'Gad2', 'Sst', 'Gfap']
# plot_genes = ['Sst', 'Gad1', 'Plp1', 'Slc17a7', 'Pcp4', 'Gad2', 'Gfap']
plot_genes = [
'Gad1', 'Gad2', 'Slc32a1',
'Slc17a7',
'Vip', 'Sst', 'Chodl', 'Pvalb',
'Cux2', 'Rorb', 'Pcp4', 'Nr4a2',
'Aqp4', 'Aldoc', 'Gfap',
'Mbp', 'Mobp', 'Plp1', 'Pdgfra',
'Ctss', 'C1qa', 'Vtn', 'Bsg', 'Ptgds', 'Dcn',
]

sc.set_figure_params(format='tif', dpi=150)

for gene in tqdm(plot_genes):
    dot_size = .5
    fig, ax = plt.subplots(figsize=(4, 5))
    
    # sns.scatterplot(x='column', y='row', data=cell_center_df, color='#fffffa', s=dot_size, legend=False, edgecolor=None, ax=ax)
    
    current_expr = adata[:, gene].layers['norm'].flatten() # / adata[:, gene].layers['norm'].flatten().max()
    sns.scatterplot(x='column', y='row', data=adata.obs, hue=current_expr, palette='Reds', s=dot_size, legend=False, edgecolor=None, ax=ax)
    # plt.title(gene)
    ax.axis('off')
    plt.tight_layout(pad=0)
    if save_as:
        expr_fig_path = os.path.join(fig_path, f'{date}-expr-norm')
        if not os.path.exists(expr_fig_path):
            os.mkdir(expr_fig_path)
        plt.savefig(os.path.join(expr_fig_path, f'{gene}.tif'))
    plt.show()

### gene expression plots w/ STARmap

In [None]:
sdata = sc.read_h5ad(os.path.join(out_path, '2022-02-10-Hu-TissueSTARmap-mad-filtered.h5ad'))
sdata.layers['raw'] = sdata.X
sc.pp.normalize_total(sdata)
sc.pp.log1p(sdata)
sdata.layers['norm'] = sdata.X
sdata

In [None]:
n = 50
save_as = False
expr_layer = 'norm'
# plot_genes = adata.var.sort_values('total_counts', ascending=False).head(n).index.to_list()
# plot_genes = plot_genes + ['Gad1', 'Cux2', 'Rorb', 'Pcp4', 'Gad2', 'Sst', 'Gfap']
plot_genes = ['Sst', 'Gad1', 'Plp1', 'Slc17a7', 'Pcp4', 'Gad2', 'Gfap']
sc.set_figure_params(format='tif', dpi=150)

for gene in tqdm(plot_genes):
    dot_size = .5
    fig, axs = plt.subplots(figsize=(8, 5), ncols=2)
    sns.scatterplot(x='column', y='row', data=cell_center_df, color='#fffffa', s=dot_size, legend=False, edgecolor=None, ax=axs[0])
    
    ribo_expr = adata[:, gene].layers[expr_layer].flatten() # / adata[:, gene].layers['norm'].flatten().max()
    sns.scatterplot(x='column', y='row', data=adata.obs, hue=ribo_expr, palette='Reds', s=dot_size, legend='auto', edgecolor=None, ax=axs[0])
    
    star_expr = sdata[:, gene].layers[expr_layer].flatten() # / adata[:, gene].layers['norm'].flatten().max()
    sns.scatterplot(x='column', y='row', data=sdata.obs, hue=star_expr, palette='Reds', s=dot_size, legend='auto', edgecolor=None, ax=axs[1])
    axs[0].axis('off')
    axs[1].axis('off')
    plt.tight_layout(pad=0)
    if save_as:
        expr_fig_path = os.path.join(fig_path, 'expr_combined')
        if not os.path.exists(expr_fig_path):
            os.mkdir(expr_fig_path)
        plt.savefig(os.path.join(expr_fig_path, f'{gene}.tif'))
    plt.show()

In [None]:
n = 50
save_as = True
expr_layer = 'norm'
# plot_genes = adata.var.sort_values('total_counts', ascending=False).head(n).index.to_list()
# plot_genes = plot_genes + ['Gad1', 'Cux2', 'Rorb', 'Pcp4', 'Gad2', 'Sst', 'Gfap']
# plot_genes = ['Sst', 'Gad1', 'Plp1', 'Slc17a7', 'Pcp4', 'Gad2', 'Gfap']
plot_genes = ['Ttc3','Hpca', 'Itpka', 'Pcp4', 'Arfgef1', 'Dcx', 'Map2', 'Nefl', 'Nptx1', 'Prkce', 'Rundc3a', 'Rap1gap', 'Scgn', 'Chat', 'Hcrt', 'Gpr151', 'Th', 'Slc18a3', 'Cbln3', 'Dpp6', 'Tbr1', 'Camkv',
            'Adcyap1', 'Gal', 'Slc17a6', 'Sncg', 'Slc6a2', 'Aldoc', 'Gng2', 'Igfbp5', 'Lrpap1', 'Dact2', 'Arc', 'Emx2', 'Akap5', 'Cck', 'Necab1', 'Calb2', 'Slc30a3', 'Cartpt', 'Adora2a', 'Agrp', 'Avp',
              'Cdh15', 'Cnp', 'Drd1', 'Drd2', 'Gng7', 'Gpr88', 'Homer3', 'Opt', 'Oxt', 'Pax5', 'Gs9', 'Slc1a6', 'Slc6a4', 'Slc6a5', 'Bcl11b', 'Calcr', 'Cnga3', 'Crabp1']
plot_genes = [g for g in plot_genes if g in adata.var.index]
sc.set_figure_params(format='tif', dpi=150)

if save_as:
    expr_fig_path = os.path.join(fig_path, f'{date}-expr-combined-{expr_layer}')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)
            
for gene in tqdm(plot_genes):
    dot_size = .5
    
    # RIBOmap
    fig, ax = plt.subplots(figsize=(4, 5))
    sns.scatterplot(x='column', y='row', data=cell_center_df, color='#fffffa', s=dot_size, legend=False, edgecolor=None, ax=ax)

    ribo_expr = adata[:, gene].layers[expr_layer].flatten() # / adata[:, gene].layers['norm'].flatten().max()
    sns.scatterplot(x='column', y='row', data=adata.obs, hue=ribo_expr, palette='Reds', s=dot_size, legend=False, edgecolor=None, ax=ax)
    ax.axis('off')
    plt.tight_layout(pad=0)
    if save_as:
        plt.savefig(os.path.join(expr_fig_path, f'{gene}_RIBOmap.tif'))
    plt.show()
    
    # STARmap
    fig, ax = plt.subplots(figsize=(4, 5))
    sns.scatterplot(x='column', y='row', data=cell_center_df, color='#fffffa', s=dot_size, legend=False, edgecolor=None, ax=ax)

    star_expr = sdata[:, gene].layers[expr_layer].flatten() # / adata[:, gene].layers['norm'].flatten().max()
    sns.scatterplot(x='column', y='row', data=sdata.obs, hue=star_expr, palette='Reds', s=dot_size, legend=False, edgecolor=None, ax=ax)
    ax.axis('off')
    plt.tight_layout(pad=0)
    if save_as:
        plt.savefig(os.path.join(expr_fig_path, f'{gene}_STARmap.tif'))
    plt.show()

### gene expression plots (all reads)

In [None]:
# load reads 
good_reads = pd.read_csv(os.path.join(base_path, 'RIBOmap', 'remain_reads_polished.csv'), index_col=0)
background_reads = pd.read_csv(os.path.join(base_path, 'RIBOmap', 'background_reads_polished.csv'), index_col=0)
all_reads = pd.concat([good_reads, background_reads])

In [None]:
# load gene lists
list_1 = pd.read_excel(os.path.join(base_path, 'gene-module', '2022-04-11-Hu-gene-list.xlsx'), sheet_name='list_1')
list_2 = pd.read_excel(os.path.join(base_path, 'gene-module', '2022-04-11-Hu-gene-list.xlsx'), sheet_name='list_2')
list_3 = pd.read_excel(os.path.join(base_path, 'gene-module', '2022-04-11-Hu-gene-list.xlsx'), sheet_name='list_3')
list_4 = pd.read_excel(os.path.join(base_path, 'gene-module', '2022-04-11-Hu-gene-list.xlsx'), sheet_name='list_4')
list_5 = pd.read_excel(os.path.join(base_path, 'gene-module', '2022-04-11-Hu-gene-list.xlsx'), sheet_name='list_5')
list_6 = pd.read_excel(os.path.join(base_path, 'gene-module', '2022-04-11-Hu-gene-list.xlsx'), sheet_name='list_6')
list_7 = pd.read_excel(os.path.join(base_path, 'gene-module', '2022-04-11-Hu-gene-list.xlsx'), sheet_name='list_7')

In [None]:
profiled_genes = all_reads['gene'].unique()
print(len(profiled_genes))

In [None]:
reads_count = pd.DataFrame(all_reads['gene'].value_counts())
reads_count.columns = ['count']
top_500_genes = reads_count.iloc[:500, :]
top_500_genes = top_500_genes.index.to_list()

In [None]:
# plot list_1
save_as = True
dot_size = 1

if save_as:
    expr_fig_path = os.path.join(fig_path, f'{date}-expr-reads', 'list-1')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)

for i, current_gene in enumerate(tqdm(list_1['gene'])):
    
    if current_gene in top_500_genes:
        
        current_reads = all_reads.loc[all_reads['gene'] == current_gene, :]

        fig, ax = plt.subplots(figsize=[40,50])

        # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=all_reads, color='#ededed', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')
        
        sns.scatterplot(x='spot_location_1', y='spot_location_2', data=good_reads, color='#dbdbdb', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')

        sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(current_reads.loc[:,'spot_location_1'], current_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='r')

        ax.axis('off')
        plt.tight_layout(pad=0)

        if save_as:
            plt.savefig(os.path.join(expr_fig_path, f'{current_gene}_RIBOmap.tif'))
            plt.close()
        else:
            plt.show()
            


In [None]:
# plot list_2
save_as = True
dot_size = 1

if save_as:
    expr_fig_path = os.path.join(fig_path, f'{date}-expr-reads', 'list-2')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)

for i, current_gene in enumerate(tqdm(list_2['gene'])):
    
    if current_gene in top_500_genes:
        
        current_reads = all_reads.loc[all_reads['gene'] == current_gene, :]

        fig, ax = plt.subplots(figsize=[40,50])

        # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=all_reads, color='#ededed', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')

        sns.scatterplot(x='spot_location_1', y='spot_location_2', data=good_reads, color='#dbdbdb', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')
        
        sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(current_reads.loc[:,'spot_location_1'], current_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='r')

        ax.axis('off')
        plt.tight_layout(pad=0)

        if save_as:
            plt.savefig(os.path.join(expr_fig_path, f'{current_gene}_RIBOmap.tif'))
            plt.close()
        else:
            plt.show()
            


In [None]:
# plot list_3
save_as = True
dot_size = 1

if save_as:
    expr_fig_path = os.path.join(fig_path, f'{date}-expr-reads', 'list-3')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)

for i, current_gene in enumerate(tqdm(list_3['gene'])):
    
    if current_gene in top_500_genes:
        
        current_reads = all_reads.loc[all_reads['gene'] == current_gene, :]

        fig, ax = plt.subplots(figsize=[40,50])

        # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=all_reads, color='#ededed', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')

        sns.scatterplot(x='spot_location_1', y='spot_location_2', data=good_reads, color='#dbdbdb', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')
        
        sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(current_reads.loc[:,'spot_location_1'], current_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='r')

        ax.axis('off')
        plt.tight_layout(pad=0)

        if save_as:
            plt.savefig(os.path.join(expr_fig_path, f'{current_gene}_RIBOmap.tif'))
            plt.close()
        else:
            plt.show()
            


In [None]:
# plot list_4
save_as = True
dot_size = 1

if save_as:
    expr_fig_path = os.path.join(fig_path, f'{date}-expr-reads', 'list-4')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)

for i, current_gene in enumerate(tqdm(list_4['gene'])):
    
    if current_gene in top_500_genes:
        
        current_reads = all_reads.loc[all_reads['gene'] == current_gene, :]

        fig, ax = plt.subplots(figsize=[40,50])

        # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=all_reads, color='#ededed', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')

        sns.scatterplot(x='spot_location_1', y='spot_location_2', data=good_reads, color='#dbdbdb', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')
        
        sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax)
        # plt.scatter(current_reads.loc[:,'spot_location_1'], current_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='r')

        ax.axis('off')
        plt.tight_layout(pad=0)

        if save_as:
            plt.savefig(os.path.join(expr_fig_path, f'{current_gene}_RIBOmap.tif'))
            plt.close()
        else:
            plt.show()
            


In [None]:
current_gene = 'Atp1a1'
current_soma_reads = good_reads.loc[good_reads['gene'] == current_gene, :]
current_non_soma_reads = background_reads.loc[background_reads['gene'] == current_gene, :]

dot_size = .3

from matplotlib.collections import PatchCollection
sample = 'RIBOmap'
sample_key = f"{sample}_morph"
nissl = adata.uns[sample_key]['label_img']
hulls = adata.uns[sample_key]['qhulls']

# colors = adata.uns[sample_key]['colors']
colors = []
good_cells = []

save_as_real_size = True
figscale = 10
if save_as_real_size:
    plt.figure(figsize=(nissl.shape[0]/1000, nissl.shape[1]/1000), dpi=100)
else:
    plt.figure(figsize=(nissl.shape[0]/1000 * figscale, nissl.shape[1]/1000 * figscale), dpi=100)

polys = []
for h in hulls:
    if h == []:
        polys.append([])
    else:
        polys.append(su.hull_to_polygon(h))

    
if good_cells is not None:
    others = [p for i, p in enumerate(polys) if i not in good_cells and p != []]
    polys = [p for i, p in enumerate(polys) if i in good_cells]

alpha = 1
linewidth = .1
# p = PatchCollection(polys, alpha=alpha, cmap=level_3_cmap, edgecolor='k', linewidth=linewidth, zorder=3)

other_cmap = sns.color_palette(['#ededed']) 
other_cmap = ListedColormap(other_cmap)
o = PatchCollection(others, alpha=1, cmap=other_cmap, edgecolor='k', linewidth=0, zorder=1)

vmin = None
vmax = None
rescale_colors = False
# if vmin or vmax is not None:
#     p.set_array(colors)
#     p.set_clim(vmin=vmin, vmax=vmax)
# else:
#     if rescale_colors:
#         p.set_array(colors+1)
#         p.set_clim(vmin=0, vmax=max(colors+1))
#     else:
#         p.set_array(colors)
#         p.set_clim(vmin=0, vmax=len(level_3_cmap.colors))

o_colors = np.ones(len(others)).astype(int)
o.set_array(o_colors)
o.set_clim(vmin=0, vmax=max(o_colors))


nissl = (nissl > 0).astype(np.int)
plt.imshow(nissl.T, cmap=plt.get_cmap('gray_r'), alpha=0, zorder=1)

# plt.gca().add_collection(p)
plt.gca().add_collection(o)

current_soma_reads['spot_location_1'] = current_soma_reads['spot_location_1'] * 0.3
current_soma_reads['spot_location_2'] = current_soma_reads['spot_location_2'] * 0.3
current_soma_reads['spot_location_1'] = current_soma_reads['spot_location_1'].astype(int)
current_soma_reads['spot_location_2'] = current_soma_reads['spot_location_2'].astype(int)

current_non_soma_reads['spot_location_1'] = current_non_soma_reads['spot_location_1'] * 0.3
current_non_soma_reads['spot_location_2'] = current_non_soma_reads['spot_location_2'] * 0.3
current_non_soma_reads['spot_location_1'] = current_non_soma_reads['spot_location_1'].astype(int)
current_non_soma_reads['spot_location_2'] = current_non_soma_reads['spot_location_2'].astype(int)


sns.scatterplot(x='spot_location_2', y='spot_location_1', data=current_non_soma_reads, color='r', s=dot_size, legend=False, edgecolor=None, markers='.')
sns.scatterplot(x='spot_location_2', y='spot_location_1', data=current_soma_reads, color='b', s=dot_size, legend=False, edgecolor=None, markers='.')


plt.axis('off')
plt.tight_layout(pad=0)

current_fig_path = f"{fig_path}/test_{current_gene}.tif"
plt.savefig(current_fig_path, bbox_inches='tight', pad_inches=0)
plt.show()   

In [None]:
# plot list_5
save_as = True
dot_size = 1

if save_as:
    expr_fig_path = os.path.join(fig_path, f'2022-04-12-expr-reads', 'list-5')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)

for i, current_gene in enumerate(tqdm(list_5['gene'])):
    
    current_reads = all_reads.loc[all_reads['gene'] == current_gene, :]
    current_reads['spot_location_1'] = current_reads['spot_location_1'] * 0.3
    current_reads['spot_location_2'] = current_reads['spot_location_2'] * 0.3
    current_reads['spot_location_1'] = current_reads['spot_location_1'].astype(int)
    current_reads['spot_location_2'] = current_reads['spot_location_2'].astype(int)
    
    sample = 'RIBOmap'
    sample_key = f"{sample}_morph"
    nissl = adata.uns[sample_key]['label_img']
    hulls = adata.uns[sample_key]['qhulls']

    colors = []
    good_cells = []

    plt.figure(figsize=(nissl.shape[0]/1000*3, nissl.shape[1]/1000*3), dpi=150)

    polys = []
    for h in hulls:
        if h == []:
            polys.append([])
        else:
            polys.append(su.hull_to_polygon(h))


    if good_cells is not None:
        others = [p for i, p in enumerate(polys) if i not in good_cells and p != []]
        polys = [p for i, p in enumerate(polys) if i in good_cells]
    

    other_cmap = sns.color_palette(['#ededed']) 
    other_cmap = ListedColormap(other_cmap)
    o = PatchCollection(others, alpha=1, cmap=other_cmap, edgecolor='k', linewidth=0, zorder=0)
    o_colors = np.ones(len(others)).astype(int)
    o.set_array(o_colors)
    o.set_clim(vmin=0, vmax=max(o_colors))

    nissl = (nissl > 0).astype(np.int)
    plt.imshow(nissl.T, cmap=plt.get_cmap('gray_r'), alpha=0, zorder=-1)

    plt.gca().add_collection(o)

    sns.scatterplot(x='spot_location_2', y='spot_location_1', data=current_reads, color='r', s=dot_size, legend=False, edgecolor='r', zorder=3, markers='.')

    plt.axis('off')
    plt.tight_layout(pad=0)

    current_fig_path = f"{expr_fig_path}/{current_gene}.tif"
    plt.savefig(current_fig_path, bbox_inches='tight', pad_inches=0)
    plt.close()   

In [None]:
# plot list_5 with all dots
save_as = True
dot_size = 2

if save_as:
    expr_fig_path = os.path.join(fig_path, f'2022-04-12-expr-reads', 'list-5-dots-dark')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)

genes = ['Shank1', 'Eef2', 'Plk2', 'Rtn4']
# for i, current_gene in enumerate(tqdm(list_5['gene'])):
for i, current_gene in enumerate(tqdm(genes)):
    
    current_reads = all_reads.loc[all_reads['gene'] == current_gene, :]

    fig, ax = plt.subplots(figsize=[40,50])

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', hue='soma', data=all_reads, s=dot_size, legend=False, edgecolor=None, ax=ax)
    # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=good_reads, color='#dbdbdb', s=dot_size, legend=False, edgecolor=None, ax=ax, zorder=0)
    plt.plot(good_reads.loc[:,'spot_location_1'], good_reads.loc[:,'spot_location_2'], '.', ms=1, alpha=0.5, c='#ededed')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax, alpha=0.5, zorder=1)
    plt.plot(current_reads.loc[:,'spot_location_1'], current_reads.loc[:,'spot_location_2'], '.', ms=dot_size, alpha=1, c='#2e0d52')

    ax.axis('off')
    plt.tight_layout(pad=0)

    if save_as:
        plt.savefig(os.path.join(expr_fig_path, f'{current_gene}_RIBOmap.tif'))
        plt.close()
    else:
        plt.show()
            


In [None]:
# plot list_5 with all dots two color
save_as = True
dot_size = 2

if save_as:
    expr_fig_path = os.path.join(fig_path, f'2022-04-12-expr-reads', 'list-5-dots-2')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)

genes = ['Shank1', 'Eef2', 'Plk2', 'Rtn4']
# for i, current_gene in enumerate(tqdm(list_5['gene'])):
for i, current_gene in enumerate(tqdm(genes)):
    
    current_reads = all_reads.loc[all_reads['gene'] == current_gene, :]
    soma_reads = current_reads.loc[current_reads['cell_barcode'] != -1, :]
    non_soma_reads = current_reads.loc[current_reads['cell_barcode'] == -1, :]
    
    fig, ax = plt.subplots(figsize=[40,50])

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', hue='soma', data=all_reads, s=dot_size, legend=False, edgecolor=None, ax=ax)
    # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=good_reads, color='#dbdbdb', s=dot_size, legend=False, edgecolor=None, ax=ax, zorder=0)
    plt.plot(good_reads.loc[:,'spot_location_1'], good_reads.loc[:,'spot_location_2'], '.', ms=1, alpha=0.5, c='#ededed')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax, alpha=0.5, zorder=1)
    plt.plot(soma_reads.loc[:,'spot_location_1'], soma_reads.loc[:,'spot_location_2'], '.', ms=dot_size, alpha=1, c='b')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax, alpha=0.5, zorder=1)
    plt.plot(non_soma_reads.loc[:,'spot_location_1'], non_soma_reads.loc[:,'spot_location_2'], '.', ms=dot_size, alpha=1, c='r')

    
    ax.axis('off')
    plt.tight_layout(pad=0)

    if save_as:
        plt.savefig(os.path.join(expr_fig_path, f'{current_gene}_RIBOmap.tif'))
        plt.close()
    else:
        plt.show()
            


In [None]:
# plot list_6 with all dots two color
save_as = True
dot_size = 4

if save_as:
    expr_fig_path = os.path.join(fig_path, f'2022-04-12-expr-reads', 'list-6-dots-4')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)

genes = ['Eef2', 'Plk2', 'Rtn4', 'Shank1']
# for i, current_gene in enumerate(tqdm(list_6['gene'])):
for i, current_gene in enumerate(tqdm(genes)):
    
    current_reads = all_reads.loc[all_reads['gene'] == current_gene, :]
    soma_reads = current_reads.loc[current_reads['cell_barcode'] != -1, :]
    non_soma_reads = current_reads.loc[current_reads['cell_barcode'] == -1, :]
    
    fig, ax = plt.subplots(figsize=[40,50])

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', hue='soma', data=all_reads, s=dot_size, legend=False, edgecolor=None, ax=ax)
    # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=good_reads, color='#dbdbdb', s=dot_size, legend=False, edgecolor=None, ax=ax, zorder=0)
    plt.plot(good_reads.loc[:,'spot_location_1'], good_reads.loc[:,'spot_location_2'], '.', ms=1, alpha=0.5, c='#ededed')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax, alpha=0.5, zorder=1)
    plt.plot(soma_reads.loc[:,'spot_location_1'], soma_reads.loc[:,'spot_location_2'], '.', ms=dot_size, alpha=1, c='b')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax, alpha=0.5, zorder=1)
    plt.plot(non_soma_reads.loc[:,'spot_location_1'], non_soma_reads.loc[:,'spot_location_2'], '.', ms=dot_size, alpha=1, c='r')

    
    ax.axis('off')
    plt.tight_layout(pad=0)

    if save_as:
        plt.savefig(os.path.join(expr_fig_path, f'{current_gene}_RIBOmap.tif'))
        plt.close()
    else:
        plt.show()
            


In [None]:
# plot list_7 with all dots two color
save_as = True
dot_size = 4

if save_as:
    expr_fig_path = os.path.join(fig_path, f'2022-04-12-expr-reads', 'list-7-dots-4')
    if not os.path.exists(expr_fig_path):
        os.mkdir(expr_fig_path)

# genes = ['Eef2', 'Plk2', 'Rtn4', 'Shank1']
for i, current_gene in enumerate(tqdm(list_7['gene'][153:])):
# for i, current_gene in enumerate(tqdm(genes)):
    
    current_reads = all_reads.loc[all_reads['gene'] == current_gene, :]
    soma_reads = current_reads.loc[current_reads['cell_barcode'] != -1, :]
    non_soma_reads = current_reads.loc[current_reads['cell_barcode'] == -1, :]
    
    fig, ax = plt.subplots(figsize=[40,50])

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', hue='soma', data=all_reads, s=dot_size, legend=False, edgecolor=None, ax=ax)
    # plt.scatter(all_reads.loc[:,'spot_location_1'], all_reads.loc[:,'spot_location_2'], s=1, alpha=0.2, c='b')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=good_reads, color='#dbdbdb', s=dot_size, legend=False, edgecolor=None, ax=ax, zorder=0)
    plt.plot(good_reads.loc[:,'spot_location_1'], good_reads.loc[:,'spot_location_2'], '.', ms=1, alpha=0.5, c='#ededed')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax, alpha=0.5, zorder=1)
    plt.plot(soma_reads.loc[:,'spot_location_1'], soma_reads.loc[:,'spot_location_2'], '.', ms=dot_size, alpha=1, c='b')

    # sns.scatterplot(x='spot_location_1', y='spot_location_2', data=current_reads, color='r', s=dot_size, legend=False, edgecolor=None, ax=ax, alpha=0.5, zorder=1)
    plt.plot(non_soma_reads.loc[:,'spot_location_1'], non_soma_reads.loc[:,'spot_location_2'], '.', ms=dot_size, alpha=1, c='r')

    
    ax.axis('off')
    plt.tight_layout(pad=0)

    if save_as:
        plt.savefig(os.path.join(expr_fig_path, f'{current_gene}_RIBOmap.tif'))
        plt.close()
    else:
        plt.show()
            


### reads distribution

In [None]:
good_reads_df = pd.DataFrame(good_reads['gene'].value_counts())
good_reads_df.columns = ['soma reads']
# good_reads_df['gene'] = good_reads_df.index

background_reads_df = pd.DataFrame(background_reads['gene'].value_counts())
background_reads_df.columns = ['non-soma reads']
# background_reads_df['gene'] = background_reads_df.index

all_reads_df = pd.DataFrame(all_reads['gene'].value_counts())
all_reads_df.columns = ['all reads']
# all_reads_df['gene'] = all_reads_df.index

reads_df = pd.concat([good_reads_df, background_reads_df, all_reads_df], axis=1)
reads_df['soma reads percentage'] = reads_df['soma reads'] / reads_df['all reads']
reads_df['non-soma reads percentage'] = reads_df['non-soma reads'] / reads_df['all reads']
reads_df

In [None]:
ns_thres = reads_df['non-soma reads percentage'].quantile(0.90) # compare w/ list_1
s_thres = reads_df['soma reads percentage'].quantile(0.90) # compare w/ list_4

reads_df['top soma gene'] = reads_df['soma reads percentage'] > s_thres
reads_df['top non-soma gene'] = reads_df['non-soma reads percentage'] > ns_thres
reads_df

In [None]:
soma_gene_overlap = list_4.loc[list_4['gene'].isin(reads_df.loc[reads_df['top soma gene'], :].index.to_list()), :]
non_soma_gene_overlap = list_1.loc[list_1['gene'].isin(reads_df.loc[reads_df['top non-soma gene'], :].index.to_list()), :]
non_soma_gene_overlap

In [None]:
with pd.ExcelWriter(os.path.join(fig_path, f'{date}-reads-distribution.xlsx')) as writer:  
    reads_df.to_excel(writer, sheet_name='reads-distribution')
    soma_gene_overlap.to_excel(writer, sheet_name='soma-gene-overlap')
    non_soma_gene_overlap.to_excel(writer, sheet_name='non-soma-gene-overlap')

In [None]:
sns.set_style('white')
fig, axs = plt.subplots(figsize=(10, 4), ncols=2)
sns.histplot(reads_df['soma reads percentage'], ax=axs[0])
axs[0].axvline(x=reads_df['soma reads percentage'].median(), color='r', linestyle='-')
axs[0].set_title('soma reads percentage')

sns.histplot(reads_df['non-soma reads percentage'], ax=axs[1])
axs[1].axvline(x=reads_df['non-soma reads percentage'].median(), color='r', linestyle='-')
axs[1].set_title('non-soma reads percentage')
plt.savefig(os.path.join(fig_path, 'reads-distribution.pdf'))
plt.show()

In [None]:
reads_df['class'] = 0
reads_df.loc[reads_df['top soma gene'] == True, 'class'] = 1
reads_df.loc[reads_df['top non-soma gene'] == True, 'class'] = 2
reads_df['class'] = reads_df['class'].astype('category')

In [None]:
cpl_colors = ['#bfbfbf', '#1d43cf', '#cf1d1d']
cpl = sns.color_palette(cpl_colors)
cmap = ListedColormap(cpl.as_hex())

In [None]:
sns.set_style("ticks")
fig, ax = plt.subplots(figsize=(6,5))

sns.scatterplot(y='non-soma reads percentage', x='order', hue='class', data=reads_df, s=12, edgecolor=None, palette=cpl, legend=False)

annotate_genes = ['Shank1', 'Eef2', 'Kif5a', 'Calm1', 'Gfap', 'Mbp', 'App', 'Rtn4', 'Mal']
for gene in annotate_genes:
    x = reads_df.loc[reads_df['index'] == gene, 'order'].values[0]
    y = reads_df.loc[reads_df['index'] == gene, 'non-soma reads percentage'].values[0]
    print(gene, x, y)
    
sns.scatterplot(y='non-soma reads percentage', x='order', hue='class', data=reads_df.loc[reads_df['index'].isin(annotate_genes), :], s=12, edgecolor='k', linewidth=1, palette=cpl, legend=False)
    
ax.annotate('Shank1', (2, 0.41132502274347815), xytext=(1000, 0.42), size=7,
            bbox=dict(boxstyle="round", alpha=0.1), 
            arrowprops = dict(arrowstyle='-', connectionstyle="arc3", facecolor='black', edgecolor='black', lw=1)
           )

ax.annotate('Mbp', (4, 0.3922457104938857), xytext=(1000, 0.40), size=7,
            bbox=dict(boxstyle="round", alpha=0.1), 
            arrowprops = dict(arrowstyle='-', connectionstyle="arc3", facecolor='black', edgecolor='black', lw=1)
           )

ax.annotate('Gfap', (20, 0.3576113516525914), xytext=(1000, 0.37), size=7,
            bbox=dict(boxstyle="round", alpha=0.1), 
            arrowprops = dict(arrowstyle='-', connectionstyle="arc3", facecolor='black', edgecolor='black', lw=1)
           )

ax.annotate('Kif5a', (52, 0.3363509649236252), xytext=(1000, 0.35), size=7,
            bbox=dict(boxstyle="round", alpha=0.1), 
            arrowprops = dict(arrowstyle='-', connectionstyle="arc3", facecolor='black', edgecolor='black', lw=1)
           )

ax.annotate('Eef2', (59, 0.33259025373758894), xytext=(1000, 0.33), size=7,
            bbox=dict(boxstyle="round", alpha=0.1), 
            arrowprops = dict(arrowstyle='-', connectionstyle="arc3", facecolor='black', edgecolor='black', lw=1)
           )

ax.annotate('Calm1', (175, 0.3051969410410328), xytext=(1000, 0.30), size=7,
            bbox=dict(boxstyle="round", alpha=0.1), 
            arrowprops = dict(arrowstyle='-', connectionstyle="arc3", facecolor='black', edgecolor='black', lw=1)
           )

ax.annotate('Rtn4', (4940, 0.14756159728122345), xytext=(4000, 0.14), size=7,
            bbox=dict(boxstyle="round", alpha=0.1), 
            arrowprops = dict(arrowstyle='-', connectionstyle="arc3", facecolor='black', edgecolor='black', lw=1)
           )

ax.annotate('App', (5073, 0.1385932516367299), xytext=(4000, 0.12), size=7,
            bbox=dict(boxstyle="round", alpha=0.1), 
            arrowprops = dict(arrowstyle='-', connectionstyle="arc3", facecolor='black', edgecolor='black', lw=1)
           )

ax.annotate('Mal', (5378, 0.08249568623154666), xytext=(4000, 0.07), size=7,
            bbox=dict(boxstyle="round", alpha=0.1), 
            arrowprops = dict(arrowstyle='-', connectionstyle="arc3", facecolor='black', edgecolor='black', lw=1)
           )

plt.savefig(os.path.join(fig_path, 'neuropil_reads.pdf'))

plt.show()

In [None]:
import plotly.express as px
fig = px.scatter(reads_df, x="order", y="non-soma reads percentage", 
                 color='class', 
                 color_discrete_sequence=cpl_colors,
                 category_orders={"class": [0, 1, 2]},
                custom_data=['index', 'non-soma reads percentage', 'soma reads percentage'])

fig.update_layout(
    autosize=False,
    width=800,
    height=800,)

fig.update_traces(
    hovertemplate="<br>".join([
        "Gene: %{customdata[0]}",
        "non-soma %: %{customdata[1]}",
        "soma %: %{customdata[2]}",
    ])
)

fig.write_html(os.path.join(fig_path, 'neuropil_reads.html'))
fig.show()

In [None]:
# level 2 markers reads distribution 
marker_df = pd.read_csv(os.path.join(fig_path, 'level_2_markers.csv'), index_col=0)
plot_marker_df = marker_df.groupby('group').head(10)
plot_marker_df = plot_marker_df.loc[plot_marker_df['group'] != 'Unknown', :]
plot_marker_df['non-soma reads percentage'] = 0
print(plot_marker_df.shape)
for gene in plot_marker_df['names']:
    # print(gene)
    plot_marker_df.loc[plot_marker_df['names'] == gene, 'non-soma reads percentage'] = reads_df.loc[reads_df['index'] == gene, 'non-soma reads percentage'].values[0]


In [None]:
fig, ax = plt.subplots(figsize=(7,3))
ax = sns.violinplot(x="group", y="non-soma reads percentage", data=plot_marker_df, palette="tab10",
                     inner="box")
plt.xticks(rotation=45)
# plt.savefig(os.path.join(fig_path, 'violin_ncounts_sample.pdf'))
plt.show()

## Test

### Dendrogram

In [None]:
used_genes = adata.var.loc[adata.var['detected'] == True, :].index.to_list()
sc.tl.dendrogram(adata, 'level_3', use_rep='X_pca', optimal_ordering=True, use_raw=True, n_pcs=30, var_names=used_genes)
fig, ax = plt.subplots(figsize=(1,20))
sc.set_figure_params(format='pdf', dpi=150)
sc.pl.dendrogram(adata, 'level_3', orientation='left', ax=ax, save=False)

### Sankey diagram

In [None]:
def contingency(a, b, unique_a, unique_b):
    """Populate contingency matrix. Rows and columns are not normalized in any way.

    Args:
        a (np.array): labels
        b (np.array): labels
        unique_a (np.array): unique list of labels. Can have more entries than np.unique(a)
        unique_b (np.array): unique list of labels. Can have more entries than np.unique(b)

    Returns:
        C (np.array): contingency matrix.
    """
    # assert a.shape == b.shape
    C = np.zeros((np.size(unique_a), np.size(unique_b)))
    for i, la in enumerate(unique_a):
        for j, lb in enumerate(unique_b):
            C[i, j] = np.sum(np.logical_and(a == la, b == lb))
            
    df = pd.DataFrame(C)
    df.index = unique_a
    df.columns = unique_b

    return df

In [None]:
def generate_nodes(df, key1, key2):
    
    df_river = {key1: [], key2: [], 'value': []}
    for key2_type in df.columns.to_list():
        for key1_type in df.index.to_list():
            df_river[key1].append(key1_type)
            df_river[key2].append(key2_type)
            df_river['value'].append(df.loc[key1_type, key2_type])
    df_river = pd.DataFrame(df_river)
    df_river = df_river.loc[df_river['value'] != 0, :]

    all_nodes = df_river[key1].unique().tolist() + df_river[key2].unique().tolist()
    source_indices = [all_nodes.index(key1_type) for key1_type in df_river[key1]]
    target_indices = [all_nodes.index(key2_type) for key2_type in df_river[key2]]
    
    return df_river, all_nodes, source_indices, target_indices

In [None]:
level_12_df = contingency(adata.obs['level_1'].values, 
                   adata.obs['level_2'].values, 
                   adata.obs['level_1'].cat.categories.values, 
                   adata.obs['level_2'].cat.categories.values)
level_12_df

In [None]:
df_river, all_nodes, source_indices, target_indices = generate_nodes(level_12_df, 'level_1', 'level_2')

level_12_colors = adata.uns['level_1_color_list'].tolist() + adata.uns['level_2_color_list'].tolist()
node_colors_mappings = dict([(node, level_12_colors[i]) for i, node in enumerate(all_nodes)])
node_colors = [node_colors_mappings[node] for node in all_nodes]
edge_colors = [node_colors_mappings[node] for node in df_river['level_1']]

import plotly.graph_objects as go
save_river = False
fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=20,
            thickness=20,
            line=dict(color="black", width=1.0),
            label=all_nodes,
            color=node_colors,
        ),

        link=dict(
            source=source_indices,
            target=target_indices,
            value=df_river['value'],
            color=edge_colors,
        ))])

fig.update_layout(title_text="level1 to level2",
                  height=600,
                  font_size=10)
if save_river:
    fig.write_image(save_river)
fig.show()

In [None]:
level_23_df = contingency(adata.obs['level_2'].values, 
                   adata.obs['level_3'].values, 
                   adata.obs['level_2'].cat.categories.values, 
                   adata.obs['level_3'].cat.categories.values)
level_23_df

In [None]:
df_river, all_nodes, source_indices, target_indices = generate_nodes(level_23_df, 'level_2', 'level_3')

node_x = [0] * 8 + [1] * 57
node_y = [i for i in range(8)] + [i*0.05 for i in range(57)]

level_12_colors = adata.uns['level_2_color_list'].tolist() + adata.uns['level_3_color_list'].tolist()
node_colors_mappings = dict([(node, level_12_colors[i]) for i, node in enumerate(all_nodes)])
node_colors = [node_colors_mappings[node] for node in all_nodes]
edge_colors = [node_colors_mappings[node] for node in df_river['level_2']]

import plotly.graph_objects as go
save_river = False
fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=10,
            thickness=20,
            line=dict(color="black", width=1.0),
            label=all_nodes,
            color=node_colors,
            # x=node_x, 
            # y=node_y,
        ),

        link=dict(
            source=source_indices,
            target=target_indices,
            value=df_river['value'],
            color=edge_colors,
        ))])

fig.update_layout(
                  height=3000,
                  font_size=20)
if save_river:
    fig.write_image(save_river)
fig.show()

#### combine

In [None]:
# get tables 

level_12_df = contingency(adata.obs['level_1'].values, 
                   adata.obs['level_2'].values, 
                   adata.obs['level_1'].cat.categories.values, 
                   adata.obs['level_2'].cat.categories.values)

level_23_df = contingency(adata.obs['level_2'].values, 
                   adata.obs['level_3'].values, 
                   adata.obs['level_2'].cat.categories.values, 
                   adata.obs['level_3'].cat.categories.values)

In [None]:
adata.uns['level_1_color_list_new'] = ['#eded58', '#356be8'] # ['#db5f57', '#57d3db']

In [None]:
# get nodes

df_river_12, all_nodes_12, source_indices_12, target_indices_12 = generate_nodes(level_12_df, 'level_1', 'level_2')
df_river_23, all_nodes_23, source_indices_23, target_indices_23 = generate_nodes(level_23_df, 'level_2', 'level_3')

all_nodes = adata.obs['level_1'].cat.categories.to_list() + adata.obs['level_2'].cat.categories.to_list() +  adata.obs['level_3'].cat.categories.to_list()
source_indices = source_indices_12 + [i+2 for i in source_indices_23]
target_indices = target_indices_12 + [i+2 for i in target_indices_23]

In [None]:
# get colors and values 
all_colors = adata.uns['level_1_color_list_new'] + adata.uns['level_2_color_list'].tolist() + adata.uns['level_3_color_list'].tolist()
node_colors_mappings = dict([(node, all_colors[i]) for i, node in enumerate(all_nodes)])
node_colors = [node_colors_mappings[node] for node in all_nodes]
edge_colors = [node_colors_mappings[node] for node in df_river_12['level_1']] + [node_colors_mappings[node] for node in df_river_23['level_2']]
link_values = df_river_12['value'].to_list() + df_river_23['value'].to_list()

In [None]:
import plotly.graph_objects as go
save_river = False
fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=10,
            thickness=20,
            line=dict(color="black", width=1.0),
            label=all_nodes,
            color=node_colors,
        ),

        link=dict(
            source=source_indices,
            target=target_indices,
            value=link_values,
            # color=edge_colors,
        ))])

fig.update_layout(width=1500, 
                  height=2000,
                  font_size=20)
# if save_river:
fig.write_image(os.path.join(fig_path, 'sankey-test.pdf'))
fig.show()

In [None]:
coords_df4 = pd.read_csv('Z:/jiahao/Github/RIBOmap/segmentation-stitching/RIBOmap/results/tuned_coords.csv', index_col=0)
coords_df4 = coords_df4.loc[coords_df4['tile'] != 0, :]
coords_df4

In [None]:
tile_col = coords_df4['column'].values + 1000
tile_row = coords_df4['row'].values + 1000

fig, ax = plt.subplots(figsize=(18,23))
sns.scatterplot(x='column', y='row', data=cell_center_df, color='#dbdbdb', s=dot_size, legend=False, edgecolor=None, ax=ax)
sns.scatterplot(x=tile_col, y=tile_row, color='black', s=dot_size, legend=False, edgecolor=None, ax=ax)
for i in range(coords_df4.shape[0]):
    plt.text(tile_col[i], tile_row[i], s=coords_df4['tile'].astype(str).to_list()[i], size=10, c='red')
ax.set_aspect('equal')
plt.show()

In [None]:
coords_df4 = pd.read_csv('Z:/jiahao/Github/RIBOmap/segmentation-stitching/RIBOmap/results/tuned_coords.csv', index_col=0)
coords_df4 = coords_df4.loc[coords_df4['tile'] != 0, :]

col_min = 20000
row_min = 45000
col_max = 45000
row_max = 69000

corrds_logical = (coords_df4['column'].isin(range(col_min, col_max))) & (coords_df4['row'].isin(range(row_min, row_max)))
scoords_df = coords_df4.loc[corrds_logical, :]
scoords_df