# Select basecalling parameters

Calling base from fluorescence

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import iss_preprocess as iss
from iss_preprocess.vis import plot_matrix_with_colorbar
from iss_preprocess.pipeline.sequencing import load_spot_sign_image
import matplotlib.pyplot as plt

data_path = "becalia_rabies_barseq/BRAC8501.6a/chamber_07"
ref_tile_index = 0  # which of the reference tiles do we want to use for plots

In [None]:
# Load some example data
processed_path = iss.io.get_processed_path(data_path)
ops = iss.io.load.load_ops(data_path)
tile_coors = ops["barcode_ref_tiles"][ref_tile_index]
cluster_means = np.load(processed_path / "barcode_cluster_means.npy")

stack, bad_pixels = iss.pipeline.load_and_register_sequencing_tile(
    data_path,
    tile_coors,
    filter_r=ops["filter_r"],
    prefix="barcode_round",
    suffix=ops["barcode_projection"],
    nrounds=ops["barcode_rounds"],
    correct_channels=ops["barcode_correct_channels"],
    corrected_shifts=ops["corrected_shifts"],
    correct_illumination=True,
)
stack = stack[:, :, np.argsort(ops["camera_order"]), :]
stack[bad_pixels, :, :] = 0
print(stack.shape)

# Spot detection threshold

Spot detection works by thresholding the average image across round and channels using
`ops["barcode_detection_threshold_basecalling"]`

In [None]:
im = np.mean(stack, axis=(2, 3))
threshold = ops["barcode_detection_threshold_basecalling"]
spots = iss.segment.spots.detect_spots(im, threshold=threshold)

In [None]:
# Find the place with the highest density of spots, just for plotting
x, y = spots["x"].values, spots["y"].values
# Create a grid of potential disk centers
x_grid, y_grid = np.meshgrid(
    np.arange(200, stack.shape[1] - 200, 25),
    np.arange(200, stack.shape[0] - 200, 25),
)
# Compute the Euclidean distance from each spot to each potential center
distances = np.sqrt((x[:, None, None] - x_grid) ** 2 + (y[:, None, None] - y_grid) ** 2)
# Count the number of spots within a 100px radius for each potential center
counts = np.sum(distances <= 50, axis=0)
center = np.unravel_index(counts.argmax(), counts.shape)
center = (x_grid[center], y_grid[center])

fig, axes = plt.subplots(1, 2, figsize=(10, 10))
vmin = np.percentile(im, 0.1)
vmax = np.percentile(im, 99.99)
cax, cb = plot_matrix_with_colorbar(
    im, ax=axes[0], vmin=vmin, vmax=ops["barcode_detection_threshold_basecalling"] * 1.1
)
cax.axhline(ops["barcode_detection_threshold_basecalling"], color="r", lw=2)
# axes[0].scatter(spots["x"], spots["y"], c='k', s=0.1, alpha=0.5)
axes[0].set_title("Mean image")
cax, cb = plot_matrix_with_colorbar(
    im[center[1] - 100 : center[1] + 100, center[0] - 100 : center[0] + 100],
    ax=axes[1],
    vmin=vmin,
    vmax=vmax,
    interpolation="none",
)
valid_spots = spots[
    (spots.x > center[0] - 100)
    & (spots.x < center[0] + 100)
    & (spots.y > center[1] - 100)
    & (spots.y < center[1] + 100)
]
axes[1].scatter(
    valid_spots["x"] - center[0] + 100, valid_spots["y"] - center[1] + 100, c="r", s=0.5
)
cax.axhline(ops["barcode_detection_threshold_basecalling"], color="r", lw=2)

for x in axes:
    x.axis("off")

## [Optional] Set spot score threshold

Spots can be further filtered with a spot score, which indicates how much the fluorescence
is similar to the average filtered fluoresence to a spot sign image. 

First one must calculate this "average spot sign image". This is done by threhsolding
the average image of isolated rolonies. The threshold, `ops["spot_shape_threshold"]`
defines how narrow the spot image is.

In [None]:
raw_spot_image = load_spot_sign_image(
    data_path, ops["spot_shape_threshold"], return_raw_image=True
)
spot_sign_image = load_spot_sign_image(data_path, ops["spot_shape_threshold"])
higher_th_sign_image = load_spot_sign_image(data_path, 0.4)
mid = spot_sign_image.shape[0] / 2
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
cax, cb = plot_matrix_with_colorbar(raw_spot_image, ax=axes[0])
axes[0].contour(
    raw_spot_image, levels=np.array([-1, 1]) * ops["spot_shape_threshold"], colors="r"
)
cax.axhline(ops["spot_shape_threshold"], color="r")
cax.axhline(0.4, color="purple")
axes[0].set_title("Average spot image")
plot_matrix_with_colorbar(
    spot_sign_image,
    ax=axes[1],
    cmap="RdBu_r",
    vmin=-1,
    vmax=1,
    extent=[-mid, mid, -mid, mid],
)
_ = axes[1].set_title(f"Spot sign image\n(threshold = {ops['spot_shape_threshold']})")
plot_matrix_with_colorbar(
    higher_th_sign_image,
    ax=axes[2],
    cmap="RdBu_r",
    vmin=-1,
    vmax=1,
    extent=[-mid, mid, -mid, mid],
)
_ = axes[2].set_title("Spot sign image\n(threshold = 0.4)")
for ax in axes:
    ax.axis("off")

Then the spot score threshold is calculated by counting the number of pixels that have
the same sign (positive of negative) as the spot sign image around the rolonie.

There are more negative pixels, to give more weight to positive pixel, you can set
`rho=ops["barcode_spot_rho"]` to be greater than 1

In [None]:
import cv2

rho = ops["barcode_spot_rho"]

neg_max = np.sum(np.sign(spot_sign_image) == -1)
pos_max = np.sum(np.sign(spot_sign_image) == 1)
pos_filter = (np.sign(spot_sign_image) == 1).astype(float)
neg_filter = (np.sign(spot_sign_image) == -1).astype(float)
filt_pos = cv2.filter2D(
    (im > 0).astype(float), -1, pos_filter, borderType=cv2.BORDER_REPLICATE
)
filt_neg = cv2.filter2D(
    (im < 0).astype(float), -1, neg_filter, borderType=cv2.BORDER_REPLICATE
)
pos_pixels = filt_pos[spots["y"], spots["x"]]
neg_pixels = filt_neg[spots["y"], spots["x"]]

score_image = (filt_neg + filt_pos * rho) / (neg_max + pos_max * rho)
score_image_doublerho = (filt_neg + filt_pos * rho /5) / (
    neg_max + pos_max * rho / 5
)
spot_score = (neg_pixels + pos_pixels * rho) / (neg_max + pos_max * rho)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(9, 6))
axes = axes.ravel()
cax, cb = plot_matrix_with_colorbar(
    im, ax=axes[0], vmin=np.percentile(im, 0.01), vmax=np.percentile(im, 99.99)
)
cax.axhline(ops["barcode_detection_threshold_basecalling"], color="r", lw=2)
axes[0].scatter(spots["x"], spots["y"], c="k", s=1)
axes[0].set_title("Mean image")

plot_matrix_with_colorbar(filt_pos / pos_max, ax=axes[1], cmap="Reds")
axes[1].set_title("Positive filter")

plot_matrix_with_colorbar(filt_neg / neg_max, ax=axes[2], cmap="Blues")
axes[2].set_title("Negative filter")

mi, ma = np.percentile(score_image, [10, 99.99])
plot_matrix_with_colorbar(score_image, ax=axes[3], vmax=ma, vmin=mi)
axes[3].set_title(f"Spot score image\n(rho={rho})")

plot_matrix_with_colorbar(score_image_doublerho, ax=axes[4], vmax=ma, vmin=mi)
axes[4].set_title(f"Spot score image\n(rho={rho/5})")

cax, cb = plot_matrix_with_colorbar(
    np.zeros_like(im) + np.nan, ax=axes[5], vmax=ma, vmin=mi
)
axes[5].scatter(spots["x"], spots["y"], c=spot_score, s=10, vmax=ma, vmin=mi)
axes[5].set_title(f"Spot score\n(rho={rho})")


for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect("equal")
    ax.set_facecolor("k")
    ax.set_xlim(center[0] - 100, center[0] + 100)
    ax.set_ylim(center[1] + 100, center[1] - 100)
fig.tight_layout()

## Base calling

This is done by loading the bleed-through matrices previously computed:

In [None]:
cluster_means = np.load(processed_path / "barcode_cluster_means.npy")
fig, axes = plt.subplots(1, 4, figsize=(10, 5))
for cl in range(4):
    axes[cl].set_title(f"Cluster {cl}")
    axes[cl].imshow(cluster_means[:, cl].T)
    axes[cl].set_xlabel("Round")
_ = axes[0].set_ylabel("Channel")

For each round, we look at the relevant column of the bleed-through matrix and normalise
the N-chan long vector for each cluster. We then take the fluorescence of each rolonies 
for this round and also normalise it. 

Finally we take the dot product and find the max to identify the best cluster

In [None]:
iss.call.extract_spots(spots, stack, ops["spot_extraction_radius"])
x = np.stack(spots["trace"], axis=2)
cluster_inds = []
iround = 0
cluster_means = np.load(processed_path / "barcode_cluster_means.npy")
this_round_means = cluster_means[iround] / np.linalg.norm(
    cluster_means[iround], axis=1, keepdims=True
)
x_norm = x[iround, :, :].T / np.linalg.norm(x[iround, :, :].T, axis=1, keepdims=True)
score = x_norm @ this_round_means.T
cluster_ind = np.argmax(score, axis=1)
cluster_inds.append(cluster_ind)

fig, ax = plt.subplots(1, 2, figsize=(5, 2))
cluster_means = np.load(processed_path / "barcode_cluster_means.npy")
ax[0].imshow(this_round_means, vmin=0, vmax=1)
ax[0].set_xlabel("Channel")
ax[0].set_ylabel("Cluster")
ax[0].set_title(f"Cluster means round {iround}")
cluster_means = np.load(processed_path / "barcode_cluster_means.npy")
this_round_means = cluster_means[iround] / np.linalg.norm(
    cluster_means[iround], axis=1, keepdims=True
)
ax[1].imshow(this_round_means, vmin=0, vmax=1)
ax[1].set_xlabel("Channel")
ax[1].set_ylabel("Cluster")
ax[1].set_title("Normalized")

# for plotting, limit spots
window = 100
is_valid = (
    (spots.x > center[0] - window)
    & (spots.x < center[0] + window)
    & (spots.y > center[1] - window)
    & (spots.y < center[1] + window)
)

fig, axes = plt.subplots(3, 1, figsize=(7, 5), sharex=True, sharey=True)
plot_matrix_with_colorbar(x[iround, :, is_valid].T, axes[0], aspect="auto", vmin=0)
plot_matrix_with_colorbar(x_norm[is_valid].T, axes[1], aspect="auto", vmin=0)
plot_matrix_with_colorbar(score[is_valid].T, axes[2], aspect="auto", vmin=0)
axes[2].scatter(
    np.arange(len(cluster_ind[is_valid])), cluster_ind[is_valid], color="k", s=1
)
axes[2].set_xlabel("Rolonie #")
for ax in axes:
    ax.set_ylabel("Channel")
axes[0].set_title(f"Extracted fluorescence round {iround}")
axes[1].set_title("Normalised fluorescence")
axes[2].set_title("Cluster score")

_ = axes[2].set_ylabel("Cluster")