# Reads assignment for a single FOV

In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
from skimage.filters import threshold_otsu, threshold_yen, gaussian
from skimage.measure import regionprops
from skimage.morphology import binary_dilation, disk
from skimage.segmentation import watershed
import scipy.ndimage as ndi
from anndata import AnnData
from tifffile import imread, imwrite
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# IO path 
base_path = '/home/unix/jiahao/wanglab/Data/Analyzed/2023-10-01-Jiahao-Test/mAD_64/02_pp/'
out_path = os.path.join(base_path, 'expr')
if not os.path.exists(out_path):
    os.mkdir(out_path)
    
image_path = os.path.join(base_path, 'images/merged/raw_max')
signal_path = os.path.join(base_path, 'signal')

In [None]:
# Load reads 
current_fov_id = 'tile_1'
current_reads_df = pd.read_csv(os.path.join(signal_path, f'{current_fov_id}_goodSpots.csv'))
current_reads_df['x'] = current_reads_df['x'] - 1
current_reads_df['y'] = current_reads_df['y'] - 1
current_reads_df['z'] = current_reads_df['z'] - 1
current_reads_df

In [None]:
# Load overlay image 
overlay = imread(os.path.join(image_path, 'overlay.tif'))
overlay.shape

# Load dapi label
dapi = imread(os.path.join(image_path, 'dapi_label.tif'))
dapi.shape

# Get cell locations 
centroids = []
areas = []

for i, region in enumerate(tqdm(regionprops(dapi))):
    centroids.append(region.centroid)
    areas.append(region.area)

centroids = np.array(centroids)
areas = np.array(areas)
sns.displot(areas, bins=50)

In [None]:
# Filter nuclei by area
lower_bd = 1000
upper_bd = 15000

# plot threshold
fig, ax = plt.subplots()
sns.histplot(areas)
ax.axvline(lower_bd, c='r')
ax.axvline(upper_bd, c='r')

# plt.savefig(os.path.join(fig_path, 'reads_filtering_threshold.pdf'))
plt.show()

In [None]:
%%time
# Segmentation

print("Gaussian & Thresholding")
overlay_blurred = gaussian(overlay, 5)
threhold = threshold_otsu(overlay_blurred)
overlay_bw = overlay_blurred > threhold
overlay_bw = binary_dilation(overlay_bw, footprint=disk(10))

print("Assigning markers")
centroids = centroids.astype(int)
markers = np.zeros(overlay_bw.shape, dtype=np.uint8)
for i in range(centroids.shape[0]):
    x, y = centroids[i, :]
    if x < overlay_bw.shape[0] and y < overlay_bw.shape[1]:
        markers[x-1, y-1] = 1
markers = ndi.label(markers)[0]

print("Watershed")
labels = watershed(overlay_bw, markers, mask=overlay_bw, watershed_line=True)
print(f"Labeled {len(np.unique(labels)) - 1} cells")

plt.figure(figsize=(10,20))
plt.imshow(labels)

print(f"Saving files to {image_path}")
imwrite(os.path.join(image_path,  "labeled_cells.tif"), labels.astype(np.uint16))

In [None]:
figsize = (np.floor(dapi.shape[1] / 1000 * 5), np.floor(dapi.shape[0] / 1000 * 5))
figsize

In [None]:
# Plot cell number 
t_size = 10
plt.figure(figsize=figsize)
plt.imshow(overlay)
for i, region in enumerate(regionprops(labels)):
    plt.plot(region.centroid[1], region.centroid[0], '.', color='red', markersize=4)
    plt.text(region.centroid[1], region.centroid[0], str(i), fontsize=t_size, color='red')

plt.axis('off')
plt.savefig(os.path.join(image_path, "cell_nums.png"))
plt.clf()
plt.close()

In [None]:
# Plot dots on segmentation mask
plt.figure(figsize=figsize)
plt.imshow(labels > 0, cmap='gray')
plt.plot(current_reads_df['x'], current_reads_df['y'], '.', color='red', markersize=1)
plt.axis('off')
points_seg_path = os.path.join(image_path, "points_seg.png")
print(f"Saving points_seg.png")
plt.savefig(points_seg_path)
plt.clf()
plt.close()

In [None]:
genes_df = pd.read_csv(os.path.join(base_path, "genes.csv"), header=None)
genes_df.columns = ['Gene', 'Barcode']
genes_df

In [None]:
%%time
# Reads assignment to cell (new)
expr_out_path = os.path.join(out_path, current_fov_id)
if not os.path.exists(expr_out_path):
    os.mkdir(expr_out_path)
        
points = current_reads_df.loc[:, ["x", "y"]].values
bases = current_reads_df['gene'].values
reads_assignment = labels[points[:, 1], points[:, 0]]
    
cell_locs = []
total_cells = len(np.unique(labels)) - 1
areas = []
seg_labels = []

genes = genes_df['Gene'].values
cell_by_gene = np.zeros((total_cells, len(genes)))
gene_seq_to_index = {}  # map from sequence to index into matrix

for i, k in enumerate(genes):
    gene_seq_to_index[k] = i
    
# Iterate through cells
print('Iterate cells...')
for i, region in enumerate(tqdm(regionprops(labels))):
    # print(region.label)
    areas.append(region.area)
    cell_locs.append(region.centroid)
    seg_labels.append(region.label)
    
    assigned_reads = bases[np.argwhere(reads_assignment == region.label).flatten()]
    for j in assigned_reads:
        if j in gene_seq_to_index:
            cell_by_gene[i, gene_seq_to_index[j]] += 1
    
     
# Keep the good cells 
cell_locs = np.array(cell_locs).astype(int)
current_meta = pd.DataFrame({'sample': current_fov_id, 'area': areas,
                          'x':cell_locs[:, 1], 'y':cell_locs[:, 0], 'seg_label': seg_labels})

# Output
with open(os.path.join(expr_out_path, "log.txt"), 'w') as f:
    msg = "{:.2%} percent [{} out of {}] reads were assigned to {} cells".format(cell_by_gene.sum()/len(bases), cell_by_gene.sum(), len(bases), total_cells)
    print(msg)
    f.write(msg)
np.savetxt(os.path.join(expr_out_path, "cell_barcode_count.csv"), cell_by_gene.astype(int), delimiter=',', fmt="%d")
cell_barcode_names = pd.DataFrame({'gene': genes})
cell_barcode_names.to_csv(os.path.join(expr_out_path, "cell_barcode_names.csv"), header=False)
current_meta.to_csv(os.path.join(expr_out_path, "meta.csv"))


In [None]:
current_meta.head()

In [None]:
# Plot area distribution

sns.displot(areas)

In [None]:
np.median(areas)

In [None]:
current_reads_df['assignment'] = reads_assignment
current_reads_df['assignment_bw'] = current_reads_df['assignment'] > 0
current_reads_df

fig, ax = plt.subplots(figsize=(40,40))
ax.imshow(labels > 0, cmap='gray')
# plt.plot(current_reads_df['x'], current_reads_df['y'], '.', color='red', markersize=1)
sns.scatterplot(x='x', y='y', data=current_reads_df, size=.001, marker='.', hue='assignment_bw', ax=ax, rasterized=True, linewidth=0)
ax.axis('off')
points_seg_path = os.path.join(image_path, "points_assignment_no_outline.pdf")
plt.savefig(points_seg_path)
# plt.show()

In [None]:
round1_amplicon = imread(os.path.join(image_path, 'round1_max_uint8.tif'))
round1_amplicon.shape

fig, ax = plt.subplots(figsize=(40,40))
ax.imshow(round1_amplicon, cmap='gray')
sns.scatterplot(x='x', y='y', data=current_reads_df, size=.0001, marker='.', color='red', ax=ax, rasterized=True, linewidth=0)
ax.axis('off')
points_seg_path = os.path.join(image_path, "spot_finding.tif")
plt.savefig(points_seg_path, dpi=200)
# plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(40,40))
ax.imshow(round1_amplicon, cmap='gray')
sns.scatterplot(x='x', y='y', data=current_reads_df, size=.0001, marker='.', hue='gene', ax=ax, rasterized=True, linewidth=0)
ax.axis('off')
points_seg_path = os.path.join(image_path, "spot_finding_gene.tif")
plt.savefig(points_seg_path, dpi=200)
# plt.show()

In [None]:
overlay_with_label = imread(os.path.join(image_path, 'overlay_with_label.tif'))

fig, ax = plt.subplots(figsize=(40,40))
ax.imshow(overlay_with_label)
sns.scatterplot(x='x', y='y', data=current_reads_df, size=.0001, marker='.', hue='gene', ax=ax, rasterized=True, linewidth=0)
ax.axis('off')
points_seg_path = os.path.join(image_path, "reads_assignment.tif")
plt.savefig(points_seg_path, dpi=200)
# plt.show()

In [None]:
current_reads_df['assignment'] = current_reads_df['assignment'].astype('category')
current_pl = sns.color_palette(['#c9c9c9', '#1f78b4', '#33a02c', '#e31a1c', '#ff7f00', '#6a3d9a', '#a6cee3', '#b2df8a', '#fb9a99'])

fig, ax = plt.subplots(figsize=(40,40))
# ax.imshow(overlay, cmap='gray')
ax.imshow(np.zeros([3072, 3072]), cmap='gray')
sns.scatterplot(x='x', y='y', data=current_reads_df, size=.0001, marker='.', hue='assignment', ax=ax, rasterized=True, linewidth=0, palette=current_pl)
ax.plot(centroids[:, 1], centroids[:, 0], '.', color='red', markersize=10)
ax.axis('off')
points_seg_path = os.path.join(image_path, "clustermap.tif")
plt.savefig(points_seg_path, dpi=200)
# plt.show()

## Check expression pattern

In [None]:
# Get assigned reads 
assigned_index = np.argwhere(reads_assignment != 0).flatten()
assigned_bases = bases[assigned_index]
assigned_points = points[assigned_index, :]

In [None]:
selected_genes = ['Gfap', 'Mbp']
expr_figure_out_path = os.path.join(expr_out_path, 'figures')
if not os.path.exists(expr_figure_out_path):
    os.mkdir(expr_figure_out_path)
    
for i, gene in enumerate(tqdm(selected_genes)):
    
    curr_index = np.argwhere(assigned_bases == gene).flatten()
    curr_points = assigned_points[curr_index, :]
    n_reads = curr_points.shape[0]

    # Plot
    plt.figure(figsize=(10, 10))
    plt.imshow(overlay, cmap='gray')
    plt.plot(curr_points[:, 0], curr_points[:, 1], '.', color='red', markersize=.5)
    plt.axis('off')
    expr_figure_path = os.path.join(expr_figure_out_path, f"{i+1}.{gene}_{n_reads}.png")
    plt.savefig(expr_figure_path)
    plt.clf()
    plt.close()