# Select bleedthrough parameters

To set up base calling, we need to create bleed-through matrices. This is done by 
selecting a threshold for spot detection and one for filtering isolated spots. This 
notebook will help you select the best threshold for your data.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy

import iss_preprocess as iss
import iss_preprocess.io.load
import iss_preprocess.pipeline.register

data_path = "becalia_rabies_barseq/BRAC8323.6g/chamber_16"
prefix = "genes_round"
ref_tile_index = 0  # which of the reference tiles do we want to use for plots

In [None]:
ops = iss.io.load.load_ops(data_path)
short_pref = prefix.split("_")[0]
ref_tiles = ops[f"{short_pref}_ref_tiles"]
print(f"{len(ref_tiles)} reference tiles found. Using {ref_tile_index}.")
ref_tile = ref_tiles[ref_tile_index]

In [None]:
print(f"Loading registered data for {ref_tile}")
stack, _ = iss.pipeline.register.load_and_register_sequencing_tile(
    data_path,
    ref_tile,
    filter_r=ops["filter_r"],
    prefix=prefix,
    suffix=ops[f"{short_pref}_projection"],
    nrounds=ops[f"{prefix}s"],
    correct_channels=ops[f"{short_pref}_correct_channels"],
    corrected_shifts=ops["corrected_shifts"],
    correct_illumination=False,
)
stack = stack[:, :, np.argsort(ops["camera_order"]), :]

In [None]:
print("Making reference image using STD")
reference = np.std(stack, axis=(2, 3))
reference.shape

## Spot detection

First step is to detect some spots

In [None]:
from iss_preprocess.segment.spots import detect_spots

ops = iss.io.load.load_ops(data_path)
detection_threshold = ops[f"{short_pref}_detection_threshold"]
print(f"detection_threshold: {detection_threshold}")
spots = detect_spots(reference, threshold=detection_threshold)

In [None]:
# Find the place with the highest density of spots, just for plotting
x, y = spots["x"].values, spots["y"].values
# Create a grid of potential disk centers
x_grid, y_grid = np.meshgrid(
    np.arange(200, stack.shape[1] - 200, 25),
    np.arange(200, stack.shape[0] - 200, 25),
)
# Compute the Euclidean distance from each spot to each potential center
distances = np.sqrt((x[:, None, None] - x_grid) ** 2 + (y[:, None, None] - y_grid) ** 2)
# Count the number of spots within a 100px radius for each potential center
counts = np.sum(distances <= 50, axis=0)
center = np.unravel_index(counts.argmax(), counts.shape)
center = (x_grid[center], y_grid[center])

In [None]:
from iss_preprocess.vis.utils import plot_matrix_with_colorbar
from iss_preprocess.vis.vis import round_to_rgb

w = 200
extent = [[center[1] - w, center[1] + w], [center[0] - w, center[0] + w]]

valid_spots = spots[
    (spots.x > extent[1][0])
    & (spots.x < extent[1][1])
    & (spots.y > extent[0][0])
    & (spots.y < extent[0][1])
]

rgb = round_to_rgb(
    stack,
    1,
    extent=extent,
    channel_colors=([1, 0, 0], [0, 1, 0], [1, 0, 1], [0, 1, 1]),
    vmax=np.percentile(stack[..., 0], 99.99, axis=(0, 1)),
    vmin=np.percentile(stack[..., 0], 0.2, axis=(0, 1)),
)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# make a colorbar and hide it to have axis the same size
cax, cb = plot_matrix_with_colorbar(rgb, axes[0])
cax.clear()
cax.axis("off")
axes[0].set_title("Raw data, round 1")
ref_part = reference[slice(*extent[0]), slice(*extent[1])]
cax, cb = plot_matrix_with_colorbar(
    ref_part,
    axes[1],
    cmap="viridis",
    vmax=np.percentile(reference, 99.999),
    vmin=np.percentile(reference, 0.2),
)
axes[1].contour(ref_part, levels=[detection_threshold], colors="r", linewidths=0.5)
cax.axhline(detection_threshold, color="r", lw=2)
axes[1].set_title("Standard deviation projection")
for ax in axes:
    ax.scatter(
        valid_spots.x - extent[1][0],
        valid_spots.y - extent[0][0],
        s=2,
        c="k",
        marker="x",
    )
    ax.axis("off")
fig.tight_layout()

## Select isolated spots

We keep only spots with a minimum distance to the next spot. This is done by measuring
the fluorescence in an annulus around each spot and thresholding

In [None]:
from iss_preprocess.coppafish import annulus
from iss_preprocess.segment import detect_isolated_spots

ops = iss.io.load.load_ops(data_path)
isolation_threshold = ops[f"{short_pref}_isolation_threshold"]
print(isolation_threshold)
annulus_r = (3, 7)
strel = annulus(annulus_r[0], annulus_r[1])
strel = strel / np.sum(strel)
annulus_image = scipy.ndimage.correlate(reference, strel)
isolated = annulus_image[spots["y"], spots["x"]] < isolation_threshold
isolated_spots = spots[isolated]
non_isolated_spots = spots[~isolated]

In [None]:
annulus_part = annulus_image[slice(*extent[0]), slice(*extent[1])]
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
plot_matrix_with_colorbar(
    ref_part,
    axes[0],
    vmax=np.percentile(ref_part, 99.999),
    vmin=np.percentile(ref_part, 0.2),
)
axes[0].set_title("Standard deviation projection")

cax, cb = plot_matrix_with_colorbar(
    annulus_part,
    axes[1],
    vmax=np.percentile(annulus_part, 99.999),
    vmin=np.percentile(annulus_part, 0.05),
)
print(isolation_threshold)
cax.axhline(isolation_threshold, color="r", lw=2)
if isolation_threshold < cax.get_ylim()[0]:
    cax.set_ylim(isolation_threshold, cax.get_ylim()[1])
if isolation_threshold > cax.get_ylim()[1]:
    cax.set_ylim(cax.get_ylim()[0], isolation_threshold)
for c, w, l in zip(
    ["r", "k"], [non_isolated_spots, isolated_spots], ["non-isolated", "isolated"]
):
    v = w[
        (w.x > extent[1][0])
        & (w.x < extent[1][1])
        & (w.y > extent[0][0])
        & (w.y < extent[0][1])
    ]
    axes[1].scatter(
        v.x - extent[1][0], v.y - extent[0][0], s=2, c=c, marker="+", label=l
    )
axes[1].legend(
    loc="lower left",
    bbox_to_anchor=(0.0, 0.0),
)
axes[1].set_title("Annulus projection")

# plot annulus values only where spots are
img = np.zeros_like(annulus_part) + np.nan
values = annulus_image[spots["y"], spots["x"]]
valid_values = annulus_image[valid_spots["y"], valid_spots["x"]]
for i, spot in valid_spots.iterrows():
    px = int(spot.x - extent[1][0])
    py = int(spot.y - extent[0][0])
    img[py, px] = annulus_part[py, px]
if len(valid_values) > 0:
    vmin, vmax = valid_values.min(), valid_values.max()
else:
    print("No valid values found")
    vmin, vmax = 0, 1
sc = axes[2].scatter(
    valid_spots.x - extent[1][0],
    valid_spots.y - extent[0][0],
    s=2,
    c=valid_values,
    marker="x",
    vmin=vmin,
    vmax=vmax,
)
axes[2].set_title("Annulus projection, spot values")
plot_matrix_with_colorbar(img, axes[2], vmin=vmin, vmax=vmax)

for ax in axes:
    ax.axis("off")

## Extract spots

Now that we have detected isolated spots, we can extract the fluorescence of each spot.
This is done with a given radius around the spot.


In [None]:
spots = detect_isolated_spots(
    reference,
    detection_threshold=detection_threshold,
    isolation_threshold=isolation_threshold,
)

iss.call.extract_spots(spots, stack, ops["spot_extraction_radius"])

In [None]:
# plot the spots in the part of reference image and show the extraction radius
from skimage.morphology import disk

fig = plt.figure(figsize=(7, 6))
axes = [plt.subplot2grid(fig=fig, shape=(2, 2), loc=(0, i)) for i in range(2)]
w = 50
smallext = [
    [center[1] - w, center[1] + w],
    [center[0] - w, center[0] + w],
]

valid_spots = spots[
    (spots.x > smallext[1][0])
    & (spots.x < smallext[1][1])
    & (spots.y > smallext[0][0])
    & (spots.y < smallext[0][1])
]

ref_part = reference[slice(*smallext[0]), slice(*smallext[1])]

plot_matrix_with_colorbar(
    ref_part,
    axes[0],
    vmax=np.percentile(ref_part, 99.999),
    vmin=np.percentile(ref_part, 0.2),
)
axes[0].set_title("Standard deviation projection")
axes[0].scatter(
    valid_spots.x - smallext[1][0],
    valid_spots.y - smallext[0][0],
    s=2,
    c="r",
    marker="x",
)


spot_footprint = disk(ops["spot_extraction_radius"])
drr, dcc = np.where(spot_footprint)
drr -= spot_footprint.shape[0] // 2
dcc -= spot_footprint.shape[1] // 2

img = np.zeros_like(ref_part) + np.nan
traces = []

for i, spot in spots.iterrows():
    if (
        spot.x <= smallext[1][0]
        or spot.x >= smallext[1][1]
        or spot.y <= smallext[0][0]
        or spot.y >= smallext[0][1]
    ):
        continue
    traces.append(spot.trace)
    rr = np.clip(drr + spot["y"] - smallext[0][0], 0, img.shape[0] - 1)
    cc = np.clip(dcc + spot["x"] - smallext[1][0], 0, img.shape[1] - 1)
    img[rr, cc] = ref_part[rr, cc]
cmap = plt.cm.viridis
cmap.set_bad("black", 1)
plot_matrix_with_colorbar(
    img,
    axes[1],
    vmax=np.percentile(ref_part, 99.999),
    vmin=np.percentile(ref_part, 0.2),
    cmap=cmap,
    interpolation="none",
)

axes[1].set_title("Extracted pixels")
print(f"Spot extraction radius: {ops['spot_extraction_radius']}.")
for ax in axes:
    ax.axis("off")


axes_t = plt.subplot2grid(fig=fig, shape=(2, 1), loc=(1, 0))
traces = np.stack(traces, axis=2)
traces = np.moveaxis(traces, [0, 1], [1, 2])
rgb_trace = iss.vis.to_rgb(
    traces,
    colors=[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]],
    vmin=[0] * 4,
    vmax=[traces.max()] * 4,
)
axes_t.imshow(rgb_trace, aspect="auto", interpolation="None")
axes_t.set_xlabel("Round")
axes_t.set_ylabel("Rolonie")
axes_t.set_title("Extracted fluorescence")
fig.tight_layout()

# Run on all tiles

We will now load all the spots of the reference tiles

In [None]:
from iss_preprocess.pipeline.sequencing import get_reference_spots

all_spots, norm_shifts = get_reference_spots(data_path, prefix=short_pref)
print(f"Found {len(all_spots)} reference spots.")

## Cluster one round

We will run the clustering part on one single round to see how to set score thresholds

In [None]:
iround = 5
score_thresh = ops[f"{short_pref}_cluster_score_thresh"]
print(f"Filtering cluster with score below {score_thresh}.")
spot_colors = np.stack(spots["trace"], axis=2)  # round x channels x spots
spot_round = spot_colors[iround, :, :].T

# now we will run scale_k_means on the spot_round. Here we do it manually to
# access the scores

In [None]:
# Initialise scaled k-means
nch = spot_round.shape[1]
initial_cluster_mean = np.array(ops["initial_cluster_means"])

x = spot_round
norm_cluster_mean = initial_cluster_mean / np.linalg.norm(
    initial_cluster_mean, axis=1
).reshape(-1, 1)
x_norm = x / np.linalg.norm(x, axis=1).reshape(-1, 1)
n_clusters = initial_cluster_mean.shape[0]
n_points, n_dims = x.shape
cluster_ind = (
    np.ones(x.shape[0], dtype=int) * -2
)  # set all to -2 so won't end on first iteration
cluster_eig_val = np.zeros(n_clusters)

if len(np.array([score_thresh]).flatten()) == 1:
    # if single threshold, set the same for each cluster
    score_thresh = np.ones(n_clusters) * score_thresh
elif isinstance(score_thresh, list):
    score_thresh = np.array(score_thresh)
    assert len(score_thresh) == n_clusters, "score_thresh must be length n_clusters"
# and run the first iteration
score = x_norm @ norm_cluster_mean.transpose()
cluster_ind = np.argmax(score, axis=1)  # find best cluster for each point
top_score = score[np.arange(n_points), cluster_ind]
top_score[np.where(np.isnan(top_score))[0]] = (
    score_thresh.min() - 1
)  # don't include nan values
plot_max = False
fig, ax = plt.subplots(1, 1, figsize=(8, 2))
mini = np.min(top_score)
for i in range(n_clusters):
    scores = top_score[cluster_ind == i]
    kwargs = dict(
        histtype="stepfilled",
        alpha=0.5,
        label=f"Cluster {i}",
        bins=np.arange(mini, 1.1, 0.01),
        cumulative=False,
    )
    ax.hist(
        scores,
        **kwargs,
        density=True,
    )
    if plot_max:
        ax.axvline(
            scores.max(),
            ymin=0,
            ymax=0.4,
            color=f"C{i}",
        )
for i, th in enumerate(score_thresh):
    ax.axvline(
        th, ymin=0.6, ymax=1, color=f"C{i}", ls="--", label=f"Threshold {i} - {th}"
    )

ax.legend(loc="upper left")
ax.set_xlim([0, 1.01])

# Get cluster means



In [None]:
score_thresh = ops[f"{short_pref}_cluster_score_thresh"]
print(f"Using {score_thresh} as the score threshold.")

# badin = np.eye(4)
cluster_means, spot_colors, cluster_inds = iss.call.call.get_cluster_means(
    all_spots,
    score_thresh=0.9,
    initial_cluster_mean=np.array(ops["initial_cluster_means"]),
)
nclusters = cluster_means[0].shape[0]  # maybe it's [1]?
nrounds = len(cluster_means)
fig, ax = plt.subplots(
    nrows=1, ncols=nclusters, facecolor="w", label="cluster_means", figsize=(8, 2)
)
for icluster in range(nclusters):
    plt.sca(ax[icluster])
    plt.imshow(np.stack(cluster_means, axis=2)[icluster, :, :])
    plt.xlabel("rounds")
    plt.ylabel("channels")
    plt.xticks(np.arange(nrounds), np.arange(1, nrounds + 1, dtype=int))
    plt.yticks(np.arange(nch), np.arange(nch, dtype=int))
    plt.title(f"Cluster {icluster+1}")

plt.tight_layout()

In [None]:
norm_fact = np.load(iss.io.get_processed_path(data_path) / f"correction_{prefix}.npz")
pixel_dist = norm_fact["pixel_dist"]
norm_factors_raw = norm_fact["norm_factors_raw"]

In [None]:
for channel in range(4):
    plt.plot(norm_factors_raw[channel], "-o", label=f"Channel {channel}")
plt.legend()
plt.xlabel("Round")
plt.ylabel("Normalization factor")

In [None]:
plt.imshow(ops["initial_cluster_means"])
plt.xticks(np.arange(4))
plt.yticks(np.arange(4))
plt.title("Initial cross-talk matrix")

## CHannel correction



In [None]:
from iss_preprocess.image import compute_distribution, filter_stack
from iss_preprocess.io.load import load_ops, load_sequencing_rounds

nrounds = None

ops = load_ops(data_path)
nch = len(ops["camera_order"])
if nrounds is None:
    nrounds = ops[f"{prefix.split('_')[0]}_rounds"]

max_val = 65535
pixel_dist = np.zeros((max_val + 1, nch, nrounds))
if prefix == "genes_round":
    projection = ops["genes_projection"]
elif prefix == "barcode_round":
    projection = ops["barcode_projection"]
else:
    raise ValueError("prefix must be 'genes_round' or 'barcode_round'")
corr_tiles = ops.get("correction_tiles", None)
if corr_tiles is None:
    print("No correction tiles specified - using ref tiles")
    corr_tiles = ops[f"{prefix.split('_')[0]}_ref_tiles"]
    assert corr_tiles is not None, "No ref tiles specified"

for tile in corr_tiles:
    print(f"counting pixel values for roi {tile[0]}, tile {tile[1]}, {tile[2]}")
    try:
        stack = load_sequencing_rounds(
            data_path, tile, suffix=projection, prefix=prefix, nrounds=nrounds
        )
    except FileNotFoundError:
        raise FileNotFoundError(
            f"Tile {tile} not found. Is ops['correction_tiles'] correct?"
        )
    stack = filter_stack(
        stack,
        r1=ops["filter_r"][0],
        r2=ops["filter_r"][1],
    )
    stack[stack < 0] = 0
    for iround in range(nrounds):
        pixel_dist[:, :, iround] += compute_distribution(
            stack[:, :, :, iround], max_value=max_val
        )

cumulative_pixel_dist = np.cumsum(pixel_dist, axis=0)
cumulative_pixel_dist = cumulative_pixel_dist / cumulative_pixel_dist[-1, :, :]

In [None]:
plt.imshow(np.log(pixel_dist[:500, 0, :]), aspect="auto")

In [None]:
norm_factors_raw = np.zeros((nch, nrounds))
for iround in range(nrounds):
    for ich in range(nch):
        norm_factors_raw[ich, iround] = np.argmax(
            cumulative_pixel_dist[:, ich, iround] > ops["correction_quantile"]
        )
norm_factors_raw

In [None]:
ops["correction_quantile"]

In [None]:
# get color on the viridis colormap
colors = plt.cm.viridis(np.linspace(0, 1, 7))
for iround in range(7):
    plt.plot(
        np.arange(500),
        cumulative_pixel_dist[:500, 0, iround],
        label=f"Round {iround}",
        color=colors[iround],
    )
    plt.scatter(
        norm_factors_raw[0, iround],
        cumulative_pixel_dist[int(norm_factors_raw[0, iround]), 0, iround],
        color=colors[iround],
    )
plt.legend()
plt.ylim(0.99, 1)
plt.xlabel("Pixel Index")
plt.ylabel("Cumulative Pixel Dist")
plt.title("Cumulative Pixel Distribution for Each Round")

In [None]:
plt.plot(np.arange(7), norm_factors_raw.T, "-o")

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import OneHotEncoder

x_ch = np.repeat(np.arange(nch)[:, np.newaxis], nrounds, axis=1)
x_round = np.repeat(np.arange(nrounds)[np.newaxis, :], nch, axis=0)
channels_encoding = (
    OneHotEncoder().fit_transform(x_ch.flatten()[:, np.newaxis]).todense()
)
x = np.asarray(np.hstack((x_round.flatten()[:, np.newaxis], channels_encoding)))

mdl = LinearRegression(fit_intercept=False).fit(
    x, np.log(norm_factors_raw.flatten()[:, np.newaxis])
)
norm_factors_fit = np.exp(mdl.predict(x))
norm_factors_fit = np.reshape(norm_factors_fit, norm_factors_raw.shape)


for channel in range(4):
    plt.plot(
        norm_factors_raw[channel], "-o", label=f"Channel {channel}", color=f"C{channel}"
    )
    plt.plot(
        norm_factors_fit[channel],
        "-o",
        label=f"Channel {channel} Fit",
        color=f"C{channel}",
        linestyle="--",
    )