In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import iss_preprocess as issp
import iss_analysis as issa
from iss_analysis.barcodes import barcodes as bar
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Run on synthetic data

In [None]:
# Generate synthetic data
def generate_synthetic_data(iseed, seed, spread=20 * 3, n_in_pairs=5, row_shift=300):
    np.random.seed(seed)
    # generate blobs
    coords = []
    means = []
    anisotropy = []
    ns = []
    for i in range(1, 11):
        means.append([(i - 5) * spread * 2, row_shift / 2 + iseed * row_shift])
        anisotropy.append(1.1)
        ns.append(i)

    distances = [10, 20, 30, 40, 50]
    for i, d in enumerate(distances):
        c = (i - 2) * spread * 4
        means.append([c, 0 + iseed * row_shift])
        means.append([c + d, 0 + iseed * row_shift])
        anisotropy.extend([1, 1])
        ns.extend([n_in_pairs] * 2)

    blobs_df = []
    for mean, anis, n in zip(means, anisotropy, ns):
        cov = [[spread, 0], [0, spread * anis]]
        x, y = np.random.multivariate_normal(mean, cov, n).T
        blobs_df.append(pd.DataFrame({"x": x, "y": y, "blob_id": [len(coords)] * n}))
        coords.append((x, y))
    blobs_df = pd.concat(blobs_df)
    blobs_df["fix"] = 1
    centroid_df = pd.DataFrame(means, columns=["x", "y"])
    return blobs_df, centroid_df

In [None]:
# Generate synthetic data
# we want to have a distribution of spot coordinates, with 4 gaussian blobs

verbose = 0
ns_to_try = [1]  # range(1, 6)
fig, axes = plt.subplots(len(ns_to_try), 1, figsize=(5, 4), squeeze=False)
row_shift = 300
seeds = [434, 32, 214]
pixel_size = 0.2
method = "variational_gmm"

for iseed, seed in enumerate(seeds):
    # fix the seed
    blobs_df, centroid_df = generate_synthetic_data(
        iseed, seed, n_in_pairs=5, row_shift=row_shift
    )
    assignments = {}
    blobs_df.reset_index(drop=True, inplace=True)
    centroid_df.reset_index(drop=True, inplace=True)
    for maxn in ns_to_try:
        spot_by_spot_params = dict(
            p=0.8,
            m=0.08,
            background_spot_prior=0.0001,
            spot_distribution_sigma=70 * pixel_size,
            max_iterations=20,
            max_distance_to_mask=600,
            inter_spot_distance_threshold=20,
            max_spot_group_size=maxn,
        )
        var_gmm_params = dict(
            alpha_background=None,
            alpha_cells=0.1,
            log_background_density=-9.5,
            max_iter=1000,
            tol=1e-4,
            sigma=50,
            max_distance_to_mask=300,
        )

        print(f"Seed {iseed}, Maxn {maxn}")
        col = "fix"
        ass = issa.barcodes.assign_barcodes_to_masks(
            spots=blobs_df,
            masks=centroid_df,
            method=method,
            parameters=(
                spot_by_spot_params if method == "spot_by_spot" else var_gmm_params
            ),
            base_column=col,
            verbose=verbose,
        )
        blobs_df[f"assignment_{col}_{maxn}"] = ass

    for i, maxn in enumerate(ns_to_try):
        col = f"assignment_fix_{maxn}"
        bg = blobs_df[col] < 0
        axes[0, i].scatter(
            centroid_df.x,
            centroid_df.y,
            c=centroid_df.index % 10,
            cmap="tab10",
            marker="o",
            s=200,
            alpha=0.2,
            edgecolors="k",
        )
        axes[0, i].scatter(
            blobs_df.x[~bg],
            blobs_df.y[~bg],
            c=blobs_df[col][~bg] % 10,
            cmap="tab10",
            alpha=0.8,
            s=7,
            vmin=0,
            vmax=9,
        )
        axes[0, i].scatter(blobs_df.x[bg], blobs_df.y[bg], color="k", alpha=0.8, s=7)
        if iseed == 0:
            axes[0, i].set_aspect("equal")
            axes[0, i].set_ylabel(f"Max group move {maxn}")
fig.suptitle("All spot same barcode")
fig.tight_layout()

# Debug probabilistic model for rolonie attribution

Cells to run it slowly and see how it works

In [None]:
project = "becalia_rabies_barseq"
error_correction_ds_name = "BRAC8498.3e_error_corrected_barcodes_10"
mouse_name = "BRAC8498.3e"
data_path = f"{project}/{mouse_name}"
analysis_folder = issp.io.get_processed_path(data_path) / "analysis"
analysis_folder.mkdir(exist_ok=True)

In [None]:
# Select example roi
chamber = "chamber_08"
roi = 7
verbose = True

## Load example data

In [None]:
import flexiznam as flz
from iss_preprocess.pipeline.segment import get_cell_masks


def _log(message, verbose):
    if verbose:
        print(message)


_log("Loading error corrected dataset", verbose)
flm_sess = flz.get_flexilims_session(project_id=project, reuse_token=True)
error_dataset = flz.Dataset.from_flexilims(
    name=error_correction_ds_name, flexilims_session=flm_sess
)

_log("Stitching masks", verbose)
masks = get_cell_masks(data_path=f"{project}/{mouse_name}/{chamber}", roi=roi)

_log("Making cell dataframe", verbose)
mask_df = issp.pipeline.segment.make_cell_dataframe(
    f"{project}/{mouse_name}/{chamber}",
    roi,
    masks=masks,
    atlas_size=None,
)

_log("Loading spots", verbose)

bc = pd.read_pickle(error_dataset.path_full)
spots = bc[(bc.chamber == chamber) & (bc.roi == roi)].copy()

## Select part to analyze

It's a bit slow on the whole dataset, so let's just take a part of it

In [None]:
xlim = np.array([-1000,1000]) + 10000
ylim = np.array([-1000,1000]) + 7100
specific_barcodes = [
    "TTTTGGACCTTTAT",
    "CACCATGTAATTAA",
    "GCAAGTAGACTCCA",
    "TCTTTTTTGACGCC",
]

In [None]:
barcode = "CCTACATCATAATA"
base_column = "corrected_bases"
barcode_df = spots[spots[base_column].astype(str) == barcode]
spot_positions = barcode_df[["x", "y"]].values

In [None]:
plt.figure(figsize=(5, 5))
plt.scatter(spots.x, spots.y, s=1, c="k", alpha=0.1)
plt.xlim(10300,11300)
plt.ylim(10000,15000)
plt.angle_spectrum

In [None]:
# Plot to select area
mask_part = mask_df[
    (mask_df.x > xlim[0])
    & (mask_df.x < xlim[1])
    & (mask_df.y > ylim[0])
    & (mask_df.y < ylim[1])
]
spots_part = spots[
    (spots.x > xlim[0])
    & (spots.x < xlim[1])
    & (spots.y > ylim[0])
    & (spots.y < ylim[1])
]

fig = plt.figure(figsize=(7, 7))
ax0 = fig.add_subplot(2, 2, 1, aspect="equal")
ax0.scatter(spots.x, spots.y, s=1, c="k", alpha=0.1)
rect = plt.Rectangle(
    (xlim[0], ylim[0]), xlim[1] - xlim[0], ylim[1] - ylim[0], fill=False, color="r"
)
ax0.add_patch(rect)

ax1 = fig.add_subplot(2, 2, 2, aspect="equal")
ax1.scatter(spots_part.x, spots_part.y, s=1, c="k", alpha=0.1)

ax2 = fig.add_subplot(2, 2, 3, aspect="equal")
bc_spot = spots[spots["corrected_bases"].isin(specific_barcodes)]
ax2.scatter(bc_spot.x, bc_spot.y, s=1, c="k", alpha=0.5)
rect = plt.Rectangle(
    (xlim[0], ylim[0]), xlim[1] - xlim[0], ylim[1] - ylim[0], fill=False, color="r"
)
ax2.add_patch(rect)

ax3 = fig.add_subplot(2, 2, 4, aspect="equal")
bc_spot = spots_part[spots_part["corrected_bases"].isin(specific_barcodes)]
ax3.scatter(bc_spot.x, bc_spot.y, s=10, c="k", alpha=0.5)
ax3.set_xlim(xlim)
ax3.set_ylim(ylim)
fig.tight_layout()


plt.show()

In [None]:
verbose = True
from iss_analysis.barcodes import probabilistic_assignment as pa


def run_assignment(params, method="spot_by_spot"):
    base_column = params.pop("base_column", "corrected_bases")
    mask_assignment = pa.assign_barcodes_to_masks(
        base_column=base_column,
        spots=spots_part,
        masks=mask_part,
        parameters=params,
        method=method,
    )
    output = pd.DataFrame(
        index=spots_part.index, columns=["mask", "chamber", "roi", "spot"]
    )
    output.loc[spots_part.index, "spot"] = spots_part.index.values
    output.loc[spots_part.index, "mask"] = mask_assignment
    output.loc[spots_part.index, "chamber"] = chamber
    output.loc[spots_part.index, "roi"] = roi
    return output, mask_assignment

## Run on the test set


In [None]:
# select parameters
default_params = dict(
    p=0.9,
    m=0.1,
    background_spot_prior=0.0001,
    spot_distribution_sigma=50,
    max_iterations=100,
    max_distance_to_mask=600,
    inter_spot_distance_threshold=20,
    max_spot_group_size=6,
    max_total_combinations=10000,
    verbose=1,
    base_column="corrected_bases",
)

In [None]:
# Test new parameters
five_spots_params = dict(
    p=0.8,
    m=0.08,
    background_spot_prior=0.0001,
    spot_distribution_sigma=100,
    max_iterations=100,
    max_distance_to_mask=600,
    inter_spot_distance_threshold=20,
    max_spot_group_size=6,
    max_total_combinations=10000,
    verbose=1,
    base_column="corrected_bases",
)

In [None]:
ma = masks.astype(float)
ma[ma == 0] = np.nan

In [None]:
import iss_preprocess.vis.utils as vis
if True:
    fig, axes = plt.subplots(2, 2, figsize=(18, 15))
    param_pair = [
        default_params,
        new_params,
    ]
    for ip, params in enumerate(param_pair):
        output, mask_assignment = run_assignment(params)
        vis.plot_bc_over_mask(
            axes[ip, 0], ma, spots_part, mask_assignment, xlim, ylim, nc=10
        )
        spec = spots_part.corrected_bases.isin(specific_barcodes)
        vis.plot_bc_over_mask(
            axes[ip, 1],
            ma,
            spots_part[spec],
            mask_assignment[spec],
            xlim,
            ylim,
            nc=10,
            show_bg_barcodes=True,
        )

    axes[0, 0].set_ylabel("Initial parameters")
    axes[1, 0].set_ylabel("New parameters")
    plt.tight_layout()

In [None]:
import matplotlib as mpl

# DOES NOT WORK NOW THAT WE ADDED THE VARIATIONAL GMM METHOD
if False:
    spec = spots_part.corrected_bases.isin(specific_barcodes)
    mask_assignment_id = issa.barcodes.main.assign_barcodes_to_masks(
        spots=spots_part[spec], masks=mask_part.iloc[[0, 1, 2]], debug=True, **params
    )
    mask_assignment_id2 = issa.barcodes.main.assign_barcodes_to_masks(
        spots=spots_part[spec], masks=mask_part.iloc[[0, 1]], debug=True, **params
    )
    ms = [mask_assignment_id, mask_assignment_id2]
    fig, axes = plt.subplots(2, 4, figsize=(10, 6))
    colors = mpl.cm.get_cmap("tab10", 10).colors
    for irow, mask_assignment_id in enumerate(ms):
        if irow == 0:
            ok = [0, 1, 2]
        else:
            ok = [0, 1]
        for iax, ax in enumerate(axes[irow]):
            for i in ok:
                m = mask_part.iloc[i]
                ax.scatter(m.x, m.y, c=i, ec="k", s=200, cmap="tab10", vmin=0, vmax=9)
                ax.text(
                    x=m.x,
                    y=m.y,
                    s=f"{i}",
                    color="k",
                    horizontalalignment="center",
                    verticalalignment="center",
                )

            c = [colors[i] if i >= 0 else "k" for i in mask_assignment_id[iax]]
            ax.scatter(spots_part[spec].x, spots_part[spec].y, c=c, s=10)
            ax.set_aspect("equal")
            ax.set_xlim(*xlim)
            ax.set_ylim(*ylim)
            ax.set_xticks([])
            ax.set_yticks([])
            if not irow:
                ax.set_title(f"Assignment round {iax+1}")
        if irow:
            axes[irow, 0].set_ylabel("Delete cell 2")
        else:
            axes[irow, 0].set_ylabel("All cells")

        mask_assignment_id

# Adapt parameters

## Parameters to set

In [None]:
# these are the parameters we can change:
p = 0.9
m = 0.1
background_spot_prior = 0.0001
spot_distribution_sigma = 50

In [None]:
# Generate synthetic data
# we want to have a distribution of spot coordinates, with 4 gaussian blobs

if False:
    spread = 20 * 3
    fig, axes = plt.subplots(3, 1, figsize=(15, 15))
    labels = ["True ID", "Barcode per blob", "All same barcode"]
    row_shift = 300
    seeds = [434, 32, 214]
    for iseed, seed in enumerate(seeds):
        # fix the seed
        np.random.seed(seed)
        # generate 4 blobs

        coords = []
        means = []
        anisotropy = []
        ns = []
        for i in range(1, 11):
            means.append([(i - 5) * spread * 2, row_shift / 2 + iseed * row_shift])
            anisotropy.append(1.1)
            ns.append(i)

        distances = [10, 20, 30, 40, 50]
        for i, d in enumerate(distances):
            c = (i - 2) * spread * 4
            means.append([c, 0 + iseed * row_shift])
            means.append([c + d, 0 + iseed * row_shift])
            anisotropy.extend([1, 1])
            ns.extend([10, 10])

        blobs_df = []
        for mean, anis, n in zip(means, anisotropy, ns):
            cov = [[spread, 0], [0, spread * anis]]
            x, y = np.random.multivariate_normal(mean, cov, n).T
            blobs_df.append(pd.DataFrame({"x": x, "y": y, "blob_id": [len(coords)] * n}))
            coords.append((x, y))
        blobs_df = pd.concat(blobs_df, ignore_index=True)
        blobs_df["fix"] = 1
        centroid_df = pd.DataFrame(means, columns=["x", "y"])

        assignments = {}
        for col in ["fix"]:  # "blob_id",
            ass = issa.barcodes.assign_barcodes_to_masks(
                spots=blobs_df,
                masks=centroid_df,
                parameters=dict(
                    p=0.8,
                    m=0.08,
                    background_spot_prior=0.0001,
                    spot_distribution_sigma=50 * pixel_size,
                    max_iterations=100,
                    inter_spot_distance_threshold=50,
                    max_distance_to_mask=600,
                ),
                base_column=col,
            )
            ass = ass[0]
            blobs_df[f"assignment_{col}"] = ass

        for i, col in enumerate(["blob_id", "assignment_blob_id", "assignment_fix"]):
            bg = blobs_df[col] < 0
            axes[i].scatter(
                centroid_df.x,
                centroid_df.y,
                c=centroid_df.index % 10,
                cmap="tab10",
                marker="o",
                s=200,
                alpha=0.2,
                edgecolors="k",
            )
            axes[i].scatter(
                blobs_df.x[~bg],
                blobs_df.y[~bg],
                c=blobs_df[col][~bg] % 10,
                cmap="tab10",
                alpha=0.8,
                s=7,
                vmin=0,
                vmax=9,
            )
            axes[i].scatter(blobs_df.x[bg], blobs_df.y[bg], color="k", alpha=0.8, s=7)
            if iseed == 0:
                axes[i].set_aspect("equal")
                axes[i].set_ylabel(labels[i])
                for i_d, d in enumerate(distances):
                    axes[i].text(
                        (i_d - 2) * spread * 4,
                        (len(seeds) - 1 + 0.15) * row_shift,
                        f"{d} um",
                    )

    fig.suptitle("New parameters")
    fig.tight_layout()

In [None]:
# Spot prior
from iss_analysis.barcodes.probabilistic_assignment import _spot_count_prior

fig, ax = plt.subplots(1, 2, figsize=(5, 2.5))
background_spot_prior = 0.00005
nspots = np.arange(10)
distances = np.arange(0, 500, 10)

log_background_spot_prior = np.log(background_spot_prior)

for p, m in [(1, 0.1),(0.9, 0.1), (0.8, 0.08)]:
    sp_cnt_prior = _spot_count_prior(nspots, p, m)
    ax[0].plot(nspots, sp_cnt_prior, label=f"p={p}, m={m}",)
    ax[1].plot(
        nspots,
        sp_cnt_prior - nspots * log_background_spot_prior,
        label=f"p={p}, m={m}",
        marker=".",
    )
    half_prior = _spot_count_prior(nspots / 2, p, m)
    #ax[2].plot(nspots, sp_cnt_prior - 2 * half_prior, label=f"p={p}, m={m}", marker="o")


ax[0].plot(
    nspots,
    nspots * log_background_spot_prior,
    label="Background prior",
    color="k",
    linestyle="--",
)
#ax[0].set_title("Spot count prior")
ax[1].axhline(0, color="k", linestyle="--", label="below = background")
#ax[1].set_title("Spot count prior - background")
#ax[2].set_title("Benefits of fusing: $L$(N spots) - 2 $L$(N/2 spots)")

ax[0].legend(loc="upper center", bbox_to_anchor=(1.1, 1.15), ncol=4)
for x in ax:
    x.set_xticks(range(0,11,2))
    x.set_xlabel("Spot count")
    x.set_xlim(0,10)
ax[0].set_ylabel("Likelihood")
ax[1].set_ylabel("Likelihood difference")

for axis in ax:
    axis.spines['top'].set_visible(False)
    axis.spines['right'].set_visible(False)
fig.subplots_adjust(wspace=0.3)
# ax[2].set_ylabel("Likelihood difference")
fig.savefig("spot_prior.svg")

## Test on real data

In [None]:
params = dict(
    p=0.8,
    m=0.08,
    background_spot_prior=0.0001,
    spot_distribution_sigma=50,
    max_iterations=100,
    max_distance_to_mask=600,
    inter_spot_distance_threshold=20,
)

In [None]:
param_list = [
    {"background_spot_prior": 0.0001, "spot_distribution_sigma": 50 * i} for i in [0.5]
]


for param in param_list:
    kwargs = dict(params, **param)
    print(kwargs)
    output, mask_assignment = run_assignment(kwargs)
    fig, axes = plt.subplots(1, 2, figsize=(20, 10))
    vis.plot_bc_over_mask(axes[0], ma, spots_part, mask_assignment, xlim, ylim, nc=20)
    spec = spots_part.corrected_bases.isin(specific_barcodes)
    vis.plot_bc_over_mask(
        axes[1], masks, spots_part[spec], mask_assignment[spec], xlim, ylim, nc=20
    )

In [None]:
import matplotlib
import skimage
plt.figure(figsize=(7,7))

cmap = 'Set3'
ncol = 12
mask_part = vis.get_stack_part(masks, xlim, ylim)
non_assigned = spots_part[mask_assignment == -1]
spots_part["assignment"] = mask_assignment
colors = matplotlib.cm.get_cmap(cmap, ncol).colors
col_bar =matplotlib.cm.get_cmap('Accent', 6).colors
m = mask_part.astype(float)
m[m == 0] = np.nan
plt.imshow(m % ncol, cmap=cmap, interpolation="none", alpha=1)

plt.scatter(non_assigned.x - xlim[0], non_assigned.y - ylim[0], c="k", s=5, alpha=0.5)
assigned = spots_part[mask_assignment >= 0]
barcodes = list(assigned.corrected_bases.unique())
print(len(barcodes))
bc_color = [col_bar[i%5] for i in np.array([barcodes.index(b) for b in assigned.corrected_bases]).astype(int)]
sp_col = [colors[i] for i in (assigned['assignment'].values % ncol).astype(int)]

plt.scatter(assigned.x - xlim[0], assigned.y - ylim[0], s=30, ec=bc_color,c=sp_col, cmap=cmap, linewidths=2, alpha=1, zorder=20)

for mask_id, sp in assigned.groupby("assignment"):
    if mask_id not in mask_part:
        continue
    cell_center = mask_part == mask_id
    cell_center = skimage.measure.centroid(cell_center)
    cy, cx = cell_center
    if False:
        for sid, s in sp.iterrows():
            plt.plot([s.x - xlim[0], cx], [s.y - ylim[0], cy], color='k', alpha=0.5, zorder=1)


plt.axis("off")
# overview
plt.xlim(700, 1400)
plt.ylim(1400,800)

In [None]:
cvbn

In [None]:
sp

In [None]:
import skimage

c = skimage.measure.centroid(cell_center)


plt.imshow(cell_center)

# Re-run barcode assignment

In [None]:
if False:
    px_size = issp.io.get_pixel_size(f"{project}/{mouse_name}/chamber_07")
    issa.barcodes.assign_barcode_all_chambers(
        project,
        mouse_name=mouse_name,
        error_correction_ds_name="BRAC8498.3e_error_corrected_barcodes_10",
        base_column="corrected_bases",
        p=0.9,
        m=0.1,
        background_spot_prior=0.0001,
        spot_distribution_sigma=50,
        max_iterations=100,
        distance_threshold=200 / px_size,
        use_slurm=True,
        conflicts="overwrite",
    )