# Gene cluster visualization

2023-04-04

In [None]:
import sys
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
import random as rd
import matplotlib.pyplot as plt
from starmap.sequencing import *
from natsort import natsorted
from scipy.io import loadmat, savemat
from skimage.filters import threshold_otsu
from skimage.color import label2rgb
from tqdm.notebook import tqdm

from skimage.measure import find_contours
from matplotlib.colors import ListedColormap
from skimage.segmentation import find_boundaries

### Functions

In [None]:
from functools import wraps
from time import time

# Timer
def timer(func):
    @wraps(func)
    def _time_it(*args, **kwargs):
        start = int(round(time() * 1000))
        try:
            return func(*args, **kwargs)
        finally:
            end_ = int(round(time() * 1000)) - start
            end_ = round(end_ / 1000, 4)
            print(f"Total execution time: {end_ if end_ > 0 else 0} s")
    return _time_it


# Load reads and their positions from mat file
def load_reads(fpath, reads_file):
    S = loadmat(os.path.join(fpath, reads_file))
    bases = [str(i[0][0]) for i in S["merged_reads"]]
    points = S["merged_points"]
    temp = np.zeros(points.shape)
    temp[:, 0] = np.round(points[:, 1]-1)
    temp[:, 1] = np.round(points[:, 0]-1)
    temp[:, 2] = np.round(points[:, 2]-1)
    
    print(f"Number of reads: {len(bases)}")
    
    return bases, temp


# Load gene table from genes.csv
def load_genes(fpath):
    genes2seq = {}
    seq2genes = {}
    with open(os.path.join(fpath, "genes.csv"), encoding='utf-8-sig') as f:
        for l in f:
            fields = l.rstrip().split(",")
            genes2seq[fields[0]] = "".join([str(s+1) for s in encode_SOLID(fields[1][::-1])])
            seq2genes[genes2seq[fields[0]]] = fields[0]
    return genes2seq, seq2genes

## Input path

In [None]:
# IO path 
# base_path = 'path/to/2022-03-28-TEMPOmap-images'
base_path = 'K:/2022-03-28-TEMPOmap-images/'

out_path = os.path.join(base_path, 'output')
if not os.path.exists(out_path):
    os.mkdir(out_path)
    
fig_path = os.path.join(base_path, 'figures')
if not os.path.exists(fig_path):
    os.mkdir(fig_path)
    
sample_dict = {'20h_labeling': '20h labeling',
               '1h_labeling_6h_wash': '1h labeling, 6h wash',
               '1h_labeling_4h_wash': '1h labeling, 4h wash',
               '1h_labeling_2h_wash': '1h labeling, 2h wash',
               '1h_labeling_1h_wash': '1h labeling, 1h wash',
               '1h_labeling': '1h labeling'}

sample_dirs = list(sample_dict.keys())
sample_dirs

## Load gene cluster label

In [None]:
# load m6A 
m6a_dc1_df = pd.read_excel(os.path.join(base_path, 'gene_modules/2022-01-09-m6A-gene-label.xlsx'))
m6a_dc1_df.columns = ['gene', 'm6A label']

# modify
m6a_dc1_df['m6A label new'] = 999
m6a_dc1_df.loc[m6a_dc1_df['m6A label'] == 'm6A', 'm6A label new'] = 0
m6a_dc1_df.loc[m6a_dc1_df['m6A label'] == 'non m6A', 'm6A label new'] = 1

m6a_dc1_df

In [None]:
# load new labels 
gene_cluster_df = pd.read_excel(os.path.join(base_path, 'gene_modules/2022-05-16-TEMPOmap_new_gene_group.xlsx'))
gene_cluster_df.columns = ['gene', 'cluster']
gene_cluster_df

## Visualization

### Single cell

In [None]:
# plot_single_cell_pattern
def plot_single_cell_pattern(base_path=None, sample=None, orig_index=None, 
                             input_df=None, plot_field=None, offset=None, bnd_linewidth=2, spot_size=15, 
                            nuclei_color=None, cell_color=None, spot_colors=None,
                            show=True, save=False, save_path=None, file_name=None):
    
    # set path 
    input_path = os.path.join(base_path, sample)

    # load image
    current_dapi_label = load_label_image(os.path.join(input_path, '2D'), fname='dapi_label.tiff')
    current_cell_label = load_label_image(os.path.join(input_path, '2D'), fname='overlay_label.tiff')
    current_segmentation = load_label_image(os.path.join(input_path, '3D'), fname='cell.tif')
    
    # load gene info
    genes2seqs, seqs2genes = load_genes(base_path)

    # Load reads 
    bases, points = load_reads(input_path, "merged_goodPoints_max3d_new.mat")
    bases = np.array(bases)
    points = np.array(points)
    
    # find cell
    cell = current_cell_label == orig_index + 1
    dapi = np.where(cell, current_dapi_label, 0)

    cell_region = None
    # print(current_cell_label.max())
    for i, region in enumerate(regionprops(current_cell_label)):
        label_num = orig_index + 1
        if region.label == label_num:
            cell_region = region

    cell_bbox = cell_region.bbox
    cell_crop = cell[cell_bbox[0]:cell_bbox[2], cell_bbox[1]:cell_bbox[3]]
    dapi_crop = dapi[cell_bbox[0]:cell_bbox[2], cell_bbox[1]:cell_bbox[3]]
    dapi_crop = dapi_crop > 0

    # get alignment 
    bg = np.zeros([offset, offset], np.uint8)
    bbox_center = np.array(dapi_crop.shape) / 2
    bbox_new_origin = [int(offset / 2 - bbox_center[0]), int(offset / 2 - bbox_center[1])]

    # get nuclei / cell contour
    bg[bbox_new_origin[0]:bbox_new_origin[0] + cell_crop.shape[0], bbox_new_origin[1]:bbox_new_origin[1] + cell_crop.shape[1]] = cell_crop
    cell_bnd = find_contours(bg>0, level=.5)

    bg[bbox_new_origin[0]:bbox_new_origin[0] + cell_crop.shape[0], bbox_new_origin[1]:bbox_new_origin[1] + cell_crop.shape[1]] = dapi_crop
    dapi_bnd = find_contours(bg>0, level=.5)

    # get reads 
    points = points.astype(int)
    reads_assignment = current_segmentation[points[:, 2], points[:, 0], points[:, 1]]

    reads_info = pd.DataFrame({'x':points[:, 0], 'y':points[:, 1], 'z':points[:, 2], 'cell_label':reads_assignment})
    reads_info = reads_info.astype(np.int32)
    reads_info['orig_index'] = reads_info['cell_label'] - 1
    reads_info['gene'] = bases
    reads_info = reads_info.loc[reads_info['orig_index'] == orig_index, :]
    reads_info = pd.merge(reads_info, input_df)

    # get reads shift 
    shifts = [bbox_new_origin[0] - cell_bbox[0], bbox_new_origin[1] - cell_bbox[1]]
    reads_info['x'] = reads_info['x'] + shifts[0] 
    reads_info['y'] = reads_info['y'] + shifts[1] 
    
    # set color
    cpl = sns.color_palette(spot_colors)
    
    bg = np.ones([offset, offset], np.uint8) * 255
    # bg = np.zeros([offset, offset], np.uint8)
    
    plt.subplots(figsize=(10, 10))
    plt.imshow(bg, cmap='Greys')
    for contour in dapi_bnd:
        plt.plot(contour[:, 1], contour[:, 0], linewidth=bnd_linewidth, c=nuclei_color)

    for contour in cell_bnd:
        plt.plot(contour[:, 1], contour[:, 0], linewidth=bnd_linewidth, c=cell_color)

    sns.scatterplot(x="y", y="x", hue=plot_field,
                    linewidth=0, s=spot_size, palette=cpl,
                    data=reads_info)
    plt.axis('off')
    
    if save:
        plt.savefig(os.path.join(save_path, file_name))

    if show:
        plt.show()
    else:
        plt.clf()
        plt.close()

In [None]:
# basic parameters
sample = '1h_labeling'
orig_index = 2171
input_df = gene_cluster_df
plot_field = 'cluster'
offset = 750
# nuclei_color = np.array((255, 255, 255)) / 255
# cell_color = np.array((255, 255, 255)) / 255

nuclei_color = np.array((0, 0, 0))
cell_color = np.array((0, 0, 0))

# spot_colors = ['#A895E6', '#22E6A2', '#E6350B', '#EBDB34'] 
# spot_colors = ['#F8766D', '#7CAE02', '#00BFC4', '#C77CFF'] # original
# spot_colors = ['#E6350B', '#32b30e', '#00BFC4', '#ad40ff'] # bright

spot_colors = ['#A895E6', '#22E6A2', '#E6350B', '#EBDB34', '#d340e3'] # original 5 colors
    
file_name = f'white_bg_{sample}_{orig_index}_original.pdf'

plot_single_cell_pattern(base_path=base_path, 
                         sample=sample, 
                         orig_index=orig_index, 
                         input_df=input_df, 
                         plot_field=plot_field, 
                         offset=offset,
                         bnd_linewidth=2, 
                         spot_size=20, 
                         nuclei_color=nuclei_color, 
                         cell_color=cell_color, 
                         spot_colors=spot_colors,
                         show=False, 
                         save=True, 
                         save_path=fig_path, 
                         file_name=file_name)


In [None]:
# basic parameters
cell_dict = {
    '20h_labeling': [55, 1262, 963, 222],
    '1h_labeling': [2169, 2178, 1098, 1461],
    '1h_labeling_2h_wash': [1347, 1369, 1669, 1273],
    '1h_labeling_4h_wash': [1694, 1917, 1755, 700],
    '1h_labeling_6h_wash': [231, 120, 1600, 1130],
}
# sample = '20h_labeling'
# orig_index = 1263
input_df = gene_cluster_df
plot_field = 'cluster'
offset = 750
# nuclei_color = np.array((255, 255, 255)) / 255
# cell_color = np.array((255, 255, 255)) / 255

nuclei_color = np.array((0, 0, 0))
cell_color = np.array((0, 0, 0))

# spot_colors = ['#A895E6', '#22E6A2', '#E6350B', '#EBDB34'] 
# spot_colors = ['#F8766D', '#7CAE02', '#00BFC4', '#C77CFF'] # original
# spot_colors = ['#E6350B', '#32b30e', '#00BFC4', '#ad40ff'] # bright

spot_colors = ['#F8766D', '#A3A500', '#00BF7D', '#00B0F6', '#E76BF3'] # original 5 colors

save_path = os.path.join(fig_path, 'cluster-sc')
if not os.path.exists(save_path):
    os.mkdir(save_path)

for current_sample in tqdm(cell_dict.keys()):
    for current_cell in cell_dict[current_sample]:
        file_name = f'white_bg_{current_sample}_{current_cell}_original.pdf'
        plot_single_cell_pattern(base_path=base_path, 
                                 sample=current_sample, 
                                 orig_index=current_cell, 
                                 input_df=input_df, 
                                 plot_field=plot_field, 
                                 offset=offset,
                                 bnd_linewidth=2, 
                                 spot_size=20, 
                                 nuclei_color=nuclei_color, 
                                 cell_color=cell_color, 
                                 spot_colors=spot_colors,
                                 show=False, 
                                 save=True, 
                                 save_path=save_path, 
                                 file_name=file_name)


### Batch

In [None]:
# plot_single_cell_pattern
def plot_single_cell_pattern_batch(base_path=None, sample=None, orig_index_list=None, 
                             input_df=None, plot_field=None, offset=None, bnd_linewidth=2, spot_size=15, 
                            nuclei_color=None, cell_color=None, spot_colors=None,
                            show=True, save=False, save_path=None):
    
    # set path 
    input_path = os.path.join(base_path, sample)

    # load image
    current_dapi_label = load_label_image(os.path.join(input_path, '2D'), fname='dapi_label.tiff')
    current_cell_label = load_label_image(os.path.join(input_path, '2D'), fname='overlay_label.tiff')
    current_segmentation = load_label_image(os.path.join(input_path, '3D'), fname='cell.tif')
    
    # load gene info
    genes2seqs, seqs2genes = load_genes(base_path)

    # Load reads 
    bases, points = load_reads(input_path, "merged_goodPoints_max3d_new.mat")
    bases = np.array(bases)
    points = np.array(points)
    
    for orig_index in tqdm(orig_index_list):
        # find cell
        cell = current_cell_label == orig_index + 1
        dapi = np.where(cell, current_dapi_label, 0)

        cell_region = None
        # print(current_cell_label.max())
        for i, region in enumerate(regionprops(current_cell_label)):
            label_num = orig_index + 1
            if region.label == label_num:
                cell_region = region

        cell_bbox = cell_region.bbox
        cell_crop = cell[cell_bbox[0]:cell_bbox[2], cell_bbox[1]:cell_bbox[3]]
        dapi_crop = dapi[cell_bbox[0]:cell_bbox[2], cell_bbox[1]:cell_bbox[3]]
        dapi_crop = dapi_crop > 0

        # get alignment 
        bg = np.zeros([offset, offset], np.uint8)
        bbox_center = np.array(dapi_crop.shape) / 2
        bbox_new_origin = [int(offset / 2 - bbox_center[0]), int(offset / 2 - bbox_center[1])]

        # get nuclei / cell contour
        bg[bbox_new_origin[0]:bbox_new_origin[0] + cell_crop.shape[0], bbox_new_origin[1]:bbox_new_origin[1] + cell_crop.shape[1]] = cell_crop
        cell_bnd = find_contours(bg>0, level=.5)

        bg[bbox_new_origin[0]:bbox_new_origin[0] + cell_crop.shape[0], bbox_new_origin[1]:bbox_new_origin[1] + cell_crop.shape[1]] = dapi_crop
        dapi_bnd = find_contours(bg>0, level=.5)

        # get reads 
        points = points.astype(int)
        reads_assignment = current_segmentation[points[:, 2], points[:, 0], points[:, 1]]

        reads_info = pd.DataFrame({'x':points[:, 0], 'y':points[:, 1], 'z':points[:, 2], 'cell_label':reads_assignment})
        reads_info = reads_info.astype(np.int32)
        reads_info['orig_index'] = reads_info['cell_label'] - 1
        reads_info['gene'] = bases
        reads_info = reads_info.loc[reads_info['orig_index'] == orig_index, :]
        reads_info = pd.merge(reads_info, input_df)

        # get reads shift 
        shifts = [bbox_new_origin[0] - cell_bbox[0], bbox_new_origin[1] - cell_bbox[1]]
        reads_info['x'] = reads_info['x'] + shifts[0] 
        reads_info['y'] = reads_info['y'] + shifts[1] 

        # set color
        cpl = sns.color_palette(spot_colors)

        bg = np.zeros([offset, offset], np.uint8)
        plt.subplots(figsize=(10, 10))
        plt.imshow(bg, cmap='Greys')
        for contour in dapi_bnd:
            plt.plot(contour[:, 1], contour[:, 0], linewidth=bnd_linewidth, c=nuclei_color)

        for contour in cell_bnd:
            plt.plot(contour[:, 1], contour[:, 0], linewidth=bnd_linewidth, c=cell_color)

        sns.scatterplot(x="y", y="x", hue=plot_field,
                        linewidth=0, s=spot_size, palette=cpl,
                        data=reads_info)
        plt.axis('off')

        if save:
            plt.savefig(os.path.join(save_path, f"Cell-{orig_index}.pdf"))

        if show:
            plt.show()
        else:
            plt.clf()
            plt.close()

#### Output

In [None]:
# import adata (cell cycle)
adata = sc.read_h5ad(os.path.join(base_path, 'output', '2023-04-05-Rena-EU-starmap-cc.h5ad'))

# basic parameters
input_df = gene_cluster_df
plot_field = 'cluster'
offset = 750
nuclei_color = np.array((0, 0, 0))
cell_color = np.array((0, 0, 0))
spot_colors = ['#F8766D', '#A3A500', '#00BF7D', '#00B0F6', '#E76BF3'] # original 5 colors
cell_to_select = 50

# set output path 
out_path = os.path.join(fig_path, '2022-05-22-cluster-sc')
if not os.path.exists(out_path):
    os.mkdir(out_path)
    
# iterate sample 
for i, sample in enumerate(tqdm(sample_dirs)):
    print(f"Plotting: {sample}")
    
    # get orig_index_list
    orig_index_list = adata.obs.loc[(adata.obs['sample'] == sample) & (adata.obs['phase_ref'] == 'G1'), 'orig_index'].to_list()
    # random sample 
    orig_index_list = rd.sample(orig_index_list, cell_to_select)

    save_path = os.path.join(out_path, sample)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    
    plot_single_cell_pattern_batch(base_path=base_path, 
                             sample=sample, 
                             orig_index_list=orig_index_list, 
                             input_df=input_df, 
                             plot_field=plot_field, 
                             offset=offset,
                             bnd_linewidth=2, 
                             spot_size=20, 
                             nuclei_color=nuclei_color, 
                             cell_color=cell_color, 
                             spot_colors=spot_colors,
                             show=False, 
                             save=True, 
                             save_path=save_path)

### Sample-wise

In [None]:
# Iterate through each sample dir
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')

color_dict = {1:'#A895E6', 2:'#22E6A2', 3:'#E6350B', 4:'#EBDB34', 5:'#d340e3'}
plot_field = 'cluster'

linewidth = .1

for current_dir in sample_dirs:

    print(f"Current sample: {current_dir}")
    input_path = os.path.join(base_path, current_dir)

    # Load genes
    genes2seqs, seqs2genes = load_genes(base_path)

    # Load reads 
    bases, points = load_reads(input_path, "merged_goodPoints_max3d_new.mat")
    bases = np.array(bases)
    points = np.array(points)
    
    structure_dict = {'whole_cell': 'cell.tif'}

    fig_out_path = base_path + 'figures' + '/' + f'{date}_{plot_field}' + '/' + current_dir
    if not os.path.exists(fig_out_path):
        os.makedirs(fig_out_path)

    # Load raw image
    current_dapi_label = load_label_image(os.path.join(input_path, '2D'), fname='dapi_label.tiff')
    current_cell_label = load_label_image(os.path.join(input_path, '2D'), fname='overlay_label.tiff')
    
    cell_bnd = find_boundaries(current_cell_label)
    dapi_bnd = find_boundaries(current_dapi_label)

    cell_bnd = cell_bnd.astype(np.uint8)
    cell_bnd[dapi_bnd] = 2
    bg = label2rgb(cell_bnd, colors=[(0,0,0), (0,0,255)], bg_label=0, bg_color=(1,1,1))
    
    for current_structure in structure_dict.keys():
        print(f"====Processing: {current_structure}====")

        # Load segmentation
        current_seg = load_label_image(os.path.join(input_path, '3D'), fname=structure_dict[current_structure])

        points = points.astype(int)
        reads_assignment = current_seg[points[:, 2], points[:, 0], points[:, 1]]

        reads_info = pd.DataFrame({'x':points[:, 0], 'y':points[:, 1], 'z':points[:, 2], 'cell_label':reads_assignment})
        reads_info = reads_info.astype(np.int32)
        reads_info['orig_index'] = reads_info['cell_label'] - 1
        reads_info['gene'] = bases
        
    for current_cluster in sorted(gene_cluster_df[plot_field].unique()):
        if current_cluster == 999:
            continue
            
        print(f"Current field: {plot_field}, Current cluster: {current_cluster}")
        current_gene_list = gene_cluster_df.loc[gene_cluster_df[plot_field] == current_cluster, 'gene'].to_list()
        current_reads_df = reads_info.loc[reads_info['gene'].isin(current_gene_list), :]

        # remove unassigned reads
        current_reads_df = current_reads_df.loc[current_reads_df['cell_label'] != 0, :]

        plt.figure(figsize=(bg.shape[0]/1000, bg.shape[1]/1000), dpi=1000)
        plt.imshow(bg)
        plt.plot(current_reads_df['y'], current_reads_df['x'], '.', color=color_dict[current_cluster], markersize=.3, markeredgewidth=0.0)
        plt.axis('off')
        plt.tight_layout(pad=0)
        # plt.show()
        current_fig_path = f"{fig_out_path}/cluster_{current_cluster}.tiff"
        plt.savefig(current_fig_path, dpi=1000, bbox_inches='tight', pad_inches=0)
        plt.clf()
        plt.close()
        
    plt.figure(figsize=(bg.shape[0]/1000, bg.shape[1]/1000), dpi=1000)
    plt.imshow(bg)
    
    for current_cluster in sorted(gene_cluster_df[plot_field].unique()):
        if current_cluster == 999:
            continue
            
        current_gene_list = gene_cluster_df.loc[gene_cluster_df[plot_field] == current_cluster, 'gene'].to_list()
        current_reads_df = reads_info.loc[reads_info['gene'].isin(current_gene_list), :]

        # remove unassigned reads
        current_reads_df = current_reads_df.loc[current_reads_df['cell_label'] != 0, :]

        plt.plot(current_reads_df['y'], current_reads_df['x'], '.', color=color_dict[current_cluster], markersize=.3, markeredgewidth=0.0)
    plt.axis('off')
    plt.tight_layout(pad=0)
    # plt.show()
    current_fig_path = f"{fig_out_path}/cluster_all.tiff"
    plt.savefig(current_fig_path, dpi=1000, bbox_inches='tight', pad_inches=0)
    plt.clf()
    plt.close()