# Reads assignemnt 3D

In [1]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from tifffile import imread, imwrite
from skimage.measure import regionprops
from tqdm.notebook import tqdm
from anndata import AnnData


## Load the tile config

In [2]:
base_path = '/home/unix/jiahao/wanglab/Data/Analyzed/2024-03-12-Mingrui-PFC/'
image_path = os.path.join(base_path, 'images')
signal_path = os.path.join(base_path, 'signal')
output_path = os.path.join(base_path, 'output')
expr_path = os.path.join(base_path, 'expr')
if not os.path.exists(expr_path):
    os.mkdir(expr_path)

# morph_path = os.path.join(image_path, 'morph')
# if not os.path.exists(morph_path):
#     os.mkdir(morph_path)

In [166]:
current_sample = 'sample4'
current_tile_xml = os.path.join(image_path, "fused/3D", current_sample, 'DAPI', 'dataset.xml')

In [167]:
setup_df = pd.read_xml(current_tile_xml, xpath=".//ViewRegistration")
setup_max = setup_df['setup'].max()
setup_max

62

In [168]:
setup_list = []
for i in range(setup_max+1):
    setup_list.append([i] * 3)
setup_list = np.array(setup_list).flatten()
setup_list

array([ 0,  0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,
        5,  6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10, 10, 11,
       11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16,
       17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21, 21, 21, 22, 22,
       22, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28,
       28, 28, 29, 29, 29, 30, 30, 30, 31, 31, 31, 32, 32, 32, 33, 33, 33,
       34, 34, 34, 35, 35, 35, 36, 36, 36, 37, 37, 37, 38, 38, 38, 39, 39,
       39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45,
       45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50,
       51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56,
       56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62,
       62, 62])

In [169]:
transform_df = pd.read_xml(current_tile_xml, xpath=".//ViewTransform")
transform_df['setup'] = setup_list
transform_df = transform_df.pivot(index='setup', columns='Name', values='affine')
transform_df = transform_df.loc[:, ["Stitching Transform", "Translation to Regular Grid"]]
transform_df['x_st'] = transform_df["Stitching Transform"].str.split(' ').apply(lambda x: x[3]).astype(float)
transform_df['y_st'] = transform_df["Stitching Transform"].str.split(' ').apply(lambda x: x[7]).astype(float)
transform_df['z_st'] = transform_df["Stitching Transform"].str.split(' ').apply(lambda x: x[11]).astype(float)

transform_df['x_trg'] = transform_df["Translation to Regular Grid"].str.split(' ').apply(lambda x: x[3]).astype(float)
transform_df['y_trg'] = transform_df["Translation to Regular Grid"].str.split(' ').apply(lambda x: x[7]).astype(float)
transform_df['z_trg'] = transform_df["Translation to Regular Grid"].str.split(' ').apply(lambda x: x[11]).astype(float)

transform_df['x'] = transform_df['x_st'] + transform_df['x_trg']
transform_df['y'] = transform_df['y_st'] + transform_df['y_trg']
transform_df['z'] = transform_df['z_st'] + transform_df['z_trg']

transform_df['x'] = transform_df['x'].astype(int)
transform_df['y'] = transform_df['y'].astype(int)
transform_df['z'] = transform_df['z'].astype(int)

transform_df['x'] = transform_df['x'] + np.abs(transform_df['x'].min())
transform_df['y'] = transform_df['y'] + np.abs(transform_df['y'].min())
transform_df['z'] = transform_df['z'] + np.abs(transform_df['z'].min())
transform_df['fov_index'] = transform_df.index

transform_df = transform_df.loc[:, ['x', 'y', 'z', 'fov_index']]

In [170]:
transform_df

Name,x,y,z,fov_index
setup,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,4,2,15,0
1,11,1350,15,1
2,10,2699,14,2
3,9,4055,12,3
4,9,5409,11,4
...,...,...,...,...
58,10720,2739,4,58
59,10720,4093,4,59
60,10723,5446,3,60
61,10724,6800,3,61


In [171]:
grid_file = os.path.join(image_path, "fused/3D", current_sample, 'grid.csv')
grid_df = pd.read_csv(grid_file, index_col=0)
grid_df

Unnamed: 0,col,row,id,grid
0,0,0,0,tile_0_0
1,0,1,190,tile_1_0
2,0,2,191,tile_2_0
3,0,3,192,tile_3_0
4,0,4,193,tile_4_0
...,...,...,...,...
58,8,2,242,tile_2_8
59,8,3,243,tile_3_8
60,8,4,244,tile_4_8
61,8,5,245,tile_5_8


In [172]:
tile_config_df = pd.concat([transform_df, grid_df], axis=1)
tile_config_df

Unnamed: 0,x,y,z,fov_index,col,row,id,grid
0,4,2,15,0,0,0,0,tile_0_0
1,11,1350,15,1,0,1,190,tile_1_0
2,10,2699,14,2,0,2,191,tile_2_0
3,9,4055,12,3,0,3,192,tile_3_0
4,9,5409,11,4,0,4,193,tile_4_0
...,...,...,...,...,...,...,...,...
58,10720,2739,4,58,8,2,242,tile_2_8
59,10720,4093,4,59,8,3,243,tile_3_8
60,10723,5446,3,60,8,4,244,tile_4_8
61,10724,6800,3,61,8,5,245,tile_5_8


In [173]:
import plotly.express as px
fig = px.scatter_3d(tile_config_df, x='x', y='y', z='z', color='id')
fig.update_traces(marker_size = 5)
fig.update_scenes(zaxis_autorange="reversed")
# fig.update_scenes(yaxis_autorange="reversed")
fig.update_scenes(xaxis_autorange="reversed")
fig.show()
fig.write_html(os.path.join(output_path, f'tile_config_{current_sample}.html'))

In [174]:
img_dim = [1496, 1496, 59]
start_x_list = []
start_y_list = []
end_x_list = []
end_y_list = []

for i in tqdm(range(tile_config_df.shape[0])):
# for i in tqdm(range(7)):
    current_record = tile_config_df.iloc[i]
    current_id = current_record['id']

    if current_id == 0:
        start_x_list.append(0)
        start_y_list.append(0)
        end_x_list.append(0)
        end_y_list.append(0)
    else:
        print(f"Processing tile {current_id}")
        current_row = current_record['row']
        current_col = current_record['col']
        current_x = current_record['x']
        current_y = current_record['y']

        left_tile = f"tile_{current_row}_{current_col - 1}"
        right_tile = f"tile_{current_row}_{current_col + 1}"
        up_tile = f"tile_{current_row - 1}_{current_col}"
        down_tile = f"tile_{current_row + 1}_{current_col}"

        if left_tile in tile_config_df.grid.values and tile_config_df.loc[tile_config_df.grid == left_tile, 'id'].values != 0:
            left_x = tile_config_df.loc[tile_config_df.grid == left_tile, 'x'].values
            me_start_x = int((left_x + img_dim[0] - current_x)/2 + 0.5) + current_x
        else:
            me_start_x = current_x

        if up_tile in tile_config_df.grid.values and tile_config_df.loc[tile_config_df.grid == up_tile, 'id'].values != 0:
            left_y = tile_config_df.loc[tile_config_df.grid == up_tile, 'y'].values
            me_start_y = int((left_y + img_dim[1] - current_y)/2 + 0.5) + current_y
        else:
            me_start_y = current_y

        if right_tile in tile_config_df.grid.values and tile_config_df.loc[tile_config_df.grid == right_tile, 'id'].values != 0:
            right_x = tile_config_df.loc[tile_config_df.grid == right_tile, 'x'].values
            me_end_x = int((current_x + img_dim[0] - right_x)/2 + 0.5 + right_x)
        else:
            me_end_x =  current_x + img_dim[0]
            
        if down_tile in tile_config_df.grid.values and tile_config_df.loc[tile_config_df.grid == down_tile, 'id'].values != 0:
            right_y = tile_config_df.loc[tile_config_df.grid == down_tile, 'y'].values
            me_end_y = int((current_y + img_dim[0] - right_y)/2 + 0.5 + right_y)
        else:
            me_end_y = current_y + img_dim[0]

        current_start_point = [me_start_x, me_start_y]
        current_end_point = [me_end_x, me_end_y]
        start_x_list.append(me_start_x)
        start_y_list.append(me_start_y)
        end_x_list.append(me_end_x)
        end_y_list.append(me_end_y)


  0%|          | 0/63 [00:00<?, ?it/s]

Processing tile 190
Processing tile 191
Processing tile 192
Processing tile 193
Processing tile 194
Processing tile 195
Processing tile 196
Processing tile 197
Processing tile 198
Processing tile 199
Processing tile 200
Processing tile 201
Processing tile 202
Processing tile 203
Processing tile 204
Processing tile 205
Processing tile 206
Processing tile 207
Processing tile 208
Processing tile 209
Processing tile 210
Processing tile 211
Processing tile 212
Processing tile 213
Processing tile 214
Processing tile 215
Processing tile 216
Processing tile 217
Processing tile 218
Processing tile 219
Processing tile 220
Processing tile 221
Processing tile 222
Processing tile 223
Processing tile 224
Processing tile 225
Processing tile 226
Processing tile 227
Processing tile 228
Processing tile 229
Processing tile 230
Processing tile 231
Processing tile 232
Processing tile 233
Processing tile 234
Processing tile 235
Processing tile 236
Processing tile 237
Processing tile 238
Processing tile 239


In [175]:
tile_config_df['start_x'] = start_x_list
tile_config_df['start_y'] = start_y_list
tile_config_df['end_x'] = end_x_list    
tile_config_df['end_y'] = end_y_list


In [176]:
tile_config_df['start_x_norm'] = tile_config_df['start_x'] - tile_config_df['x']
tile_config_df['start_y_norm'] = tile_config_df['start_y'] - tile_config_df['y']
tile_config_df['end_x_norm'] = tile_config_df['end_x'] - tile_config_df['x']
tile_config_df['end_y_norm'] = tile_config_df['end_y'] - tile_config_df['y']


In [177]:
tile_config_df.loc[tile_config_df.id == 0, 'start_x_norm'] = 0
tile_config_df.loc[tile_config_df.id == 0, 'start_y_norm'] = 0
tile_config_df.loc[tile_config_df.id == 0, 'end_x_norm'] = 0
tile_config_df.loc[tile_config_df.id == 0, 'end_y_norm'] = 0

In [178]:
tile_config_df.head(50)

Unnamed: 0,x,y,z,fov_index,col,row,id,grid,start_x,start_y,end_x,end_y,start_x_norm,start_y_norm,end_x_norm,end_y_norm
0,4,2,15,0,0,0,0,tile_0_0,0,0,0,0,0,0,0,0
1,11,1350,15,1,0,1,190,tile_1_0,11,1350,1427,2773,0,0,1416,1423
2,10,2699,14,2,0,2,191,tile_2_0,10,2773,1428,4125,0,74,1418,1426
3,9,4055,12,3,0,3,192,tile_3_0,9,4125,1427,5480,0,70,1418,1425
4,9,5409,11,4,0,4,193,tile_4_0,9,5480,1427,6905,0,71,1418,1496
5,2,6764,8,5,0,5,0,tile_5_0,0,0,0,0,0,0,0,0
6,0,8113,6,6,0,6,0,tile_6_0,0,0,0,0,0,0,0,0
7,1347,0,14,7,1,0,194,tile_0_1,1347,0,2766,1423,0,0,1419,1423
8,1347,1350,15,8,1,1,195,tile_1_1,1427,1423,2766,2775,80,73,1419,1425
9,1349,2703,12,9,1,2,196,tile_2_1,1428,2775,2767,4130,79,72,1418,1427


In [179]:
tile_config_df.to_csv(os.path.join(output_path, f'tile_config_{current_sample}.csv'))

In [46]:
for i in range(tile_config_df.shape[0]):
    current_record = tile_config_df.iloc[i]
    current_id = current_record['id']
    
    if current_id == 0:
        continue
    else:
        print(current_id)
        current_fov_id = f"Position{current_id:03}"

        # Create path 
        current_morph_path = os.path.join(morph_path, current_fov_id)
        if not os.path.exists(current_morph_path):
            os.mkdir(current_morph_path)

        current_expr_path = os.path.join(expr_path, current_fov_id)
        if not os.path.exists(current_expr_path):
            os.mkdir(current_expr_path)

        # Load images
        current_gray_img = imread(os.path.join(image_path, "flamingo", 'DAPI', f"{current_fov_id}.tif"))
        current_gray_max = np.max((current_gray_img), axis=0)
        current_label_img = imread(os.path.join(image_path, "flamingo", 'stardist_segmentation', f"{current_fov_id}.tif"))
        current_label_max = np.max((current_label_img > 0), axis=0)
        current_seg_coverage = (current_label_img > 0).sum() / (current_gray_img > 40).sum() * 100
        print(current_seg_coverage)

        # Load signal
        reads_df = pd.read_csv(os.path.join(signal_path, f'{current_fov_id}_goodSpots.csv'))
        reads_df['x'] = reads_df['x'] - 1
        reads_df['y'] = reads_df['y'] - 1
        reads_df['z'] = reads_df['z'] - 1
        reads_df['global_x'] = reads_df['x'] + current_record['x']
        reads_df['global_y'] = reads_df['y'] + current_record['y']
        reads_df['global_z'] = reads_df['z'] + current_record['z']

        # Load genes.csv
        genes_df = pd.read_csv(os.path.join(base_path, "genes.csv"), header=None)
        genes_df.columns = ['gene', 'barcode']

        # Reads assignment to cell
        points = reads_df.loc[:, ["x", "y", "z"]].values
        bases = reads_df['gene'].values
        reads_assignment = current_label_img[points[:, 2], points[:, 1], points[:, 0]]
        reads_df['seg_label'] = reads_assignment
        
        cell_locs = []
        total_cells = len(np.unique(current_label_img)) - 1
        areas = []
        seg_labels = []

        genes = genes_df['gene'].values
        cell_by_gene = np.zeros((total_cells, len(genes)))
        gene_seq_to_index = {}  # map from sequence to index into matrix

        for i, k in enumerate(genes):
            gene_seq_to_index[k] = i
            
        # Iterate through cells
        print('Iterate cells...')
        for i, region in enumerate(tqdm(regionprops(current_label_img, current_gray_img))):
            areas.append(region.area)
            cell_locs.append(region.centroid)
            seg_labels.append(region.label)
            current_cell_label = region.image
            current_cell_image = region.image_intensity
            imwrite(os.path.join(current_morph_path, f"mask_{region.label}.tif"), current_cell_label)
            imwrite(os.path.join(current_morph_path, f"img_{region.label}.tif"), current_cell_image)

            assigned_reads = bases[np.argwhere(reads_assignment == region.label).flatten()]
            for j in assigned_reads:
                if j in gene_seq_to_index:
                    cell_by_gene[i, gene_seq_to_index[j]] += 1
            
        cell_locs = np.array(cell_locs).astype(int)
        global_cell_locs = cell_locs + np.array([current_record['z'], current_record['y'], current_record['x']])
        current_meta = pd.DataFrame({'sample': current_sample, 'fov_id': current_fov_id, 'volume': areas, 'fov_x': cell_locs[:, 2], 'fov_y': cell_locs[:, 1], 'fov_z': cell_locs[:, 0], 'seg_label': seg_labels,
                                    'global_x': global_cell_locs[:, 2], 'global_y': global_cell_locs[:, 1], 'global_z': global_cell_locs[:, 0]})
        cell_barcode_names = pd.DataFrame({'gene': genes})
        cell_barcode_names.index = cell_barcode_names['gene']

        # Create scanpy object
        adata = AnnData(X=cell_by_gene, obs=current_meta, var=cell_barcode_names)

        # Filter cells based on location 
        adata = adata[adata.obs['fov_x'].isin(range(current_record['start_x_norm'], current_record['end_x_norm'])), ]
        adata = adata[adata.obs['fov_y'].isin(range(current_record['start_y_norm'], current_record['end_y_norm'])), ]
        adata.obs = adata.obs.reset_index(drop=True)
        
        # Visualize the data
        # cell centers on the segmentation
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(current_label_max, cmap='gray')
        rect = patches.Rectangle((current_record.start_x_norm, current_record.start_y_norm), 
                                current_record.end_x_norm - current_record.start_x_norm, 
                                current_record.end_y_norm - current_record.start_y_norm,
                                linewidth=.5, edgecolor='y', facecolor='none')
        ax.add_patch(rect)

        ax.plot(current_meta.fov_x, current_meta.fov_y, 'k.', markersize=1, )
        ax.plot(adata.obs.fov_x, adata.obs.fov_y, 'r.', markersize=2, ) 
        plt.savefig(os.path.join(current_expr_path, f"cell_centers_on_label.png"))
        plt.clf()
        plt.close()
        # plt.show()

        # cell centers on the dapi
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(current_gray_max, cmap='gray')
        rect = patches.Rectangle((current_record.start_x_norm, current_record.start_y_norm), 
                                current_record.end_x_norm - current_record.start_x_norm, 
                                current_record.end_y_norm - current_record.start_y_norm,
                                linewidth=.5, edgecolor='y', facecolor='none')
        ax.add_patch(rect)

        ax.plot(current_meta.fov_x, current_meta.fov_y, 'k.', markersize=1, )
        ax.plot(adata.obs.fov_x, adata.obs.fov_y, 'r.', markersize=2, ) 
        plt.savefig(os.path.join(current_expr_path, f"cell_centers_on_dapi.png"))
        plt.clf()
        plt.close()
        # plt.show()

        # reads on the segmentation
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(current_label_max, cmap='gray')
        rect = patches.Rectangle((current_record.start_x_norm, current_record.start_y_norm), 
                                current_record.end_x_norm - current_record.start_x_norm, 
                                current_record.end_y_norm - current_record.start_y_norm,
                                linewidth=.5, edgecolor='y', facecolor='none')
        ax.add_patch(rect)

        ax.plot(reads_df.x, reads_df.y, 'r.', markersize=1, )
        plt.savefig(os.path.join(current_expr_path, f"reads_on_label.png"))
        plt.clf()
        plt.close()
        # plt.show()

        # Output
        # log
        with open(os.path.join(current_expr_path, "log.txt"), 'w') as f:
            msg = "{:.2%} percent [{} out of {}] reads were assigned to {} cells\n".format(cell_by_gene.sum()/len(bases), cell_by_gene.sum(), len(bases), total_cells)
            f.write(msg)
            f.write(f"segmentation coverage: {current_seg_coverage:.2f}%")

        # adata
        adata.write(os.path.join(current_expr_path, "raw.h5ad"))
        
        # reads assignment
        reads_df.to_csv(os.path.join(current_expr_path, "reads_assignment.csv"), index=False)

        break


        

707


NameError: name 'morph_path' is not defined

In [132]:
with open("/home/unix/jiahao/wanglab/jiahao/Github/starfinder/test/Jiakun/fov.txt", "w") as f:
    for i in range(1, 1524 + 1):
        f.write(f"{str(i)}\n")