# Detect barcoded cells

This notebook works on a single tile to make it easier and faster

In [None]:
# imports and chamber selection
%load_ext autoreload
%autoreload 2
import iss_preprocess as iss
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import yaml
from flexiznam.config import PARAMETERS
from pathlib import Path
from itertools import cycle
from matplotlib.animation import FuncAnimation
data_path = 'becalia_rabies_barseq/BRYC65.1d/chamber_13/'

processed_path = Path(PARAMETERS['data_root']['processed'])
metadata = iss.io.load_metadata(data_path)

ops = iss.config.DEFAULT_OPS.copy()
ops.update({ 
    'camera_order': metadata["camera_order"],
    'genes_rounds': metadata["genes_rounds"],
    'barcode_rounds': metadata["barcode_rounds"],
    'use_rois': [1, 2, 5, 6],
    'ref_tile': (1, 5, 8),
    'correction_tiles': [(1, 5, 8), (1, 5, 9), (1, 4, 8), (1, 4, 9), (2, 4, 9), (2, 3, 9), (2, 2, 9), (2, 2, 8)],
    'barcode_ref_tiles': [(1, 5, 8), (1, 5, 9), (1, 4, 8), (1, 4, 9), (2, 4, 9), (2, 3, 9), (2, 2, 9), (2, 2, 8)],
    'average_clip_value': 2000,
})

## Filter detected barcodes

We will load detected barcodes and filter them by dot product.

In [None]:
roi = 5
dot_threshold = 0.2
gaussian_width_um = 10


In [None]:
# not required for actual analysis. Helps to set parameters
all_spots = pd.read_pickle(
    processed_path / data_path / f"barcode_round_spots_{roi}.pkl"
)
print(
    f"{len(all_spots)} spots with {len(all_spots.bases.unique())} distincts barcodes."
)
fig, ax = plt.subplots(1, 1)
fig.set_size_inches((5, 2))
ax.axvline(dot_threshold, color="black")
ax.hist(all_spots.dot_product_score, bins=np.arange(-0.5, 1, 0.01), histtype="step")
ax.set_xlabel("Dot product score")
_ = ax.set_ylabel("# spots")
spots = all_spots[all_spots.dot_product_score > dot_threshold]
print(f"{len(spots)} spots with {len(spots.bases.unique())} distincts barcodes.")

# make a 1d kernel to convovle
import cv2

acq_data = iss.io.load_single_acq_metdata(data_path, prefix="barcode_round_1_1")
pixel_size = acq_data["FrameKey-0-0-0"]["PixelSizeUm"]
gaussian_width = int(gaussian_width_um/pixel_size)
kernel_size = gaussian_width * 8
kernel_size += 1 - kernel_size % 2  # kernel shape must be odd
kernel = cv2.getGaussianKernel(kernel_size, sigma=int(gaussian_width))
# set the initial value so that single pixels after convolution have a peak of 1
kernel /= kernel.max()
kernel = kernel.astype(float)

fig = plt.figure(figsize=(5, 1))
ax = fig.add_subplot(1, 1, 1)
ax.plot((np.arange(kernel_size) - kernel_size / 2) * pixel_size, kernel)
ax.set_ylim([0, 1.05])
_ = ax.set_xlabel("Distance (um)")


## Select which tile we will use

In [None]:
# select tile
tile = (3, 9)

fig, ax = plt.subplots(1, 1)
fig.set_size_inches(10, 7)

corners = np.load(
    processed_path
    / data_path
    / "reg"
    / f"genes_round_1_1_roi{roi}_acquisition_tile_corners.npy"
)
ax.scatter(spots.x, spots.y, s=1, color="darkred", alpha=0.3)
ax.set_aspect("equal")
for row, corner in enumerate(corners):
    for col, corne in enumerate(corner):
        center = np.nanmean(corne, axis=1)
        square = np.hstack([corne, corne[:, [0]]])
        if (row == tile[0]) and (col == tile[1]):
            kwargs = dict(color="purple", zorder=10, lw=2)
        else:
            kwargs = dict(color="k")
        ax.plot(square[1, :], square[0, :], **kwargs)
        ax.text(
            center[1],
            center[0],
            s=f"({row}, {col})",
            verticalalignment="center",
            horizontalalignment="center",
        )
ax.invert_yaxis()


In [None]:
corner = corners[tile[0], tile[1]]
tile_coors = (roi, tile[0], tile[1])

spots_in_tile = pd.DataFrame(
    spots[
        (spots.x > corner[1, 0])
        & (spots.x < corner[1, 2] - 1)
        & (spots.y > corner[0, 0])
        & (spots.y < corner[0, 2] - 1)
    ]
)  # make a copy
spots_in_tile.x -= corner[1, 0]
spots_in_tile.y -= corner[0, 0]

genes_spots = pd.read_pickle(
    processed_path
    / data_path
    / "spots"
    / f"genes_round_spots_{roi}_{tile[0]}_{tile[1]}.pkl"
)


hyb_spots = [pd.read_pickle(
    processed_path / data_path / f"{hyb}_spots_{roi}.pkl"
) for hyb in metadata['hybridisation']]
hyb_spots = pd.concat(hyb_spots, axis=0, ignore_index=True)
hyb_spots =  pd.DataFrame(
    hyb_spots[
        (hyb_spots.x > corner[1, 0])
        & (hyb_spots.x < corner[1, 2] - 1)
        & (hyb_spots.y > corner[0, 0])
        & (hyb_spots.y < corner[0, 2] - 1)
    ]
) 
hyb_spots.x -= corner[1, 0]
hyb_spots.y -= corner[0, 0]

genes_all_channels = iss.pipeline.stitch.load_tile_ref_coors(
    data_path=data_path, tile_coors=tile_coors, prefix="genes_round_1_1"
)
barcodes_all_channels = iss.pipeline.stitch.load_tile_ref_coors(
    data_path=data_path, tile_coors=tile_coors, prefix="barcode_round_1_1"
)


In [None]:
# make a spot "image"
print("Convolving")
output_shape = genes_all_channels.shape[:2]
blur = iss.segment.spots.make_spot_image(
    spots_in_tile, gaussian_width=int(gaussian_width_um / pixel_size), dtype="single", output_shape=output_shape
)


In [None]:
# Show raw spot image for illustration
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(1,2,1)
ax.imshow(blur, vmax=40, vmin=0)
ax1 = fig.add_subplot(1,2,2)
snippet = np.array(blur)
snippet[snippet < 2] = 0
ax1.imshow(snippet, vmin=0)

for x in [ax, ax1]:
    x.scatter(spots_in_tile.x, spots_in_tile.y,s=1, color='red', alpha=0.3)
    x.axis('off')
plt.tight_layout()


In [None]:
# detect using opencv
barcoded_mask = iss.segment.barcodes.segment_spot_image(
    blur, binarise_threshold=10, distance_threshold=3
)
plt.imshow(barcoded_mask, cmap="tab20", interpolation="None")


# Find barcodes and genes inside cells

In [None]:
# find which barcode is in which cells
barcode_df = iss.segment.cells.count_rolonies(
    barcoded_mask, spots_in_tile, grouping_column="bases"
)
genes_df = iss.segment.cells.count_rolonies(barcoded_mask, genes_spots, grouping_column="gene")
hyb_df = iss.segment.cells.count_rolonies(barcoded_mask, hyb_spots, grouping_column="gene")
for gene in hyb_df.columns:
    if gene in genes_df.columns:
        genes_df.pop(gene)
fused_df = hyb_df.join(genes_df, how='outer')
fused_df[np.isnan(fused_df)] = 0
fused_df = fused_df.astype(int)


In [None]:

fig = plt.figure(figsize=(50, 7))
ax = fig.add_subplot(1,1,1)
img = ax.imshow(barcode_df.values, aspect='auto', interpolation='none', vmax=10, origin='lower')
cb = plt.colorbar(img, ax=ax)
cb.set_label("Rolonie #")
ax.set_xticks(np.arange(barcode_df.shape[1]))
ax.set_yticks(np.arange(barcode_df.shape[0]))
ax.set_yticklabels(barcode_df.index)
ax.set_xticklabels(barcode_df.columns, rotation=90)
plt.tight_layout()



In [None]:
fig = plt.figure(figsize=(30, 7))
ax = fig.add_subplot(1,1,1)
img = ax.imshow(genes_df.values, aspect='auto', interpolation='none', vmax=10, origin='lower')
cb = plt.colorbar(img, ax=ax)
cb.set_label("Rolonie #")
ax.set_xticks(np.arange(genes_df.shape[1]))
ax.set_yticks(np.arange(genes_df.shape[0]))
ax.set_yticklabels(genes_df.index)
ax.set_xticklabels(genes_df.columns, rotation=90)
plt.tight_layout()

In [None]:
plt.subplot(1,2,1)
plt.hist(genes_df.loc[1:].sum(axis=1))
plt.xlabel("# genes rolonies per cells")
plt.subplot(1,2,2)
plt.hist(barcode_df.loc[1:].sum(axis=1))
plt.xlabel("# barcode rolonies per cells")

In [None]:
iss.vis.plot_gene_matrix(fused_df.iloc[1:].astype(int), cmap="inferno", vmax=2)

# Plot example SST cell

In [None]:
cell_id = 23

cell_series = fused_df.loc[cell_id]
print(f"Ploting cell {cell_id} with {cell_series.Sst} sst rolonies")

In [None]:
mask = np.vstack(np.where(barcoded_mask == cell_id))
bounding_box = np.vstack([mask.min(axis=1), mask.max(axis=1)]).astype(int)
bounding_box += np.array([[-1, -1],[1,1]]) * np.diff(bounding_box, axis=0).max()
part2plot = (slice(*bounding_box[:, 0]), slice(*bounding_box[:, 1]))

data = np.dstack([barcodes_all_channels.std(axis=2), genes_all_channels.std(axis=2)])
lim = np.nanquantile(data, [0.05, 0.99], axis=(0,1))
img = iss.vis.to_rgb(data,
                     colors=[[1,0,0],[0,1,0]], vmin=lim[0], vmax=lim[1])
plt.imshow(img[part2plot])
plt.contour(barcoded_mask[part2plot])
plt.scatter(spots_in_tile.x - bounding_box[0, 1], spots_in_tile.y - bounding_box[0, 0], s=10, label='Barcodes')
plt.scatter(genes_spots.x - bounding_box[0, 1], genes_spots.y - bounding_box[0, 0], s=10, label='Genes')
plt.xlim([0, np.diff(bounding_box, axis=0)[0,1]])
plt.ylim([np.diff(bounding_box, axis=0)[0, 0], 0])


In [None]:
hyb1_all_channels = iss.pipeline.stitch.load_tile_ref_coors(
    data_path=data_path, tile_coors=tile_coors, prefix="hybridisation_1_1"
)
hyb2_all_channels = iss.pipeline.stitch.load_tile_ref_coors(
    data_path=data_path, tile_coors=tile_coors, prefix="hybridisation_2_1"
)


In [None]:
raise NotImplementedError("Everything below needs to be changed or deleted")

In [None]:
cells = iss.segment.cells.cellpose_segmentation(
    snippet,
    channels=(0, 0),
    flow_threshold=0.5,
    min_pix=0,
    dilate_pix=0,
    rescale=None,
    model_type="CP",
    use_gpu=False,
    diameter=int(20 / pixel_size),
)
cellpose_cells = np.array(cells)
plt.imshow(cells, cmap="Set2", interpolation="None")
print(f"Found {len(np.unique(cells))-1} cells")


# Debug Opencv version

Now using classic opencv

In [None]:
mask = 255 * (blur > 10).astype("uint8")
kernel = np.ones((5, 5), dtype="uint8") * 255
background = cv2.dilate(mask, kernel, iterations=10)
dst2nonzero = cv2.distanceTransform(mask, distanceType=cv2.DIST_L2, maskSize=5)
is_cell = 255 * (dst2nonzero > 20).astype("uint8")
ret, markers = cv2.connectedComponents(is_cell)
# make the background to 1
markers += 1
# and part to watershed to 0
markers[np.bitwise_xor(background, is_cell).astype(bool)] = 0
# watershed required a rgb image
stack = cv2.normalize(blur, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8U)
stack = cv2.cvtColor(stack, cv2.COLOR_GRAY2BGR)
water = cv2.watershed(stack, markers)
water -= 1  # put the background seed to 0.
water[water < 0] = 0  # put borders into background

fig, axes = plt.subplots(2, 2)
fig.set_size_inches(10, 10)
axes[0,0].imshow(blur)
axes[0,0].contour(blur, levels=[5,10,20,40], colors=['orange', 'pink', 'red', 'darkred'])
axes[0,1].imshow(dst2nonzero)
axes[1,0].imshow(cv2.MORPH_CLOSE)

In [None]:
# plot various binary step
debug = iss.segment.barcodes.segment_spot_image(
    blur, binarise_threshold=5, distance_threshold=10, debug=True
)
fig, axes = plt.subplots(2, 2)
fig.set_size_inches(10, 10)
axes[0, 0].imshow(debug["binary"])
axes[0, 0].set_title("Binarised")
axes[0, 1].imshow(debug["background"])
axes[0, 1].set_title("Background is blue")
axes[1, 0].imshow(debug["distance"])
axes[1, 0].set_title("Distance 2 non-zero")
axes[1, 1].imshow(debug["seeds"])
axes[1, 1].set_title("Cells")

plt.tight_layout()


In [None]:
fig, axes = plt.subplots(2, 2)
fig.set_size_inches(10, 10)
axes[0, 0].imshow(debug["initial_labels"], cmap="tab20", interpolation="None")
axes[0, 0].set_title("Markers")
axes[0, 1].imshow(debug["watershed"], cmap="tab20", interpolation="None")
axes[0, 1].set_title("Watershed")
axes[1, 0].imshow(debug["intial_labels"], cmap="tab20", interpolation="None")
axes[1, 0].contour(debug["watershed"], colors="darkred")
axes[1, 0].set_title("Cell contours")
axes[1, 1].imshow(blur, cmap="tab20", interpolation="None")
axes[1, 0].contour(debug["watershed"], colors="darkred")
axes[1, 0].set_title("Cell contours")


for x in axes.flatten():
    x.axis("off")
plt.tight_layout()
print(f"Found {len(np.unique(water))-2} cells")
opencv_cells = np.array(water)


"""
background=background,
binary=mask,
seeds=is_cell,
distance=dst2nonzero,
initial_labels=markers,
watershed=water,
"""


# Now watershed

We want to flood from each cell and extend around but not too far. We can do that by
setting a label for the background that is far from cells

# Overlay to data

In [None]:
# select tile
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(10, 7)
corners = np.load(
    processed_path
    / data_path
    / "reg"
    / f"genes_round_1_1_roi{roi}_acquisition_tile_corners.npy"
)
ax.scatter(spots.x, spots.y, s=1, color="red", alpha=0.3)
ax.set_aspect("equal")
for row, corner in enumerate(corners):
    for col, corne in enumerate(corner):
        center = np.nanmean(corne, axis=1)
        ax.plot(corne[1, :], corne[0, :], color="k")
        ax.text(
            center[1],
            center[0],
            s=f"({row}, {col})",
            verticalalignment="center",
            horizontalalignment="center",
        )
ax.invert_yaxis()


In [None]:
# Get raw data for overlay
tile_coords = (roi, 3, 9)
genes_all_channels = iss.pipeline.stitch.load_tile(
    data_path=data_path, tile_coordinates=tile_coords, prefix="genes_round_1_1"
)
barcodes_all_channels = iss.pipeline.stitch.load_tile(
    data_path=data_path, tile_coordinates=tile_coords, prefix="barcode_round_1_1"
)
dapi = iss.pipeline.stitch.load_tile(
    data_path=data_path, tile_coordinates=tile_coords, prefix="dapi_1"
)

ref_corners = np.load(
    processed_path
    / data_path
    / "reg"
    / f"genes_round_1_1_roi{roi}_acquisition_tile_corners.npy"
)
dapi = dapi[:, :, 0, 0]
barcodes = np.nanstd(barcodes_all_channels, axis=2)[..., 0]
genes = np.nanstd(genes_all_channels, axis=2)[..., 0]


In [None]:
ref_corners.shape


In [None]:
borders = np.zeros(snippet.shape, dtype="uint8")
borders[water == -1] = 255
borders = cv2.dilate(borders, np.ones((5, 5)))
glim = np.nanquantile(genes, [0.1, 0.99])
blim = np.nanquantile(barcodes, [0.1, 0.99])
dlim = np.nanquantile(dapi, [0.1, 0.99])
img = iss.vis.to_rgb(
    np.dstack((dapi, genes, barcodes)),
    colors=[[0, 0, 1], [1, 0, 0], [0, 1, 0]],
    vmin=np.array([dlim[0], glim[0], blim[0]]),
    vmax=[dlim[1], glim[1], blim[1]],
)

fig, axes = plt.subplots(1, 1)
axes = [axes]
fig.set_size_inches(20, 20)
corner = ref_corners[tile_coords[1], tile_coords[2]]
valid_spot = spots[
    (spots.x > corner[1, 0])
    & (spots.x < corner[1, 3])
    & (spots.y > corner[0, 0])
    & (spots.y < corner[0, 3])
]
for x in axes:
    x.scatter(
        spots.x - corner[1, 0], spots.y - corner[0, 0], s=10, color="yellow", alpha=1
    )
    x.imshow(img)
    x.axis("off")

plt.tight_layout()


In [None]:
# basic imread
fname = 'barcode_round_1_1_MMStack_5-Pos000_000_fstack.tif'
full_fname = processed_path / data_path / "barcode_round_1_1" / fname
%timeit iss.io.load.load_stack(full_fname).astype('single')

In [None]:
from tifffile import imread
%timeit np.moveaxis(imread(full_fname).astype('single'), 0, 2)

Optimising tile loading

In [None]:
from skimage.morphology import binary_dilation

prefix = "barcode_round"
tile_coors = (5, 0, 0)
nrounds = 1
suffix = "fstack"
filter_r = (2, 4)


## original version with just processing steps

This is for reference

In [None]:
# origin version with just processing steps
def original_version():
    processed_path = Path(PARAMETERS["data_root"]["processed"])
    tforms_fname = f"tforms_corrected_{prefix}_{tile_coors[0]}_{tile_coors[1]}_{tile_coors[2]}.npz"
    tforms_path = processed_path / data_path / "reg" / tforms_fname
    tforms = np.load(tforms_path, allow_pickle=True)

    stack = iss.pipeline.load_sequencing_rounds(
        data_path, tile_coors, suffix=suffix, prefix=prefix, nrounds=nrounds
    )
    tforms = iss.pipeline.generate_channel_round_transforms(
        tforms["angles_within_channels"],
        tforms["shifts_within_channels"],
        tforms["scales_between_channels"],
        tforms["angles_between_channels"],
        tforms["shifts_between_channels"],
        stack.shape[:2],
    )
    stack = iss.pipeline.align_channels_and_rounds(stack, tforms)
    stack = iss.pipeline.apply_illumination_correction(data_path, stack, prefix)
    bad_pixels = np.any(np.isnan(stack), axis=(2, 3))
    stack[np.isnan(stack)] = 0
    stack = iss.pipeline.filter_stack(stack, r1=filter_r[0], r2=filter_r[1])
    mask = np.ones((filter_r[1] * 2 + 1, filter_r[1] * 2 + 1))
    bad_pixels = binary_dilation(bad_pixels, mask)

    correction_path = processed_path / data_path / f"correction_{prefix}.npz"
    norm_factors = np.load(correction_path, allow_pickle=True)["norm_factors"]
    stack = stack / norm_factors[np.newaxis, np.newaxis, :, :nrounds]
    return stack
%timeit original_version()


In [None]:
# separate tform version with just processing steps
processed_path = Path(PARAMETERS["data_root"]["processed"])
tforms_fname = f"tforms_corrected_{prefix}_{tile_coors[0]}_{tile_coors[1]}_{tile_coors[2]}.npz"
tforms_path = processed_path / data_path / "reg" / tforms_fname
tforms = np.load(tforms_path, allow_pickle=True)
image_shape = (3300, 3296)
tforms = iss.pipeline.generate_channel_round_transforms(
    tforms["angles_within_channels"],
    tforms["shifts_within_channels"],
    tforms["scales_between_channels"],
    tforms["angles_between_channels"],
    tforms["shifts_between_channels"],
    image_shape,
)
def pregenerate_tforms():
    stack = iss.pipeline.load_sequencing_rounds(
        data_path, tile_coors, suffix=suffix, prefix=prefix, nrounds=nrounds
    )

    stack = iss.pipeline.align_channels_and_rounds(stack, tforms)
    stack = iss.pipeline.apply_illumination_correction(data_path, stack, prefix)
    bad_pixels = np.any(np.isnan(stack), axis=(2, 3))
    stack[np.isnan(stack)] = 0
    stack = iss.pipeline.filter_stack(stack, r1=filter_r[0], r2=filter_r[1])
    mask = np.ones((filter_r[1] * 2 + 1, filter_r[1] * 2 + 1))
    bad_pixels = binary_dilation(bad_pixels, mask)

    correction_path = processed_path / data_path / f"correction_{prefix}.npz"
    norm_factors = np.load(correction_path, allow_pickle=True)["norm_factors"]
    stack = stack / norm_factors[np.newaxis, np.newaxis, :, :nrounds]
%timeit pregenerate_tforms()


In [None]:
iss.io.load.load_stack(full_fname).astype("single").shape


In [None]:
np.moveaxis(imread(full_fname).astype("single"), 0, 2).shape


In [None]:
roi = 5
(
    stitched_stack_dapi,
    stitched_stack_genes,
    angle,
    shift,
) = iss.pipeline.stitch_and_register(
    data_path, "genes_round_1_1", "DAPI_1", roi=roi, downsample=5
)


In [None]:
(
    stitched_stack_barcode,
    stitched_stack_genes,
    angle,
    shift,
) = iss.pipeline.stitch_and_register(
    data_path, "genes_round_1_1", "barcode_round_1_1", roi=roi, downsample=5
)


In [None]:
masks = np.load(processed_path / data_path / f"masks_{roi}.npy")
im = np.stack(
    [
        stitched_stack_genes[3000:10000, 12000:20000],
        stitched_stack_dapi[3000:10000, 12000:20000],
        masks[3000:10000, 12000:20000] > 0,
    ],
    axis=2,
)
shift_right, shift_down, tile_shape = iss.pipeline.register_adjacent_tiles(
    data_path, ref_coors=ops["ref_tile"], prefix="genes_round_1_1"
)
genes_spots = iss.pipeline.merge_roi_spots(
    data_path, shift_right, shift_down, tile_shape, iroi=roi, prefix="genes_round"
)


In [None]:
barcode_spots = iss.pipeline.merge_roi_spots(
    data_path, shift_right, shift_down, tile_shape, iroi=roi, prefix="barcode_round"
)

plt.figure(figsize=(50, 50))
plt.imshow(
    iss.vis.to_rgb(
        im,
        colors=[[1, 0, 0], [0, 0, 1], [0, 1, 0]],
        vmax=[400, 200, 1],
        vmin=np.array([30, 0, 0]),
    )
)
plt.plot(
    barcode_spots["x"] - 12000, barcode_spots["y"] - 3000, ".r", alpha=1, markersize=10
)
plt.plot(
    genes_spots["x"] - 12000,
    genes_spots["y"] - 3000,
    ".",
    color="purple",
    alpha=1,
    markersize=10,
)
plt.xlim([0, 4000])
plt.ylim([4000, 0])
plt.axis("off")
