# Look at barcode diversity in real life

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import iss_preprocess as issp
from iss_analysis.barcodes import barcodes as bar
from iss_analysis.barcodes.diagnostics import plot_gmm_clusters, plot_error_along_sequence
import matplotlib.pyplot as plt
import numpy as np

# Get data

## Get all the barcode and filter with GMM

In [None]:
data_path = "becalia_rabies_barseq/BRAC8498.3e"
analysis_folder = issp.io.get_processed_path(data_path) / "analysis"
analysis_folder.mkdir(exist_ok=True)


In [None]:

thresholds = dict(mean_intensity=0.01, dot_product_score=0.2, mean_score=0.5)
barcode_spots, gmm, all_barcode_spots = bar.get_barcodes(
    acquisition_folder=data_path,
**{f"{k}_threshold": v for k, v in thresholds.items()})

pg = plot_gmm_clusters(all_barcode_spots, gmm, thresholds=thresholds)
fig = pg.figure
fig.savefig(analysis_folder / "gmm_clusters.png")


## Plot proportion of barcodes in each cluster

In [None]:

ch_r_gp = all_barcode_spots.groupby(['chamber', 'roi'])
n_rois = len(ch_r_gp)
n_per_lab = np.zeros((n_rois, 3))
for i, (ch_r, df) in enumerate(ch_r_gp):
    n_per_lab[i] = df.groupby('gmm_label').size()

fig = plt.figure(figsize=(20, 10))
ax = fig.add_subplot(2,1,1)
ax.plot(np.arange(n_rois), n_per_lab, 'o-')
ax.set_xticks(np.arange(n_rois))
ax.set_xticklabels([])
ax.set_ylabel('Number of spots')
ax = fig.add_subplot(2,1,2)
ax.plot(np.arange(n_rois), n_per_lab/np.sum(n_per_lab, axis=1, keepdims=True), 'o-')
ax.set_xticks(np.arange(n_rois))
ax.set_ylabel('Fraction of spots')
_ = ax.set_xticklabels([f"{ch}_{r}" for ch, r in ch_r_gp.groups.keys()], rotation=90)
fig.tight_layout()

## Plot one example chamber

In [None]:
chamber = 8
roi = 8

import matplotlib.pyplot as plt
df = all_barcode_spots[(all_barcode_spots["chamber"] == f"chamber_{chamber:02}") & (all_barcode_spots["roi"] == roi)]
fig = plt.figure(figsize=(20, 10), facecolor='black')
ax = fig.add_subplot(1,1,1,aspect='equal', facecolor='black')
for i in range(3):
    ax.scatter(df[df["gmm_label"] == i]["y"], df[df["gmm_label"] == i]["x"], label=f"cluster {i}",
                alpha=0.1, s=1)
plt.legend()
_ = ax.axis('off')

# Error correction

## Estimate error rate along sequence

In [None]:
errors_along_seq = bar.error_per_round(spot_df=barcode_spots, edit_distance=1, spot_count_threshold=30, sequence_column='sequence', filter_column='bases')

In [None]:
figs = plot_error_along_sequence(errors_along_seq, nrows=2, plot_matrix=False, marker='o')

## Manually filter out missing data

In [None]:
# This is INDEX of the round, NOT the number as written in the file name, so 0 based.
missing_roi = dict(chamber_09={8: 2, 9: 2}, chamber_08={8:10, 9:10})
missing_tiles = dict(chamber_08={'10_2_4':5})

In [None]:
clean_barcodes = barcode_spots.copy()

print('Removing rounds with missing rois')
for chamber, rois in missing_roi.items():
    for roi, round_ in rois.items():
        print(f"Removing round {round_} from {chamber}, roi {roi}")
        bad = (clean_barcodes["chamber"] == chamber) & (clean_barcodes["roi"] == roi)
        mask = np.zeros(len(clean_barcodes.iloc[0]['sequence']), dtype=float)
        mask[round_] = np.nan
        c = clean_barcodes.loc[bad, 'sequence'].map(lambda x: x + mask)
        clean_barcodes.loc[bad, 'sequence'] = c
        c = clean_barcodes.loc[bad, 'bases'].map(lambda x: x[:round_] + 'N' + x[round_ + 1:])
        clean_barcodes.loc[bad, 'bases'] = c
print('Removing rounds with missing tiles')
for chamber, tiles in missing_tiles.items():
    for tile, round_ in tiles.items():
        print(f"Removing round {round_} from {chamber}, tile {tile}")
        bad = (clean_barcodes["chamber"] == chamber) & (clean_barcodes["tile"] == tile)
        mask = np.zeros(len(clean_barcodes.iloc[0]['sequence']), dtype=float)
        mask[round_] = np.nan
        c = clean_barcodes.loc[bad, 'sequence'].map(lambda x: x + mask)
        clean_barcodes.loc[bad, 'sequence'] = c
        c = clean_barcodes.loc[bad, 'bases'].map(lambda x: x[:round_] + 'N' + x[round_ + 1:])
        clean_barcodes.loc[bad, 'bases'] = c

In [None]:
errors_along_seq = bar.error_per_round(spot_df=clean_barcodes, edit_distance=1, spot_count_threshold=30)
figs = plot_error_along_sequence(errors_along_seq, nrows=2, plot_matrix=False, marker='o')

## Actual error correction part

In [None]:
import pandas as pd
redo = False
weights = np.ones(14)
# weights[10:] = 0
err_corr_edt = dict()
for edist in [3]:
    print(f"Correcting barcodes with edit distance {edist}")
    fname = analysis_folder / f"corrected_barcode_spots_edit{edist}.pkl"
    if redo or not fname.exists():
        err_corr, merge_dict = bar.correct_barcode_sequences(clean_barcodes, max_edit_distance=edist, weights=weights, return_merge_dict=True)
        err_corr.to_pickle(fname)
    else:
        err_corr = pd.read_pickle(fname)
    non_corr = err_corr.bases.value_counts()
    corr = err_corr.corrected_bases.value_counts()
    print(f"Number of unique sequences before correction: {len(non_corr)}")
    print(f"Number of unique sequences after correction: {len(corr)}")
    err_corr_edt[edist] = err_corr



## Quality control

In [None]:
for threshold in range(10, 100, 10):
    n_bc = {}
    n_bc_with_N = {}
    for k, v in err_corr_edt.items():
        v['corrected_bases'] = v['corrected_bases'].map(lambda x: ''.join(x))
        vc = v['corrected_bases'].value_counts()
        n_bc[k] = vc[vc > threshold].shape[0]
        n_bc_with_N[k] = np.sum(['N' in i for i in vc.index])
    plt.plot(n_bc.keys(), n_bc.values(), 'o-', label=f"Aboundance > {threshold}")
plt.ylabel('Number of unique barcodes')
plt.xlabel('Edit distance')
plt.legend()

In [None]:
ab = vc[vc >= 500]

nchamber = len(v.chamber.unique())
nrois= 10
fig, axes = plt.subplots(nrois, nchamber, figsize=(20, 40))
for bci in range(5):
    bc = v[v.corrected_bases==ab.index[bci]]
    for ich, (ch_name, ch) in enumerate(v.groupby('chamber')):
        for iroi, (roi_name, roi) in enumerate(ch.groupby('roi')):
            ax = axes[iroi, ich]
            bcc = bc[(bc.roi == roi_name) & (bc.chamber == ch_name)]
            if not bci:
                ax.scatter(roi['y'], roi['x'], label=f"roi {iroi}", s=1, color='k', alpha=0.05)
            ax.scatter(bcc['y'], bcc['x'], label=f"roi {iroi}", s=5, color=f'C{bci}', alpha=0.5)
            ax.axis('equal')
            ax.axis('off')
fig.tight_layout()

In [None]:


plt.plot(n_bc_with_N.keys(), n_bc_with_N.values(), 'o-', label="with N")
plt.ylabel('Number of unique barcodes')
plt.xlabel('Edit distance')
plt.legend()

In [None]:
nan_to_start = np.sum(['N' in i for i in non_corr.index])
nan_after = np.sum(['N' in i for i in corr.index])
print(f"Number of 'unique' sequences with N before correction: {nan_to_start}")
print(f"Number of 'unique' sequences with N after correction: {nan_after}")
print(f"Number of sequences with N corrected: {nan_to_start-nan_after}")

In [None]:
fig = plt.figure(figsize=(7, 3)) 
fig.suptitle("# Rolonies per sequence")
ax = fig.add_subplot(1, 2, 1)
ax.hist(non_corr.values, bins=np.arange(0,corr.max(),10), log=True, histtype='step', label='non-corrected')
ax.hist(corr.values, bins=np.arange(0,corr.max(),10), log=True, histtype='step', label='corrected')
ax = fig.add_subplot(1, 2, 2)
ax.hist(non_corr.values, bins=np.arange(100), log=True, label='non-corrected', histtype='step')
ax.hist(corr.values, bins=np.arange(100), log=True, label='corrected', histtype='step')
ax.legend(loc='upper right')
fig.tight_layout()

In [None]:
fig = plt.figure(figsize=(7, 3)) 
fig.suptitle("# Rolonies per sequence")
ax = fig.add_subplot(1, 2, 1)
ax.hist(non_corr.values, bins=np.arange(corr.max()), log=False, histtype='step', label='non-corrected', cumulative=True, density=True)
ax.hist(corr.values, bins=np.arange(corr.max()), log=False, histtype='step', label='corrected', cumulative=True, density=True)
ax = fig.add_subplot(1, 2, 2)
ax.hist(non_corr.values, bins=np.arange(100), log=False, label='non-corrected', histtype='step', density=True)
ax.hist(corr.values, bins=np.arange(100), log=False, label='corrected', histtype='step', density=True)
ax.legend(loc='upper right')
fig.tight_layout()

In [None]:
from iss_preprocess.call import BASES
th = 1
good = corr[corr > 30].index
print(len(good))
sequences = np.stack(err_corr["corrected_sequence"].to_numpy())
error_along_sequence = np.zeros((len(good), sequences.shape[1]))
bases = list(BASES) + ['N']
for ibar, barcode in enumerate(good):
    seq = [bases.index(b) for b in barcode]
    diff = sequences - seq
    edit_distance = np.sum(diff != 0, axis=1)
    actual_errs = edit_distance <= th
    bad_barcode = diff[actual_errs]
    error_along_sequence[ibar] = np.any(bad_barcode != 0, axis=0)


In [None]:
def plot_error_along_seq(error_along_sequence):
    fig = plt.figure(figsize=(10, 5))
    fig.suptitle("Errors along sequence (edit distance 1)")
    ax = fig.add_subplot(1, 2, 1)
    im = ax.imshow(error_along_sequence, aspect='auto')
    ax.set_xlabel('Position')
    ax.set_ylabel('Barcode')
    cb = fig.colorbar(im, ax=ax)
    cb.set_label('Errors')
    ax = fig.add_subplot(1, 2, 2)
    ax.plot(np.mean(error_along_sequence, axis=0))
    ax.set_xlabel('Position')
    ax.set_ylabel('Mean errors')
    fig.tight_layout()
    return fig
f = plot_error_along_seq(error_along_sequence)

In [None]:
non_corr = err_corr.bases.value_counts()
corr = err_corr.corrected_bases.value_counts()
df_non_corr = pd.DataFrame(dict(seq=non_corr.index, cnt=non_corr.values))
df_corr = pd.DataFrame(dict(seq=corr.index, cnt=corr.values))

print(df_non_corr.head())


In [None]:
ok = df_non_corr.cnt > 50
print(df_non_corr[ok].cnt.sum() / df_non_corr.cnt.sum())
ok = df_corr.cnt > 50
print(df_corr[ok].cnt.sum() / df_corr.cnt.sum())

In [None]:
fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(1,1,1)
ax.plot(np.log(df_non_corr.index),np.log(df_non_corr["cnt"]), label='non-corrected')
ax.plot(np.log(df_corr.index), np.log(df_corr["cnt"]), label='corrected')
plt.legend()
ax.set_xlabel('Rank')
ax.set_ylabel('log(count)')


In [None]:
all_vs.shape

In [None]:
fig = plt.figure(figsize=(10, 3))
print(all_vs.max())
ax = fig.add_subplot(1, 2, 1)
_ = ax.hist(all_vs, bins=np.arange(0, all_vs.max(), 10), log=True)
ax.text(800, 10e3, f"Max: {all_vs.max():.0f}")
ax.set_ylabel('Number of rolonies before merge')
ax = fig.add_subplot(1, 2, 2)
_ = ax.hist(all_vs, bins=np.arange(25) + 0.5, log=False)
ax.set_xticks(np.arange(0, 24, 2) + 1)
ax.set_ylabel('Number of rolonies before merge')
fig.tight_layout()


# Look at starters

## Select example ROI

In [None]:
# try to make a big registered mask image
chamber= 'chamber_08'
roi = 4
print(data_path + f"/{chamber}")


In [None]:

ref = issp.pipeline.stitch_registered(data_path + f"/{chamber}", prefix='hybridisation_round_2_1', roi=roi, channels=[2,3])
barcode_round_1 = issp.pipeline.stitch_registered(data_path + f"/{chamber}", prefix='barcode_round_1_1', roi=roi, channels=range(4))
mcherry = issp.pipeline.stitch_registered(data_path + f"/{chamber}", prefix='mCherry_1', roi=roi, channels=2)

In [None]:
st = np.dstack([ref[...,1], mcherry, np.nanmax(barcode_round_1, axis=2)])
rgb = issp.vis.to_rgb(st, colors=[(0, 0, 1), (1, 0, 0), (0, 1, 0)], vmin=np.nanpercentile(st, 1, axis=(0,1)), vmax=np.nanpercentile(st, 99.9, axis=(0,1)))
plt.figure(figsize=(20, 20))
plt.imshow(rgb, interpolation='none')

## Segment rolonies

In [None]:
# get masks
stitched_mask = issp.pipeline.stitch_registered(data_path + f"/{chamber}", prefix='mCherry_1_masks', roi=roi, projection='corrected')
bigmask = issp.pipeline.segment.get_big_masks(data_path+ f"/{chamber}", roi, stitched_mask, mask_expansion=5)

In [None]:
# filter rolonies
import numpy as np
df = all_barcode_spots[(all_barcode_spots["chamber"] == chamber) & (all_barcode_spots["roi"] == roi)]


In [None]:
# helper plotting functions
def get_stack_part(stack, xlim, ylim):
    ylim = sorted(ylim)
    xlim = sorted(xlim)
    return stack[ylim[0]:ylim[1], xlim[0]:xlim[1]]

def get_spot_part(df, xlim, ylim):
    ylim = sorted(ylim)
    xlim = sorted(xlim)
    return df[(df["x"] >= xlim[0]) & (df["x"] < xlim[1]) & (df["y"] >= ylim[0]) & (df["y"] < ylim[1])]

In [None]:
ylim = [16000,8000]
xlim = [12000,18000]
plt.figure(figsize=(20, 20))

# mask with nan for display
m = bigmask.astype(float)
m[bigmask == 0] = np.nan


plt.imshow(get_stack_part(m, xlim, ylim), interpolation='none', alpha=0.5, cmap='prism', extent=[xlim[0], xlim[1], ylim[0], ylim[1]])
sp = get_spot_part(df, xlim, ylim)
plt.scatter(sp["x"], sp["y"],alpha=0.5, s=1, color='k')


In [None]:
barcode_df, fused_df = issp.pipeline.segment.segment_spots(
    data_path + f'/{chamber}',
    roi,
    mask_expansion=0,
    masks=bigmask.astype('uint16'),
    barcode_df=err_corr,
    barcode_dot_threshold=None,
    spot_score_threshold=0.1,
    hyb_score_threshold=0.8,
    load_genes=True,
    load_hyb=False,
    load_barcodes=True,
)

In [None]:
barcode

In [None]:
o = issp.pipeline.stitch.get_tform_to_ref(data_path+ f'/{chamber}', prefix='mCherry_1', tile_coors=(1,1,1), corrected_shifts=None)
print(o)

In [None]:
cell_df = issp.pipeline.segment.make_cell_dataframe(data_path+ f'/{chamber}', roi, masks=bigmask, mask_expansion=0, atlas_size=10)

In [None]:
cell_df.head()

In [None]:


ops = issp.io.load_ops(data_path + f'/{chamber}')
tile = ops['mcherry_ref_tile']

In [None]:
# Load the mCherry masks
import numpy as np
import matplotlib.pyplot as plt
processed = issp.io.get_processed_path(data_path)
tname = "_".join([str(t) for t in tile])
masks_file = processed / chamber / "cells" / f"mCherry_1_masks_corrected_{tname}.npy"
masks = np.load(masks_file)
m = masks.astype(float)
m[m == 0] = np.nan
plt.imshow(m, interpolation='None', cmap='prism')

In [None]:
# Warp into ref coordinages
masks_file = processed / chamber / "cells" / f"mCherry_1_masks_corrected_{tname}.npy"
masks = np.load(masks_file)
warped_mask, bad_pixels = issp.pipeline.stitch.warp_stack_to_ref(masks, data_path+ f'/{chamber}', prefix='mCherry_1', tile_coors=tile, interpolation=0)
wm = warped_mask.astype(float)
wm[wm == 0] = np.nan
plt.imshow(wm[...,0,0], interpolation='None', cmap='prism')

In [None]:
# get tile corners

tile_origin

In [None]:
ok = (err_corr.chamber == chamber) & (err_corr['tile'] == '_'.join([str(t) for t in tile]))
bc = err_corr[ok]
tc = issp.pipeline.stitch.get_tile_corners(data_path + f'/{chamber}', prefix=ops['reference_prefix'], roi=tile[0])
tc = tc[tile[1], tile[2]]
tile_origin = tc[:,0]
plt.imshow(wm[...,0,0], interpolation='None', cmap='prism', alpha=0.5)
plt.scatter(bc.x-tile_origin[1], bc.y-tile_origin[0], s=1, color='k', alpha=0.5)

In [None]:
issp.pipeline.segment.make_cell_dataframe(data_path, roi, masks=None, mask_expansion=5.0, atlas_size=10)