In [2]:
import os
from pathlib import Path
from tqdm import tqdm

import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
plt.rcParams.update({
    "pgf.texsystem": "xelatex",
    'font.family': 'serif',
    'text.usetex': False,
    'pgf.rcfonts': False,
    'figure.dpi': 300,
})

from skimage.io import imread
from skimage.io import imsave

In [3]:
BASE_DIR = Path(r"F:\spatial_data\processed")
RUN_ID = '20240731_TNBC_BZ02_CA2_TCR_immune_marker'
src_dir = BASE_DIR / f'{RUN_ID}_processed'
stc_dir = src_dir / 'stitched'
read_dir = src_dir / 'readout'
seg_dir = src_dir / 'segmented'
check_dir = seg_dir / 'check_cell'
check_dir.mkdir(exist_ok=True)

1. extract Actb spots from two approach
2. find subtraction spots, get some coordinates
3. find those coordinates example in raw image(stitched)

# Load data

In [4]:
# Load data of intensity and mapped genes
# intensity = pd.read_csv(read_dir/'intensity_deduplicated.csv', index_col=0)
# print(len(intensity))
# intensity.head()

# mapped_genes = pd.read_csv(read_dir / 'mapped_genes.csv', index_col=0)
# print(len(mapped_genes))
# mapped_genes.head()

In [5]:
mapped_genes = pd.read_csv(read_dir / 'mapped_genes.csv', index_col=0)
print(len(mapped_genes))
mapped_genes.head()

2055145


Unnamed: 0,Y,X,Gene
1778,5838,8287,EPCAM
1779,7532,6467,EPCAM
1780,13767,12485,EPCAM
1781,4422,10847,EPCAM
1783,12724,13546,EPCAM


In [6]:
dapi_centroids = pd.read_csv(seg_dir/'dapi_centroids.csv', header=None)
dapi_centroids.columns = [['Y', 'X']]
expression_matrix = pd.read_csv(seg_dir/'expression_matrix.csv', index_col=0)
clonotype_proportion = pd.read_csv(seg_dir/'clonotype_proportion.csv', index_col=0)

# Intensity analysis

In [7]:
# # crosstalk elimination
# intensity['B'] = intensity['B'] - intensity['G'] * 0.25
# intensity['B'] = np.maximum(intensity['B'], 0)

# # Scale
# intensity['Scaled_R'] = intensity['R']
# intensity['Scaled_Ye'] = intensity['Ye']
# intensity['Scaled_G'] = intensity['G'] * 2.5
# intensity['Scaled_B'] = intensity['B']

# # threshold by intensity
# intensity['sum'] = intensity['Scaled_R'] + intensity['Scaled_Ye'] + intensity['Scaled_B']
# intensity['sum_G'] = intensity['Scaled_R'] + intensity['Scaled_Ye'] + intensity['Scaled_B'] + intensity['Scaled_G']

In [8]:
# plt.figure(figsize=(20,10))
# plt.hist(intensity['sum'], bins=200, range=[0, 10000])
# plt.show()

# Snr analysis

In [9]:
# intensity.loc[intensity['snr']>20, 'snr'] = 20
# plt.figure(figsize=(20,10))
# plt.hist(intensity.loc[:, 'snr'], bins=1000)
# plt.show()

In [10]:
# plt.figure(figsize=(20,10))
# plt.hist(intensity['sum_G'], bins=200, range=[100,5000])
# plt.show()

In [11]:
# data = intensity[intensity['sum']>750]
# data = data[data['label']>15]

# label_counts = data['label'].value_counts()
# label_percentages = label_counts / label_counts.sum()

# plt.figure(figsize=(8, 8))  # 设置图形的大小
# plt.pie(label_percentages, labels=label_counts.index, autopct='%1.1f%%', startangle=90)
# plt.axis('equal')
# plt.legend()
# plt.show()

# Save spots raw images

In [12]:
channels = ['cy5','cy3','FAM','TxRed','DAPI']
cells = {}
cell_type_list = sorted(clonotype_proportion['Determined_Clonotype'].unique().tolist())

for cell_type in cell_type_list:
    celltype_df = clonotype_proportion[clonotype_proportion['Determined_Clonotype'] == cell_type]
    celltype_index = celltype_df.sample(n=min(len(celltype_df),20)).index
    cells[cell_type] = dapi_centroids.loc[celltype_index]

cut_size = 200
for chn in channels:
    im = imread(stc_dir / f'cyc_1_{chn}.tif')
    for cell_index, celltype in enumerate(cells.keys(), start=1):
        for index, centroid in cells[celltype].iterrows():
            x_cen = int(centroid['X'])
            y_cen = int(centroid['Y'])
            x_st = max(0, x_cen-cut_size//2)
            y_st = max(0, y_cen-cut_size//2)
            x_ed = min(im.shape[1], x_cen+cut_size//2)
            y_ed = min(im.shape[0], y_cen+cut_size//2)
            # save gene spots in this area:
            if not os.path.exists(check_dir / f'{cell_index}_{celltype}_{index}_crop.csv'): 
                spots = mapped_genes[(mapped_genes['Y']>y_st)&(mapped_genes['Y']<y_ed)&(mapped_genes['X']>x_st)&(mapped_genes['X']<x_ed)]
                spots.loc[:, 'X'] = spots.loc[:, 'X'] - x_st
                spots.loc[:, 'Y'] = spots.loc[:, 'Y'] - y_st
                spots.to_csv(check_dir / f'{cell_index}_{celltype}_{index}_crop.csv', index=False)
            im_cut = im[max(0, y_st): min(im.shape[0], y_ed), 
                        max(0, x_st): min(im.shape[1], x_ed)]
            imsave(check_dir / f'{cell_index}_{celltype}_{index}_{chn}_crop.tif', im_cut, check_contrast=False)

In [84]:
plt.rcParams['axes.facecolor'] = 'black'
plt.rcParams['figure.facecolor'] = 'black'
plt.rcParams['axes.edgecolor'] = 'white'
plt.rcParams['axes.labelcolor'] = 'white'
plt.rcParams['xtick.color'] = 'white'
plt.rcParams['ytick.color'] = 'white'
plt.rcParams['grid.color'] = 'gray'
plt.rcParams['text.color'] = 'white'

# Function to extract the central 10x10 region and the diagonal from it
def extract_center_and_diagonal(channel, center_size=11, position='sequence'):
    start_idx = (channel.shape[0] - center_size) // 2
    end_idx = start_idx + center_size
    center_region = channel[start_idx:end_idx, start_idx:end_idx]
    if position == 'sequence': return np.diagonal(center_region)
    elif position == 'nonsequence': return np.diagonal(np.fliplr(center_region))

def one_cell(red, yellow, blue, green, dapi, spots, 
             spot_type_colors, axes, index, center_size=11):
    # Convert uint16 data to float32 to avoid overflow and for processing
    red = (red - np.mean(red)).astype(np.float32)
    yellow = (yellow - np.mean(yellow)).astype(np.float32)
    blue = (blue - np.mean(blue)).astype(np.float32)
    green = ((green - np.mean(green))*2.5).astype(np.float32)
    
    # Decompose the yellow channel into red and green
    red_effective = np.clip(red + 0.5 * yellow, 0, 65535)
    green_effective = np.clip(green + 0.5 * yellow, 0, 65535)
    blue_effective = np.clip(blue, 0, 65535)  # Blue channel remains unchanged
    rgb_image = np.stack((red_effective/3000, green_effective/3000, blue_effective/3000), axis=-1)
    rgb_image = np.clip(rgb_image, 0, 1)

    # Normalize DAPI and create an overlay with 50% transparency
    dapi_normalized = (dapi - np.min(dapi)) / (np.max(dapi) - np.min(dapi))
    dapi_overlay = np.stack((dapi_normalized, dapi_normalized, dapi_normalized), axis=-1) * 0.5
    # Merge DAPI with RGB using transparency
    merged_image = rgb_image + dapi_overlay
    rgb_image = np.clip(merged_image, 0, 1)  # Ensure values stay within valid range

    # Calculate the starting index
    start_index = (rgb_image.shape[0] - center_size) // 2
    end_index = start_index + center_size
    center_region = rgb_image[start_index:end_index, start_index:end_index]

    def spots_on_img(ax_im, ax_imsub, ax_legend, spot_type_color, spots=spots):
        # Setup the figure and subplots
        ax_im.imshow(rgb_image)
        ax_im.set_title(index)
        for gene, color in spot_type_color.items():
            x = spots[spots['Gene'] == gene]['X']
            y = spots[spots['Gene'] == gene]['Y']
            ax_im.plot(x, y, 'x', color=color, label=gene, markersize=6, markerfacecolor='none', markeredgewidth=3)  # Plotting with labels
        ax_imsub.imshow(center_region)
        ax_imsub.set_title(f'Center {center_size}x{center_size} Region')
        spots = spots[(spots['X'] > start_index) & (spots['X'] < end_index) & (spots['Y'] > start_index) & (spots['Y'] < end_index)]
        for gene, color in spot_type_color.items():
            x = spots[spots['Gene'] == gene]['X'] - start_index
            y = spots[spots['Gene'] == gene]['Y'] - start_index
            ax_imsub.plot(x, y, 'x', color=color, label=gene, markersize=6, markerfacecolor='none', markeredgewidth=3)  # Plotting with labels
        # ax legend
        handles, labels = ax_imsub.get_legend_handles_labels()
        ax_legend.legend(handles, labels, loc='center left', bbox_to_anchor=(-0.5, 0.5))
        ax_legend.axis('off')
    for i, spot_type_color in enumerate(spot_type_colors):
        spots_on_img(ax_im=axes[i*3], ax_imsub=axes[i*3+1], ax_legend=axes[i*3+2], 
                     spot_type_color=spot_type_color)    
    

In [85]:
immune_marker = {'CD3D': '#1f77b4',
 'CD4': '#ff7f0e',
 'CD8A': '#2ca02c',
 'FOXP3': '#d62728',
 'PDCD1': '#9467bd',
 'CTLA4': '#8c564b',
 'CXCL13': '#e377c2',
 'CSF3R': '#7f7f7f',
 'GZMB': '#bcbd22',
 'GZMA': '#17becf',
 'GZMK': '#393b79',
 'ZNF683': '#5254a3',
 'NCAM1': '#6b6ecf',
 'CD79A': '#8c6d31',
}
non_immune = {
    'KRT7': '#9c9ede',
 'KRT19': '#637939',
 'EPCAM': '#8ca252',
 'PECAM1': '#b5cf6b',
 'ACTA2': '#cedb9c',
} 
TCR = {
'clonotype1_TRB': '#bd9e39',
 'clonotype2_TRB': '#e7ba52',
 'clonotype3_TRB': '#e7cb94',
 'clonotype4_TRB': '#843c39',
 'clonotype5_TRB': '#ad494a',
 'clonotype1_TRA': '#d6616b',
 'clonotype2_TRA-1': '#e7969c',
 'clonotype2_TRA-2': '#7b4173',
 'clonotype3_TRA': '#a55194',
 'clonotype4_TRA': '#ce6dbd',
 'clonotype5_TRA': '#de9ed6'
}

In [86]:
import seaborn as sns

# Define different tab palettes
tab20_colors = sns.color_palette("tab20", 20)

# Assign colors to the three groups
immune_marker_colors = {gene: color for gene, color in zip(immune_marker.keys(), tab20_colors)}
non_immune_colors = {gene: color for gene, color in zip(non_immune.keys(), tab20_colors)}
TCR_colors = {gene: color for gene, color in zip(TCR.keys(), tab20_colors)}


In [87]:
sns.color_palette('tab20',10)

In [88]:
color_dicts = [immune_marker_colors, non_immune_colors, TCR_colors]

for cell_type, gene in tqdm(enumerate(cell_type_list, start=1), total=len(cell_type_list), desc='Drawing_spots'):
    index_list = list(set([_.split('_')[-3] for _ in os.listdir(check_dir) if _.split('_')[0]==f'{cell_type}' and _.endswith('.tif')]))
    points = {}
    for index in index_list:
        red = imread(check_dir / f'{cell_type}_{gene}_{index}_cy5_crop.tif')
        yellow = imread(check_dir / f'{cell_type}_{gene}_{index}_TxRed_crop.tif')
        blue = imread(check_dir / f'{cell_type}_{gene}_{index}_FAM_crop.tif')
        green = imread(check_dir / f'{cell_type}_{gene}_{index}_cy3_crop.tif')
        dapi = imread(check_dir / f'{cell_type}_{gene}_{index}_DAPI_crop.tif')
        spots = pd.read_csv(check_dir / f'{cell_type}_{gene}_{index}_crop.csv')
        points[index] = {'red':red, 'green':green, 'blue':blue, 'yellow':yellow, 'dapi':dapi, 'spots':spots}
    # points
    width_ratios = [1, 1, 0.8] * len(color_dicts)
    fig, axes = plt.subplots(ncols=len(width_ratios), nrows=len(points), figsize=(4*sum(width_ratios), 4*len(points)), gridspec_kw={'width_ratios': width_ratios})
    for i, (index, point) in enumerate(points.items()):
        if len(points)==1: ax_tmp = axes
        else: ax_tmp = axes[i,:]
        red = point['red']
        green = point['green']
        blue = point['blue']
        yellow = point['yellow']
        dapi = point['dapi']
        spots = point['spots']
        one_cell(red,yellow,blue,green,dapi,spots,spot_type_colors=color_dicts,axes=ax_tmp,index=index,center_size=100)
    plt.tight_layout()

    plt.savefig(check_dir / f'.overall_{cell_type}.png', bbox_inches='tight', dpi=300)
    plt.close(fig)

Drawing_spots: 100%|██████████| 5/5 [03:07<00:00, 37.51s/it]


# legacy

In [18]:
# merged_df = pd.merge(spot, intensity, on=['Y', 'X'], how='inner')
# merged_df
# spot['intensity'] = intensity['sum_G']
# random = merged_df[merged_df['Gene'] == 'PRISM_1'].sample(n=1)
# #random = random[random['sum_G']>4000].sample(n=1)
# random
# Gene_list = ['PRISM_48','PRISM_54']
# x_start = 23167
# y_start = 31839
# x_width = 30
# y_width = 30
# channels = {'cy5','cy3','FAM','TxRed'}

# for chn in channels:
#     for i in range(1,2):
#         im = imread(stc_dir/f'cyc_{i}_{chn}.tif')
#         im = im[y_start-15:y_start+15,x_start-15:x_start+15]
#         imsave(check_dir/f'cyc_{i}_{chn}_crop.tif',im)


# seq_old = pd.read_csv(read_old_dir/'ref_checked.csv')
# seq_new = pd.read_csv(read_new_dir/'ref_checked.csv')
# seq_old[(seq_old['X']==52600)&(seq_old['Y']==25565)]
# seq_new[(seq_new['X']==52600)&(seq_new['Y']==25565)]
# Actb_new


In [19]:
# import scanpy as sc
# from scipy.signal import argrelextrema
# from scipy.signal import find_peaks
# import matplotlib.gridspec as gridspec
# import pickle

# # Preprocessing
# def adata_filter(adata, min_genes, min_counts, max_counts, min_cells):
#     sc.pp.filter_cells(adata, min_genes=min_genes)
#     sc.pp.filter_cells(adata, min_counts=min_counts)
#     sc.pp.filter_cells(adata, max_counts=max_counts)
#     sc.pp.filter_genes(adata, min_cells=min_cells)
#     return adata


# def QC_plot(adata, hue, min_counts='nan', max_counts='nan', min_genes='nan', min_cells='nan'):
#     g = sns.JointGrid(
#         data=adata.obs,
#         #x="total_counts",
#         #y="n_genes_by_counts",
#         height=5,
#         ratio=2,
#         hue=hue,
#     )
    
#     g.plot_joint(sns.scatterplot, s=40, alpha=0.3)
#     g.plot_marginals(sns.kdeplot)
#     g.set_axis_labels("total_counts", "n_genes_by_counts", fontsize=16)
#     g.fig.set_figwidth(6)
#     g.fig.set_figheight(6)
#     g.fig.suptitle("QC_by_{}, cell_num={}, gene_num={}\n\
#                    min_counts={}, max_counts={}, min_genes={}, min_cells={}\
#                    \n\n\n\n\n".format(hue,len(adata),len(adata.var.index),min_counts, max_counts, min_genes, min_cells))
#     plt.show()


# def general_preprocess(adata, min_genes=2, min_counts=5, max_counts=200, min_cells=3, auto_filter=False, hue='dataset'):
#     # Calculate QC metrics
#     sc.pp.calculate_qc_metrics(adata, percent_top=None, inplace=True)

#     fig = plt.figure(figsize=(12, 4))  # You can adjust the overall figure size here
#     gs = gridspec.GridSpec(1, 3)
#     # Plot top 20 most expressed genes
#     ax1 = fig.add_subplot(gs[0, 0])
#     sc.pl.highest_expr_genes(adata, n_top=10, ax=ax1, show=False)
#     # distribution of cell counts
#     ax2 = fig.add_subplot(gs[0, 1:3])
#     counts = adata.obs.total_counts
#     sns.histplot(counts, stat='count', ax=ax2,
#                 bins=150, edgecolor='white', linewidth=0.5, alpha=1, 
#                 kde=True, line_kws=dict(color='black', alpha=0.7, linewidth=1.5, label='KDE'), kde_kws={'bw_adjust': 1},
#                 )
#     y = ax2.get_lines()[0].get_ydata()
#     maxima = [float(_/len(y)*(max(counts)-min(counts))+min(counts)) for _ in argrelextrema(-np.array(y), np.less)[0]]
#     print(f'maxima: {maxima}')
#     plt.tight_layout()
#     plt.show()
#     plt.close(fig=fig)

    
    
#     # plot origin and filtered in a combined figure
#     fig = plt.figure(figsize=(12, 6))
#     gs = gridspec.GridSpec(3, 6)
#     categories = adata.obs[hue].unique()

#     ax_1_scatter = fig.add_subplot(gs[1:3, 0:2])
#     sns.scatterplot(x=adata.obs.total_counts, y=adata.obs.n_genes_by_counts,
#                     hue=adata.obs[hue], ax=ax_1_scatter, )
    
#     ax_1_count = fig.add_subplot(gs[0:1, 0:2])
#     for category in categories:
#         subset = adata[adata.obs[hue] == category]
#         sns.kdeplot(subset.obs.total_counts, ax=ax_1_count)
#     ax_1_count.xaxis.set_visible(False)
#     ax_1_count.yaxis.set_visible(False)
#     ax_1_count.grid(False)

#     ax_1_gene = fig.add_subplot(gs[1:3, 2:3])
#     for category in categories:
#         subset = adata[adata.obs[hue] == category]
#         sns.kdeplot(y=subset.obs.n_genes_by_counts, ax=ax_1_gene)
#     ax_1_gene.xaxis.set_visible(False)
#     ax_1_gene.yaxis.set_visible(False)
#     ax_1_gene.grid(False)

#     ax_1_count.set_title("QC_by_{}, cell_num={}, gene_num={}\n\
#                          min_counts={}, max_counts={}, \n\
#                          min_genes={}, min_cells={}\
#                          ".format(hue, len(adata), len(adata.var.index), 'nan', 'nan', 'nan', 'nan'))

#     origin_cell_num = len(adata)
#     min_counts = int(maxima[0]) if auto_filter else min_counts
#     max_counts = int(np.percentile(counts, 99.9)) if auto_filter else max_counts
#     adata = adata_filter(adata, min_genes, min_counts, max_counts, min_cells)
#     filtered_cell_num = len(adata)

#     ax_2_scatter = fig.add_subplot(gs[1:3, 3:5])
#     sns.scatterplot(x=adata.obs.total_counts, y=adata.obs.n_genes_by_counts,
#                     hue=adata.obs[hue], ax=ax_2_scatter, )
    
#     ax_2_count = fig.add_subplot(gs[0:1, 3:5])
#     categories = adata.obs[hue].unique()
#     for category in categories:
#         subset = adata[adata.obs[hue] == category]
#         sns.kdeplot(subset.obs.total_counts, ax=ax_2_count)
#     ax_2_count.xaxis.set_visible(False)
#     ax_2_count.yaxis.set_visible(False)
#     ax_2_count.grid(False)

#     ax_2_gene = fig.add_subplot(gs[1:3, 5:6])
#     for category in categories:
#         subset = adata[adata.obs[hue] == category]
#         sns.kdeplot(y=subset.obs.n_genes_by_counts, ax=ax_2_gene)
#     ax_2_gene.xaxis.set_visible(False)
#     ax_2_gene.yaxis.set_visible(False)
#     ax_2_gene.grid(False)

#     ax_2_count.set_title("QC_by_{}, cell_num={}, gene_num={}\n\
#                          min_counts={}, max_counts={}, \n\
#                          min_genes={}, min_cells={}\
#                          ".format(hue, len(adata), len(adata.var.index), min_counts, max_counts, min_genes, min_cells))
    
#     plt.tight_layout()
    
#     plt.show()
#     plt.close(fig=fig)

#     # plot origin

#     # QC_plot(adata, hue='dataset')

#     # # plot filtered
#     # min_counts = int(maxima[0]) if auto_filter else min_counts
#     # max_counts = int(np.percentile(counts, 99.9)) if auto_filter else max_counts
#     # adata = adata_filter(adata, min_genes, min_counts, max_counts, min_cells)
#     # filtered_cell_num = len(adata)
#     # QC_plot(adata, hue='dataset', min_genes=min_genes, min_counts=min_counts, max_counts=max_counts, min_cells=min_cells)
#     return adata, origin_cell_num, filtered_cell_num


# # g = sns.JointGrid(
# #     data=adata.obs,
# #     x="total_counts",
# #     y="n_genes_by_counts",
# #     height=5,
# #     ratio=2,
# #     hue=hue,
# # )

# # g.plot_joint(sns.scatterplot, s=40, alpha=0.3)
# # g.plot_marginals(sns.kdeplot)
# # g.set_axis_labels("total_counts", "n_genes_by_counts", fontsize=16)
# # g.fig.set_figwidth(6)
# # g.fig.set_figheight(6)
# # g.fig.suptitle("QC_by_{}, cell_num={}, gene_num={}\n\
# #                 min_counts={}, max_counts={}, min_genes={}, min_cells={}\
# #                 \n\n\n\n\n".format(hue,len(adata),len(adata.var.index),min_counts, max_counts, min_genes, min_cells))
# # plt.show()


# def preprocess_of_UMAP(adata):
#     # Normalization scaling
#     sc.pp.normalize_total(adata)
#     #sc.pp.log1p(adata)
#     # Scale data to unit variance and zero mean
#     #sc.pp.regress_out(adata, ["total_counts"])
#     sc.pp.scale(adata)
#     return adata


# def save_pos_on_UMAP(adata, out_dir):
#     try:
#         adata_coor = pd.DataFrame(
#             adata.obsm["X_umap"], columns=["Coor_X", "Coor_Y"], index=adata.obs.index
#         )
#         df = pd.concat(
#             [
#                 adata_coor["Coor_X"],
#                 adata_coor["Coor_Y"],
#                 pd.DataFrame(adata.obs.index),
#                 adata.obs.leiden,
#             ],
#             axis=1,
#         )
#         df.to_csv(out_dir)
#     except KeyError:
#         print('X_umap not found, please perform umap first.')


# def save_cell_cluster(
#     adata,
#     out_path,
#     st_point,
#     cell_num,
#     name="leiden",
# ):
#     raw_clu = dict(adata.obs[name])
#     cluster = dict()
#     for cell_num in raw_clu.keys():
#         cluster[cell_num] = -1

#     for cell in raw_clu.keys():
#         cluster[int(cell) - st_point] = int(raw_clu[cell])

#     with open(out_path, "wb") as handle:
#         pickle.dump(cluster, handle)


# def UMAP_genes_plot(adata, FOI='', size=0.1, save=False, out_path='./', datatype='direct', dataset=[],gene_list=['Slc17a7','Gad1','Gad2','Snap25']):
#     n_pcs = len(adata.uns['pca']['variance'])
#     n_neighbors = adata.uns['neighbors']['params']['n_neighbors']
#     resolution = adata.uns['leiden']['params']['resolution']
#     # Plot Gene distribution
#     ncols = int(-(-len(gene_list)**(1/2)//1))
#     nrows = -(-len(gene_list)//ncols)
#     fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4))
#     for pos, gene_name in enumerate(gene_list):
#         sc.pl.umap(
#             adata[adata.obs.dataset.isin(dataset)],
#             size=size,
#             color=gene_name,
#             legend_fontweight=100,
#             legend_fontsize=20,
#             ax=ax[pos // ncols][pos % ncols],
#             show=False,
#             vmax=5,
#             vmin=0,
#             cmap='inferno_r'
#         )
#         ax[pos // ncols][pos % ncols].set_xticklabels("")
#         ax[pos // ncols][pos % ncols].set_yticklabels("")

#     fig.suptitle(
#         "{}\nexp:{}\nUMAP:{}\n".format(
#         f"{FOI}_{datatype}_{dataset}_UMAP",
#         f"cell_num={len(adata)}",
#         f"n_neighbors={n_neighbors}, n_pcs={n_pcs}, resolution={resolution}"),
#         fontsize=20,
#     )
#     plt.tight_layout()

#     if save:
#         plt.savefig(f"{out_path}/{FOI}_{datatype}_UMAP_genes.png")
#         plt.close()
#     else:
#         plt.show()


# def UMAP_leiden_plot(adata, FOI='', color='leiden', save=False, out_path='./',dpi=300, datatype='direct', DOI=['PRISM3D'], legend_loc='on data', palette=False, size=1):
#     n_pcs = len(adata.uns['pca']['variance'])
#     n_neighbors = adata.uns['neighbors']['params']['n_neighbors']
#     resolution = adata.uns['leiden']['params']['resolution']
#     # Plot Cluster
#     fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(20, 10))
    
#     if palette:
#         sc.pl.umap(
#             adata[adata.obs.dataset.isin(DOI)],
#             size=size, color=color, palette=palette, 
#             legend_loc=legend_loc, legend_fontsize=7,
#             ax=ax[0], show=False,
#         )

#         sc.pl.umap(
#             adata,
#             size=size, color="dataset",
#             legend_fontweight=100, legend_fontsize=20,
#             ax=ax[1], show=False,
#         )
#     else:
#         sc.pl.umap(
#             adata[adata.obs.dataset.isin(DOI)],
#             size=size, color=color,
#             legend_loc=legend_loc, legend_fontsize=7,
#             ax=ax[0], show=False,
#         )

#         sc.pl.umap(
#             adata,
#             size=size, color="dataset",
#             legend_fontweight=100, legend_fontsize=20,
#             ax=ax[1], show=False,
#         )
#     fig.suptitle(
#         "{}\nexp:{}\nUMAP:{}\n".format(
#         f"{FOI}_{datatype}_cluster",
#         f"cell_num={len(adata)}",
#         f"n_neighbors={n_neighbors}, n_pcs={n_pcs}, resolution={resolution}"),
#         fontsize=20,
#     )
#     plt.tight_layout()
#     if save:
#         if out_path.endswith('.png'):    
#             plt.savefig(f"{out_path}", bbox_inches = 'tight', dpi=dpi)
#         else:
#             plt.savefig(f"{out_path}", bbox_inches = 'tight')
#         plt.close()
#     else:
#         plt.show()


# def leiden_QC_plot(adata, color='leiden'):
#     # cluster QC
#     g = sns.JointGrid(
#         data=adata.obs,
#         x="total_counts",
#         y="n_genes_by_counts",
#         height=5,
#         ratio=2,
#         hue=color,
#     )
#     g.plot_joint(sns.scatterplot, s=40, alpha=0.3)
#     g.plot_marginals(sns.kdeplot)
#     g.set_axis_labels("total_counts", "n_genes_by_counts", fontsize=8)
#     g.fig.set_figwidth(3)
#     g.fig.set_figheight(3)
#     plt.show()

In [20]:
# def threshold_in_cluster(adata, marker_gene=[], thre_gene=['AFP','GPC3','ACTA2','PECAM1'], type_name=[], cluster_dict={}):
# # for cluster_gene in cluster_to_filter:
#     thre_min = [True] * len(marker_gene) + [False] * len(thre_gene)
#     gene_list = marker_gene + thre_gene
#     minima_dict = {}
#     for _ in gene_list:
#         minima_dict[_] = ''

#     cluster_list_temp=[]
#     for _ in cluster_dict.keys():
#         for name in type_name:
#             if name in _:
#                 cluster_list_temp += [str(_) for _ in cluster_dict[_]]

#     cluster = adata[adata.obs.leiden.isin(cluster_list_temp)]

#     fig, ax = plt.subplots(nrows=1,ncols=len(thre_min),figsize=(24, 4))
#     for i, gene in enumerate(gene_list):
#         a = [float(_) for _ in cluster[:, gene].X]
#         sns.histplot(a, bins=20, stat='density', alpha= 1, kde=True,
#                     edgecolor='white', linewidth=0.5,
#                     log=True, 
#                     ax=ax[i],
#                     line_kws=dict(color='black', alpha=0.7, linewidth=1.5, label='KDE'))
#         ax[i].get_lines()[0].set_color('red') # edit line color due to bug in sns v 0.11.0
#         ax[i].set_xlabel(gene)

#         y = ax[i].get_lines()[0].get_ydata()
#         minima_dict[gene] = [float(_/len(y)*(max(a)-min(a))+min(a)) for _ in argrelextrema(np.array(y), np.less)[0]]
#         # print(f'{gene}_minima: {minima_dict[gene]}')
#         fig.subplots_adjust(hspace=0.4)
#         fig.subplots_adjust(wspace=0.4)
#         fig.suptitle(f'distribution of cluster, marker gene={marker_gene}')
#     plt.show()

#     cluster.obs['tmp_leiden'] = ['-1']*len(cluster)
#     for _, gene in enumerate(gene_list):
#         minima = minima_dict[gene]
#         while True:
#             if len(minima) == 0:
#                 minima = [0]
#                 break
#             if minima[0] > 1 and gene != 'CPA3':
#                 minima[0] = 0
#                 break
#             if minima[0] < -1 and gene != 'CPA3':
#                 minima.pop(0)
#                 continue
#             break
        
#         print(f'{gene}_thre: {minima[0]}')

#         if thre_min[_]:
#             tmp = cluster[cluster[:, gene].X > minima[0]]
#             cluster.obs['tmp_leiden'][tmp.obs.index] = ['1']*len(tmp)
#         else:
#             tmp = cluster[cluster[:, gene].X > minima[0]]
#             cluster.obs['tmp_leiden'][tmp.obs.index] = ['-1']*len(tmp)
    
#     tmp = cluster[cluster.obs['tmp_leiden']=='-1']
#     adata.obs['tmp_leiden'][tmp.obs.index] = ['-2']*len(tmp)

#     cell_to_plot = len(cluster[cluster.obs['tmp_leiden']=='1'])
#     print(f'marker_gene={marker_gene}, {cell_to_plot} cells of {len(cluster)} cells left\n')
#     return adata


# def collect_liver(combine_adata_st, tissue_obs='tissue', in_out_leiden='tmp_leiden'):
#     other_cluster = combine_adata_st[combine_adata_st.obs[in_out_leiden]=='-2']
#     liver = other_cluster[other_cluster.obs[tissue_obs] == "liver"]
#     combine_adata_st.obs[in_out_leiden] = list(combine_adata_st.obs[in_out_leiden])
#     combine_adata_st.obs[in_out_leiden][liver.obs.index] = ["-1"] * len(liver)
#     return combine_adata_st

In [21]:
# matrix_1 = pd.read_csv(read_dir/'matrix_Hdc+_3.csv')
# matrix_1 = matrix_1.transpose()
# matrix_1.columns = matrix_1.iloc[0]
# matrix_1 = matrix_1[1:]
# matrix_1

In [22]:
# # load expression matrix
# adata = sc.AnnData(matrix_1)
# adata.var.index = adata.var.index.str.upper()
# adata.obs['dataset'] = ["PRISM3D"] * len(adata)
# adata.raw = adata

In [23]:
# # preprocess of UMAP
# adata = preprocess_of_UMAP(adata)

# # compute pca
# sc.tl.pca(adata)
# sc.pl.pca_variance_ratio(adata, log=False)

In [24]:
# # select the num of pc
# n_pcs=20
# sc.tl.pca(adata, n_comps=n_pcs)  

In [25]:
# # Run UMAP
# sc.pp.neighbors(adata, n_neighbors=50, n_pcs=n_pcs)
# sc.tl.umap(adata)

In [26]:
# # Run Leiden cluster
# sc.tl.leiden(adata, resolution=1)

In [27]:
# a = [len(adata[adata.obs.leiden == _]) for _ in adata.obs.leiden.unique()]
# fig, ax = plt.subplots(figsize=(7,3))
# sns.histplot(a, bins=30, stat='count', alpha=1, kde=True,
#             edgecolor='white', linewidth=0.5,
#             # log=True, 
#             ax=ax,
#             line_kws=dict(color='black', alpha=0.7, linewidth=1.5, label='KDE'),
#             # binrange=[0,100]
#             )
# plt.show()

# adata_thre = adata[adata.obs.leiden.isin([_ for _ in adata.obs.leiden.unique() if len(adata[adata.obs.leiden == _]) > 100])]

In [28]:
# UMAP_leiden_plot(adata, FOI=os.path.split(src_dir)[-1], color='leiden', save=False, out_path='', datatype='direct',size=100)
# QC_plot(adata, hue='leiden')

In [29]:
# UMAP_genes_plot(adata, FOI=os.path.split(src_dir)[-1], save=False, datatype='direct', dataset=["PRISM3D"], size=40, gene_list = ['SNAP25','SLC17A7','GAD1','GAD2','SLC1A3','AQP4','CX3CR1','PLP1','HDC','DRD1','SLC18A2','FOXJ1','SST','VIP','PVALB','LAMP5','NPY'])


In [30]:
# adata.var_names

In [31]:
# adata.var_names[adata.var_names=='SNAP25']