# Detect barcoded cells

This notebook works on a single tile to make it easier and faster

In [None]:
# imports and chamber selection
%load_ext autoreload
%autoreload 2
import iss_preprocess as iss
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import yaml
from flexiznam.config import PARAMETERS
from pathlib import Path
from itertools import cycle
from matplotlib.animation import FuncAnimation
data_path = 'becalia_rabies_barseq/BRYC65.1d/chamber_13/'

processed_path = Path(PARAMETERS['data_root']['processed'])
metadata = iss.io.load_metadata(data_path)

ops = iss.config.DEFAULT_OPS.copy()
ops.update({ 
    'camera_order': metadata["camera_order"],
    'genes_rounds': metadata["genes_rounds"],
    'barcode_rounds': metadata["barcode_rounds"],
    'use_rois': [1, 2, 5, 6],
    'ref_tile': (1, 5, 8),
    'correction_tiles': [(1, 5, 8), (1, 5, 9), (1, 4, 8), (1, 4, 9), (2, 4, 9), (2, 3, 9), (2, 2, 9), (2, 2, 8)],
    'barcode_ref_tiles': [(1, 5, 8), (1, 5, 9), (1, 4, 8), (1, 4, 9), (2, 4, 9), (2, 3, 9), (2, 2, 9), (2, 2, 8)],
    'average_clip_value': 2000,
})

## Filter detected barcodes

We will load detected barcodes and filter them by dot product.

In [None]:
roi = 5
dot_threshold = 0.2
gaussian_width_um = 10


In [None]:
# get spots
raw_spots = dict()
spot_list = ['genes_round', 'barcode_round', 'hybridisation_1_1', 'hybridisation_2_1']
for prefix in spot_list:
    if prefix.endswith('_1'):
        reg_prefix = prefix
    else:
        reg_prefix = prefix + '_1_1'
    prefix
    spot_df = iss.pipeline.stitch.merge_and_align_spots(
    data_path,
    roi=5,
    spots_prefix=prefix,
    reg_prefix=reg_prefix,
    ref_prefix="genes_round_1_1",
)
    raw_spots[prefix] = spot_df

In [None]:
# filter spots
barcode_dot_threshold = 0.15
omp_score_threshold = 0.1
hyb_score = 0.8

spots = dict()
fig, axes = plt.subplots(2, 2)
fig.set_size_inches(7, 5)
kw = dict(histtype='step', color='k', lw=2)
axes[0,0].hist(raw_spots['barcode_round'].dot_product_score, bins=np.arange(-0.5, 1.1, 0.05), **kw)
axes[0,0].axvline(barcode_dot_threshold, color='k')
axes[0, 0].set_xlabel('Barcode dot score')
axes[0, 0].set_ylabel('# barcode rolonies')

axes[0,1].hist(raw_spots['genes_round'].spot_score, bins=np.arange(0, 1.2, 0.05), **kw)
axes[0,1].axvline(omp_score_threshold, color='k')
axes[0,1].set_xlabel('OMP score')
axes[0, 1].set_ylabel('# genes rolonies')

for i in range(2):
    axes[1,i].hist(raw_spots[f'hybridisation_{i+1}_1'].score, bins=np.arange(-0.50, 1.2, 0.05), **kw)
    axes[1,i].axvline(hyb_score, color='k')
    axes[1,i].set_xlabel('Hybridisation score')
    axes[1, i].set_ylabel(f'# hyb {i+1} rolonies')

plt.tight_layout()
raw_spots['genes_round'].head()

ok_barcode = raw_spots['barcode_round'].dot_product_score > barcode_dot_threshold
spots['barcode_round'] = raw_spots['barcode_round'][ok_barcode].copy()
print(f'Keeping {np.sum(ok_barcode)} barcode rolonies out of {len(ok_barcode)}.')
ok_genes = raw_spots['genes_round'].spot_score > omp_score_threshold
spots['genes_round'] = raw_spots['genes_round'][ok_genes].copy()
print(f'Keeping {np.sum(ok_genes)} genes rolonies out of {len(ok_genes)}.')
for i in range(2):
    ok_hyb = raw_spots[f'hybridisation_{i +1}_1'].score > hyb_score
    spots[f'hybridisation_{i +1}_1'] = raw_spots[f'hybridisation_{i +1}_1'][ok_hyb].copy()
    print(f'Keeping {np.sum(ok_hyb)} hybridisation rolonies out of {len(ok_hyb)} for round {i+1}.')

In [None]:
# get masks and expand
masks = np.load(processed_path / data_path / f"masks_{roi}.npy")
from skimage.segmentation import expand_labels
pixel_size=0.18
big_mask = expand_labels(masks, distance=int(5/pixel_size))


In [None]:
# plot what we have
roi = 5
corners = np.load(
    processed_path
    / data_path
    / "reg"
    / f"genes_round_1_1_roi{roi}_acquisition_tile_corners.npy"
)
tile = (4, 9)
# find shift

s = 1000

plt.figure(figsize=(10, 10))
center = np.mean(corners, axis=(3))[tile[0], tile[1]].astype(int)
xlim = center[0] + np.array([-s, s], dtype=int) - 500
ylim = center[1] + np.array([-s, s], dtype=int)

part2plot = (slice(*ylim), slice(*xlim))

m = np.array(big_mask[part2plot]-masks[part2plot], copy=True, dtype=float)
m[m==0] =np.nan
plt.imshow(m, extent=[*xlim, *ylim[::-1]], cmap='prism', interpolation='None', alpha=0.5)
colors = dict(barcode_round='darkred', genes_round='black', hybridisation_1_1='green', hybridisation_2_1='blue')
for w, sp in spots.items():
    ok = sp[(xlim[0] < sp.x) & (sp.x < xlim[1]) & (sp.y > ylim[0])& (sp.y < ylim[1])]
    plt.scatter(ok.x, ok.y, s=2, label=w, color=colors[w])
plt.legend(loc='upper right')



# Find barcodes and genes inside cells

In [None]:
# find which barcode is in which cells
spots_in_cells = dict()
for prefix, spot_df in spots.items():
    print(prefix, flush=True)
    grouping_column='bases' if prefix.startswith('barcode') else 'gene'
    cell_df = iss.segment.cells.count_rolonies(
    big_mask, spot_df, grouping_column=grouping_column
)
    spots_in_cells[prefix] = cell_df

In [None]:
# clean-up hyb
fused_df = spots_in_cells['genes_round'].copy()
for i_hyb in range(2):
    hyb_df = spots_in_cells[f"hybridisation_{i_hyb + 1}_1"]
    for gene in hyb_df.columns:
        if gene in fused_df.columns:
            print(f'Replacing {gene} with hybridisation')
            fused_df.pop(gene)
    fused_df = fused_df.join(hyb_df, how='outer')
fused_df[np.isnan(fused_df)] = 0
fused_df = fused_df.astype(int)
fused_df.head()

In [None]:
barcode_df = spots_in_cells['barcode_round']
rol_th = 10
fig, axes = plt.subplots(2,2)
fig.set_size_inches(10, 10)
kw = dict(histtype='step', color='k', lw=2)
for i in range(2):
    axes[0, i].hist(barcode_df.iloc[1:].sum(axis=1).values, bins=np.arange(-0.5, 50, 1), **kw)
    axes[0, i].set_xlabel('Number of barcode rolonies per cell')
    axes[0, i].axvline(rol_th, color='k')
axes[0, 1].semilogy()

barcoded_cells = barcode_df[barcode_df.sum(axis=1) > rol_th].iloc[1:]
axes[1, 0].scatter(barcoded_cells.sum(axis=1), barcoded_cells.max(axis=1), color='k')
axes[1, 0].set_xlabel('Total number of rolonies')
axes[1, 0].set_ylabel('Most aboundant sequence')
prop_main = barcoded_cells.max(axis=1)/barcoded_cells.sum(axis=1)
axes[1, 1].hist(prop_main, bins=np.arange(0, 1.1, 0.05), **kw)
axes[1, 1].set_xlabel('Proportion of rolonies from main sequence')
axes[1, 1].set_ylabel('# of cells')

In [None]:
# make edit distance plot
import editdistance
code_len = len(barcoded_cells.columns[0])
distance_df = pd.DataFrame(index=barcoded_cells.index, columns=np.arange(code_len + 1), dtype=int)
for cell_id, cell in barcoded_cells.iterrows():
    main = cell.idxmax()
    dst = np.zeros(code_len + 1)
    barcodes = cell[cell !=0]
    for seq, cnt in barcodes.items():
        edit = editdistance.eval(seq, main)
        dst[edit] += cnt
    distance_df.loc[cell_id, :] = dst

In [None]:
add_dapi = False
if add_dapi:
    dapi_stitched = iss.pipeline.stitch.stitch_registered(data_path, prefix='DAPI_1', roi=5, channels=0)

In [None]:
import seaborn as sns

distance_df = distance_df.sort_values(0)
fig, axes = plt.subplots(2, 2)
fig.set_size_inches(10, 10)
im = axes[0,0].imshow(distance_df.values, aspect='auto', interpolation='None')
cb = fig.colorbar(im, ax=axes[0,0])
cb.set_label('# rolonies')
axes[0,0].set_xlabel('Edit distance')
axes[0,0].set_ylabel('Cell #')

sns.stripplot(data=distance_df, ax=axes[0,1], color='purple')
axes[0,1].bar(distance_df.columns-0.1, distance_df.sum(axis=0)/len(distance_df.index), edgecolor='k', facecolor='None', width=1)
axes[0,1].set_xlabel('Edit distance')
axes[0,1].set_ylabel('# rolonies per cell')

double_labeled = distance_df.loc[:, 3:].max(axis=1) > 3
double_cell = distance_df.loc[:, 6].idxmax()
double_seq = barcoded_cells.loc[double_cell]
double_seq = double_seq[double_seq!=0].sort_values()[::-1]
axes[1,0].plot(double_seq.values, 'o', color='k')
axes[1,0].set_xticks(np.arange(len(double_seq)))
axes[1,0].set_xticklabels(double_seq.index, rotation=90)
axes[1,0].set_title(f'Cell {double_cell}')
axes[1,0].set_ylabel('# rolonies')
axes[1,0].set_xlabel('Sequence')

# plot the double cell
dc_position = np.where(big_mask == double_cell)
ylim = [dc_position[0].min(), dc_position[0].max()]
xlim = [dc_position[1].min(), dc_position[1].max()]
ylim += np.array([-1, 1]) * int(np.diff(ylim))
xlim += np.array([-1, 1]) * int(np.diff(xlim))
part2plot = (slice(*ylim), slice(*xlim))
axes[1,1].contour((big_mask[part2plot]-masks[part2plot])!=0, extent=[*xlim, *ylim[::]], colors='k', linewidths=0.5)
m = np.array(big_mask[part2plot], copy=True, dtype=float)
vals = np.unique(m)
for iv, v in enumerate(vals):
    m[m==v] = iv
m[m==0] = np.nan
axes[1,1].set_title(f'Cell {double_cell}')
if add_dapi:
    vmin, vmax =np.quantile(dapi_stitched[part2plot], [0.6, 0.995])
    axes[1,1].imshow(dapi_stitched[part2plot], extent=[*xlim, *ylim[::-1]], cmap='viridis', interpolation='None', alpha=1, vmax=vmax, vmin=vmin)
else:
    axes[1,1].imshow(m, extent=[*xlim, *ylim[::-1]], cmap='tab20', interpolation='None', alpha=1)

sp = spots['barcode_round']
ok = sp[(xlim[0] < sp.x) & (sp.x < xlim[1]) & (sp.y > ylim[0])& (sp.y < ylim[1])]
seqs = np.unique(ok.bases.values)
for s in seqs:
    v = ok.bases == s
    if np.sum(v) < 2:
        kw=dict(color='k',  s=20)
    else:
        kw = dict(s=20, label=s)
    axes[1,1].scatter(ok[v].x, ok[v].y, **kw)
axes[1,1].legend(loc='upper right', bbox_to_anchor=(1.1, -0.1), ncol= 3)


# Find the cortex

In [None]:

barcode_df = spots_in_cells['barcode_round']
fig = plt.figure(figsize=(50, 7))
ax = fig.add_subplot(1,1,1)
img = ax.imshow(barcode_df.values, aspect='auto', interpolation='none', vmax=10, origin='lower')
cb = plt.colorbar(img, ax=ax)
cb.set_label("Rolonie #")
ax.set_xticks(np.arange(barcode_df.shape[1]))
ax.set_yticks(np.arange(barcode_df.shape[0]))
ax.set_yticklabels(barcode_df.index)
ax.set_xticklabels(barcode_df.columns, rotation=90)
plt.tight_layout()



In [None]:
fig = plt.figure(figsize=(30, 7))
ax = fig.add_subplot(1,1,1)
img = ax.imshow(fused_df.values, aspect='auto', interpolation='none', vmax=10, origin='lower')
cb = plt.colorbar(img, ax=ax)
cb.set_label("Rolonie #")
ax.set_xticks(np.arange(fused_df.shape[1]))
ax.set_yticks(np.arange(fused_df.shape[0]))
ax.set_yticklabels(fused_df.index)
ax.set_xticklabels(fused_df.columns, rotation=90)
plt.tight_layout()

In [None]:
plt.subplot(1,2,1)
plt.hist(fused_df.loc[1:].sum(axis=1))
plt.semilogy()
plt.xlabel("# genes rolonies per cells")
plt.subplot(1,2,2)
plt.hist(barcode_df.loc[1:].sum(axis=1))
plt.semilogy()
plt.xlabel("# barcode rolonies per cells")

In [None]:
iss.vis.plot_gene_matrix(fused_df.iloc[1:].astype(int), cmap="inferno", vmax=2)

# Plot example SST cell

In [None]:
cell_id = 23

cell_series = fused_df.loc[cell_id]
print(f"Ploting cell {cell_id} with {cell_series.Sst} sst rolonies")

In [None]:
mask = np.vstack(np.where(barcoded_mask == cell_id))
bounding_box = np.vstack([mask.min(axis=1), mask.max(axis=1)]).astype(int)
bounding_box += np.array([[-1, -1],[1,1]]) * np.diff(bounding_box, axis=0).max()
part2plot = (slice(*bounding_box[:, 0]), slice(*bounding_box[:, 1]))

data = np.dstack([barcodes_all_channels.std(axis=2), genes_all_channels.std(axis=2)])
lim = np.nanquantile(data, [0.05, 0.99], axis=(0,1))
img = iss.vis.to_rgb(data,
                     colors=[[1,0,0],[0,1,0]], vmin=lim[0], vmax=lim[1])
plt.imshow(img[part2plot])
plt.contour(barcoded_mask[part2plot])
plt.scatter(spots_in_tile.x - bounding_box[0, 1], spots_in_tile.y - bounding_box[0, 0], s=10, label='Barcodes')
plt.scatter(genes_spots.x - bounding_box[0, 1], genes_spots.y - bounding_box[0, 0], s=10, label='Genes')
plt.xlim([0, np.diff(bounding_box, axis=0)[0,1]])
plt.ylim([np.diff(bounding_box, axis=0)[0, 0], 0])


In [None]:
hyb1_all_channels = iss.pipeline.stitch.load_tile_ref_coors(
    data_path=data_path, tile_coors=tile_coors, prefix="hybridisation_1_1"
)
hyb2_all_channels = iss.pipeline.stitch.load_tile_ref_coors(
    data_path=data_path, tile_coors=tile_coors, prefix="hybridisation_2_1"
)


In [None]:
raise NotImplementedError("Everything below needs to be changed or deleted")

In [None]:
cells = iss.segment.cells.cellpose_segmentation(
    snippet,
    channels=(0, 0),
    flow_threshold=0.5,
    min_pix=0,
    dilate_pix=0,
    rescale=None,
    model_type="CP",
    use_gpu=False,
    diameter=int(20 / pixel_size),
)
cellpose_cells = np.array(cells)
plt.imshow(cells, cmap="Set2", interpolation="None")
print(f"Found {len(np.unique(cells))-1} cells")


# Debug Opencv version

Now using classic opencv

In [None]:
mask = 255 * (blur > 10).astype("uint8")
kernel = np.ones((5, 5), dtype="uint8") * 255
background = cv2.dilate(mask, kernel, iterations=10)
dst2nonzero = cv2.distanceTransform(mask, distanceType=cv2.DIST_L2, maskSize=5)
is_cell = 255 * (dst2nonzero > 20).astype("uint8")
ret, markers = cv2.connectedComponents(is_cell)
# make the background to 1
markers += 1
# and part to watershed to 0
markers[np.bitwise_xor(background, is_cell).astype(bool)] = 0
# watershed required a rgb image
stack = cv2.normalize(blur, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8U)
stack = cv2.cvtColor(stack, cv2.COLOR_GRAY2BGR)
water = cv2.watershed(stack, markers)
water -= 1  # put the background seed to 0.
water[water < 0] = 0  # put borders into background

fig, axes = plt.subplots(2, 2)
fig.set_size_inches(10, 10)
axes[0,0].imshow(blur)
axes[0,0].contour(blur, levels=[5,10,20,40], colors=['orange', 'pink', 'red', 'darkred'])
axes[0,1].imshow(dst2nonzero)
axes[1,0].imshow(cv2.MORPH_CLOSE)

In [None]:
# plot various binary step
debug = iss.segment.barcodes.segment_spot_image(
    blur, binarise_threshold=5, distance_threshold=10, debug=True
)
fig, axes = plt.subplots(2, 2)
fig.set_size_inches(10, 10)
axes[0, 0].imshow(debug["binary"])
axes[0, 0].set_title("Binarised")
axes[0, 1].imshow(debug["background"])
axes[0, 1].set_title("Background is blue")
axes[1, 0].imshow(debug["distance"])
axes[1, 0].set_title("Distance 2 non-zero")
axes[1, 1].imshow(debug["seeds"])
axes[1, 1].set_title("Cells")

plt.tight_layout()


In [None]:
fig, axes = plt.subplots(2, 2)
fig.set_size_inches(10, 10)
axes[0, 0].imshow(debug["initial_labels"], cmap="tab20", interpolation="None")
axes[0, 0].set_title("Markers")
axes[0, 1].imshow(debug["watershed"], cmap="tab20", interpolation="None")
axes[0, 1].set_title("Watershed")
axes[1, 0].imshow(debug["intial_labels"], cmap="tab20", interpolation="None")
axes[1, 0].contour(debug["watershed"], colors="darkred")
axes[1, 0].set_title("Cell contours")
axes[1, 1].imshow(blur, cmap="tab20", interpolation="None")
axes[1, 0].contour(debug["watershed"], colors="darkred")
axes[1, 0].set_title("Cell contours")


for x in axes.flatten():
    x.axis("off")
plt.tight_layout()
print(f"Found {len(np.unique(water))-2} cells")
opencv_cells = np.array(water)


"""
background=background,
binary=mask,
seeds=is_cell,
distance=dst2nonzero,
initial_labels=markers,
watershed=water,
"""


# Now watershed

We want to flood from each cell and extend around but not too far. We can do that by
setting a label for the background that is far from cells

# Overlay to data

In [None]:
# select tile
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(10, 7)
corners = np.load(
    processed_path
    / data_path
    / "reg"
    / f"genes_round_1_1_roi{roi}_acquisition_tile_corners.npy"
)
ax.scatter(spots.x, spots.y, s=1, color="red", alpha=0.3)
ax.set_aspect("equal")
for row, corner in enumerate(corners):
    for col, corne in enumerate(corner):
        center = np.nanmean(corne, axis=1)
        ax.plot(corne[1, :], corne[0, :], color="k")
        ax.text(
            center[1],
            center[0],
            s=f"({row}, {col})",
            verticalalignment="center",
            horizontalalignment="center",
        )
ax.invert_yaxis()


In [None]:
# Get raw data for overlay
tile_coords = (roi, 3, 9)
genes_all_channels = iss.pipeline.stitch.load_tile(
    data_path=data_path, tile_coordinates=tile_coords, prefix="genes_round_1_1"
)
barcodes_all_channels = iss.pipeline.stitch.load_tile(
    data_path=data_path, tile_coordinates=tile_coords, prefix="barcode_round_1_1"
)
dapi = iss.pipeline.stitch.load_tile(
    data_path=data_path, tile_coordinates=tile_coords, prefix="dapi_1"
)

ref_corners = np.load(
    processed_path
    / data_path
    / "reg"
    / f"genes_round_1_1_roi{roi}_acquisition_tile_corners.npy"
)
dapi = dapi[:, :, 0, 0]
barcodes = np.nanstd(barcodes_all_channels, axis=2)[..., 0]
genes = np.nanstd(genes_all_channels, axis=2)[..., 0]


In [None]:
ref_corners.shape


In [None]:
borders = np.zeros(snippet.shape, dtype="uint8")
borders[water == -1] = 255
borders = cv2.dilate(borders, np.ones((5, 5)))
glim = np.nanquantile(genes, [0.1, 0.99])
blim = np.nanquantile(barcodes, [0.1, 0.99])
dlim = np.nanquantile(dapi, [0.1, 0.99])
img = iss.vis.to_rgb(
    np.dstack((dapi, genes, barcodes)),
    colors=[[0, 0, 1], [1, 0, 0], [0, 1, 0]],
    vmin=np.array([dlim[0], glim[0], blim[0]]),
    vmax=[dlim[1], glim[1], blim[1]],
)

fig, axes = plt.subplots(1, 1)
axes = [axes]
fig.set_size_inches(20, 20)
corner = ref_corners[tile_coords[1], tile_coords[2]]
valid_spot = spots[
    (spots.x > corner[1, 0])
    & (spots.x < corner[1, 3])
    & (spots.y > corner[0, 0])
    & (spots.y < corner[0, 3])
]
for x in axes:
    x.scatter(
        spots.x - corner[1, 0], spots.y - corner[0, 0], s=10, color="yellow", alpha=1
    )
    x.imshow(img)
    x.axis("off")

plt.tight_layout()


In [None]:
# basic imread
fname = 'barcode_round_1_1_MMStack_5-Pos000_000_fstack.tif'
full_fname = processed_path / data_path / "barcode_round_1_1" / fname
%timeit iss.io.load.load_stack(full_fname).astype('single')

In [None]:
from tifffile import imread
%timeit np.moveaxis(imread(full_fname).astype('single'), 0, 2)

Optimising tile loading

In [None]:
from skimage.morphology import binary_dilation

prefix = "barcode_round"
tile_coors = (5, 0, 0)
nrounds = 1
suffix = "fstack"
filter_r = (2, 4)


## original version with just processing steps

This is for reference

In [None]:
# origin version with just processing steps
def original_version():
    processed_path = Path(PARAMETERS["data_root"]["processed"])
    tforms_fname = f"tforms_corrected_{prefix}_{tile_coors[0]}_{tile_coors[1]}_{tile_coors[2]}.npz"
    tforms_path = processed_path / data_path / "reg" / tforms_fname
    tforms = np.load(tforms_path, allow_pickle=True)

    stack = iss.pipeline.load_sequencing_rounds(
        data_path, tile_coors, suffix=suffix, prefix=prefix, nrounds=nrounds
    )
    tforms = iss.pipeline.generate_channel_round_transforms(
        tforms["angles_within_channels"],
        tforms["shifts_within_channels"],
        tforms["scales_between_channels"],
        tforms["angles_between_channels"],
        tforms["shifts_between_channels"],
        stack.shape[:2],
    )
    stack = iss.pipeline.align_channels_and_rounds(stack, tforms)
    stack = iss.pipeline.apply_illumination_correction(data_path, stack, prefix)
    bad_pixels = np.any(np.isnan(stack), axis=(2, 3))
    stack[np.isnan(stack)] = 0
    stack = iss.pipeline.filter_stack(stack, r1=filter_r[0], r2=filter_r[1])
    mask = np.ones((filter_r[1] * 2 + 1, filter_r[1] * 2 + 1))
    bad_pixels = binary_dilation(bad_pixels, mask)

    correction_path = processed_path / data_path / f"correction_{prefix}.npz"
    norm_factors = np.load(correction_path, allow_pickle=True)["norm_factors"]
    stack = stack / norm_factors[np.newaxis, np.newaxis, :, :nrounds]
    return stack
%timeit original_version()


In [None]:
# separate tform version with just processing steps
processed_path = Path(PARAMETERS["data_root"]["processed"])
tforms_fname = f"tforms_corrected_{prefix}_{tile_coors[0]}_{tile_coors[1]}_{tile_coors[2]}.npz"
tforms_path = processed_path / data_path / "reg" / tforms_fname
tforms = np.load(tforms_path, allow_pickle=True)
image_shape = (3300, 3296)
tforms = iss.pipeline.generate_channel_round_transforms(
    tforms["angles_within_channels"],
    tforms["shifts_within_channels"],
    tforms["scales_between_channels"],
    tforms["angles_between_channels"],
    tforms["shifts_between_channels"],
    image_shape,
)
def pregenerate_tforms():
    stack = iss.pipeline.load_sequencing_rounds(
        data_path, tile_coors, suffix=suffix, prefix=prefix, nrounds=nrounds
    )

    stack = iss.pipeline.align_channels_and_rounds(stack, tforms)
    stack = iss.pipeline.apply_illumination_correction(data_path, stack, prefix)
    bad_pixels = np.any(np.isnan(stack), axis=(2, 3))
    stack[np.isnan(stack)] = 0
    stack = iss.pipeline.filter_stack(stack, r1=filter_r[0], r2=filter_r[1])
    mask = np.ones((filter_r[1] * 2 + 1, filter_r[1] * 2 + 1))
    bad_pixels = binary_dilation(bad_pixels, mask)

    correction_path = processed_path / data_path / f"correction_{prefix}.npz"
    norm_factors = np.load(correction_path, allow_pickle=True)["norm_factors"]
    stack = stack / norm_factors[np.newaxis, np.newaxis, :, :nrounds]
%timeit pregenerate_tforms()


In [None]:
iss.io.load.load_stack(full_fname).astype("single").shape


In [None]:
np.moveaxis(imread(full_fname).astype("single"), 0, 2).shape


In [None]:
roi = 5
(
    stitched_stack_dapi,
    stitched_stack_genes,
    angle,
    shift,
) = iss.pipeline.stitch_and_register(
    data_path, "genes_round_1_1", "DAPI_1", roi=roi, downsample=5
)


In [None]:
(
    stitched_stack_barcode,
    stitched_stack_genes,
    angle,
    shift,
) = iss.pipeline.stitch_and_register(
    data_path, "genes_round_1_1", "barcode_round_1_1", roi=roi, downsample=5
)


In [None]:
masks = np.load(processed_path / data_path / f"masks_{roi}.npy")
im = np.stack(
    [
        stitched_stack_genes[3000:10000, 12000:20000],
        stitched_stack_dapi[3000:10000, 12000:20000],
        masks[3000:10000, 12000:20000] > 0,
    ],
    axis=2,
)
shift_right, shift_down, tile_shape = iss.pipeline.register_adjacent_tiles(
    data_path, ref_coors=ops["ref_tile"], prefix="genes_round_1_1"
)
genes_spots = iss.pipeline.merge_roi_spots(
    data_path, shift_right, shift_down, tile_shape, iroi=roi, prefix="genes_round"
)


In [None]:
barcode_spots = iss.pipeline.merge_roi_spots(
    data_path, shift_right, shift_down, tile_shape, iroi=roi, prefix="barcode_round"
)

plt.figure(figsize=(50, 50))
plt.imshow(
    iss.vis.to_rgb(
        im,
        colors=[[1, 0, 0], [0, 0, 1], [0, 1, 0]],
        vmax=[400, 200, 1],
        vmin=np.array([30, 0, 0]),
    )
)
plt.plot(
    barcode_spots["x"] - 12000, barcode_spots["y"] - 3000, ".r", alpha=1, markersize=10
)
plt.plot(
    genes_spots["x"] - 12000,
    genes_spots["y"] - 3000,
    ".",
    color="purple",
    alpha=1,
    markersize=10,
)
plt.xlim([0, 4000])
plt.ylim([4000, 0])
plt.axis("off")
