# Select spot dection thresholds

To set up base calling, we need to create bleed-through matrices. This is done by 
selecting a threshold for spot detection and one for filtering isolated spots. This 
notebook will help you select the best threshold for your data.

In [None]:
import scipy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import iss_preprocess as iss

data_path = "becalia_rabies_barseq/BRAC8501.6a/chamber_07"
prefix = "barcode_round"
ref_tile_index = 1  # which of the reference tiles do we want to use for plots

In [None]:
ops = iss.io.load.load_ops(data_path)
short_pref = prefix.split("_")[0]
ref_tiles = ops[f"{short_pref}_ref_tiles"]
print(f"{len(ref_tiles)} reference tiles found. Using {ref_tile_index}.")
ref_tile = ref_tiles[ref_tile_index]

In [None]:
print(f"Loading registered data for {ref_tile}")
stack, _ = iss.pipeline.load_and_register_sequencing_tile(
    data_path,
    ref_tile,
    filter_r=ops["filter_r"],
    prefix=prefix,
    suffix=ops[f"{short_pref}_projection"],
    nrounds=ops[f"{prefix}s"],
    correct_channels=ops[f"{short_pref}_correct_channels"],
    corrected_shifts=ops["corrected_shifts"],
    correct_illumination=False,
)
stack = stack[:, :, np.argsort(ops["camera_order"]), :]

In [None]:
print("Making reference image using STD")
reference = np.std(stack, axis=(2, 3))
reference.shape

## Spot detection

First step is to detect some spots

In [None]:
from iss_preprocess.segment.spots import detect_spots

detection_threshold = (ops[f"{short_pref}_detection_threshold"],)
spots = detect_spots(reference, threshold=detection_threshold)

In [None]:
# Find the place with the highest density of spots, just for plotting
x, y = spots["x"].values, spots["y"].values
# Create a grid of potential disk centers
x_grid, y_grid = np.meshgrid(
    np.arange(200, stack.shape[1] - 200, 25),
    np.arange(200, stack.shape[0] - 200, 25),
)
# Compute the Euclidean distance from each spot to each potential center
distances = np.sqrt(
    (x[:, None, None] - x_grid) ** 2 + (y[:, None, None] - y_grid) ** 2
)
# Count the number of spots within a 100px radius for each potential center
counts = np.sum(distances <= 50, axis=0)
center = np.unravel_index(counts.argmax(), counts.shape)
center = (x_grid[center], y_grid[center])

In [None]:
from iss_preprocess.vis import round_to_rgb, plot_matrix_with_colorbar

w = 200
extent = [[center[1] - w, center[1] + w], [center[0] - w, center[0] + w]]

valid_spots = spots[
    (spots.x > extent[1][0])
    & (spots.x < extent[1][1])
    & (spots.y > extent[0][0])
    & (spots.y < extent[0][1])
]

rgb = round_to_rgb(
    stack,
    1,
    extent=extent,
    channel_colors=([1, 0, 0], [0, 1, 0], [1, 0, 1], [0, 1, 1]),
    vmax=np.percentile(stack[..., 0], 99.99, axis=(0, 1)),
    vmin=np.percentile(stack[..., 0], 0.2, axis=(0, 1)),
)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# make a colorbar and hide it to have axis the same size
cax, cb = plot_matrix_with_colorbar(rgb, axes[0])
cax.clear()
cax.axis('off')
axes[0].set_title("Raw data, round 1")
ref_part = reference[slice(*extent[0]), slice(*extent[1])]
cax, cb = plot_matrix_with_colorbar(
    ref_part,
    axes[1],
    cmap="viridis",
    vmax=np.percentile(reference, 99.999),
    vmin=np.percentile(reference, 0.2),
)
axes[1].contour(ref_part, levels=[detection_threshold], colors="r", linewidths=0.5)
cax.axhline(detection_threshold, color="r", lw=2)
axes[1].set_title("Standard deviation projection")
for ax in axes:
    ax.scatter(
        valid_spots.x - extent[1][0],
        valid_spots.y - extent[0][0],
        s=2,
        c="k",
        marker="x",
    )
    ax.axis("off")
fig.tight_layout()

## Select isolated spots

We keep only spots with a minimum distance to the next spot. This is done by mesuring
the fluorescence in an annulus around each spot and thresholding

In [None]:
from iss_preprocess.segment import detect_isolated_spots
from iss_preprocess.coppafish import annulus

isolation_threshold = ops[f"{short_pref}_isolation_threshold"]
annulus_r = (3, 7)
strel = annulus(annulus_r[0], annulus_r[1])
strel = strel / np.sum(strel)
annulus_image = scipy.ndimage.correlate(reference, strel)
isolated = annulus_image[spots["y"], spots["x"]] < isolation_threshold
isolated_spots = spots[isolated]
non_isolated_spots = spots[~isolated]

In [None]:
annulus_part = annulus_image[slice(*extent[0]), slice(*extent[1])]
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
plot_matrix_with_colorbar(
    ref_part,
    axes[0],
    vmax=np.percentile(ref_part, 99.999),
    vmin=np.percentile(ref_part, 0.2),
)
axes[0].set_title("Standard deviation projection")

cax, cb = plot_matrix_with_colorbar(
    annulus_part,
    axes[1],
    vmax=np.percentile(annulus_part, 99.999),
    vmin=np.percentile(annulus_part, 0.05),
)
print(isolation_threshold)
cax.axhline(isolation_threshold, color="r", lw=2)
if isolation_threshold < cax.get_ylim()[0]:
    cax.set_ylim(isolation_threshold, cax.get_ylim()[1])
if isolation_threshold > cax.get_ylim()[1]:
    cax.set_ylim(cax.get_ylim()[0], isolation_threshold)
for c, w, l in zip(
    ["r", "k"], [non_isolated_spots, isolated_spots], ["non-isolated", "isolated"]
):
    v = w[
        (w.x > extent[1][0])
        & (w.x < extent[1][1])
        & (w.y > extent[0][0])
        & (w.y < extent[0][1])
    ]
    axes[1].scatter(
        v.x - extent[1][0], v.y - extent[0][0], s=2, c=c, marker="+", label=l
    )
axes[1].legend(
    loc="lower left",
    bbox_to_anchor=(0.0, 0.0),
)
axes[1].set_title("Annulus projection")

# plot annulus values only where spots are
img = np.zeros_like(annulus_part) + np.nan
values = annulus_image[spots["y"], spots["x"]]
valid_values = annulus_image[valid_spots["y"], valid_spots["x"]]
for i, spot in valid_spots.iterrows():
    px = int(spot.x - extent[1][0])
    py = int(spot.y - extent[0][0])
    img[py, px] = annulus_part[py, px]
vmin, vmax = valid_values.min(), valid_values.max()
sc = axes[2].scatter(
    valid_spots.x - extent[1][0],
    valid_spots.y - extent[0][0],
    s=2,
    c=valid_values,
    marker="x",
    vmin=vmin,
    vmax=vmax,
)
axes[2].set_title("Annulus projection, spot values")
plot_matrix_with_colorbar(img, axes[2], vmin=vmin, vmax=vmax)

for ax in axes:
    ax.axis("off")

## Extract spots

Now that we have detected isolated spots, we can extract the fluorescence of each spot.
This is done with a given radius around the spot.


In [None]:
spots = detect_isolated_spots(
    reference,
    detection_threshold=detection_threshold,
    isolation_threshold=isolation_threshold,
)

iss.call.extract_spots(spots, stack, ops["spot_extraction_radius"])

In [None]:
# plot the spots in the part of reference image and show the extraction radius
from skimage.morphology import disk

fig = plt.figure(figsize=(7, 6))
axes = [plt.subplot2grid(fig=fig, shape=(2,2), loc=(0, i)) for i in range(2)]
w = 50
smallext = [
    [center[1] - w, center[1] + w],
    [center[0] - w, center[0] + w],
]

valid_spots = spots[
    (spots.x > smallext[1][0])
    & (spots.x < smallext[1][1])
    & (spots.y > smallext[0][0])
    & (spots.y < smallext[0][1])
]

ref_part = reference[slice(*smallext[0]), slice(*smallext[1])]

plot_matrix_with_colorbar(
    ref_part,
    axes[0],
    vmax=np.percentile(ref_part, 99.999),
    vmin=np.percentile(ref_part, 0.2),
)
axes[0].set_title("Standard deviation projection")
axes[0].scatter(
    valid_spots.x - smallext[1][0],
    valid_spots.y - smallext[0][0],
    s=2,
    c="r",
    marker="x",
)


spot_footprint = disk(ops["spot_extraction_radius"])
drr, dcc = np.where(spot_footprint)
drr -= spot_footprint.shape[0]//2
dcc -= spot_footprint.shape[1]//2

img = np.zeros_like(ref_part) + np.nan
traces = []
for i, spot in spots.iterrows():
    if (
        spot.x < smallext[1][0]
        or spot.x >= smallext[1][1]
        or spot.y < smallext[0][0]
        or spot.y >= smallext[0][1]
    ):
        continue
    traces.append(spot.trace)
    rr = np.clip(drr + spot["y"]- smallext[0][0], 0, stack.shape[0])
    cc = np.clip(dcc + spot["x"]- smallext[1][0], 0, stack.shape[1])
    img[rr, cc] = ref_part[rr, cc]
cmap = plt.cm.viridis
cmap.set_bad("black", 1)
plot_matrix_with_colorbar(
    img,
    axes[1],
    vmax=np.percentile(ref_part, 99.999),
    vmin=np.percentile(ref_part, 0.2),
    cmap=cmap,
    interpolation="none",
)

axes[1].set_title("Extracted pixels")
print(f"Spot extraction radius: {ops['spot_extraction_radius']}.")
for ax in axes:
    ax.axis("off")


axes_t = plt.subplot2grid(fig=fig, shape=(2,1), loc=(1,0))
traces = np.stack(traces, axis=2)
traces = np.moveaxis(traces, [0,1], [1, 2])
rgb_trace = iss.vis.to_rgb(traces, colors=[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]],
vmin=[0]*4, vmax=[traces.max()] * 4)
axes_t.imshow(rgb_trace, aspect='auto', interpolation='None')
axes_t.set_xlabel('Round')
axes_t.set_ylabel('Rolonie')
axes_t.set_title('Extracted fluorescence')
fig.tight_layout()