In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import holoviews as hv
from holoviews.operation.datashader import regrid
import hvplot.pandas
import panel as pn
import param
import pickle
import os
import pyarrow as pa
from scipy.integrate import simps
import segmentation
from matriarch_stub import RevImage
import astropy.stats
from cytoolz import partial

In [None]:
hv.extension("bokeh")

# Load data

In [None]:
base_filename = "/n/groups/paulsson/jqs1/molecule-counting"
# filename = '190304photobleaching.pickle'
filename = "190328photobleaching_flatcorr.pickle"
with open(os.path.join(base_filename, filename), "rb") as f:
    data = pickle.load(f)
# filename = '190328photobleaching_flatcorr.arrow'
# data = pa.read_serialized(pa.OSFile(os.path.join(base_filename, filename)), np.array).deserialize()

In [None]:
# import pyarrow as pa
# buf = pa.serialize(data).to_buffer()
# with pa.OSFile(os.path.join(base_filename, '190328photobleaching_flatcorr.arrow')) as f:
#    f.write(buf)

In [None]:
# do p=0 correction
# M correction
# multi-position selector (multi-file?): just use a function to combine and add to data dict
# check that depends funcs aren't firing twice
# fix segmentation componentwise_normalization
# add sandwich scatterplot (when applicable)
# size filtering/size histogram?
# check integral values
# TODO: rescale values to count values??

In [None]:
def _trace_plot(traces, downsample=10, normalize=True, colors=None):
    if colors is None:
        colors = np.random.permutation(traces.shape[0])
        cmap = "Category20"
    else:
        cmap = "bmy"
    y = traces
    if normalize:
        y = y / y[:, 0, np.newaxis]
    curves = [
        {"x": np.arange(traces.shape[1]), "y": y[i], "i": colors[i]}
        for i in range(0, traces.shape[0], downsample)
    ]
    plot = hv.Contours(curves, vdims=["i"]).options(
        color_index="i", cmap=cmap, logy=True
    )
    return plot


def _fluctuation_plots(pbar, y, qs, nu_qs, colors=None):
    if colors is None:
        colors = np.random.permutation(y.shape[0])
        cmap = "Category20"
    else:
        cmap = "bmy"
    nus = nu_qs.iloc[:, -1]
    fluctuation_curves = [
        {"x": 1 - pbar.values[i], "y": y.values[i], "i": colors[i]}
        for i in range(y.shape[0])
    ]
    fluctuation_plot = hv.Contours(fluctuation_curves, vdims=["i"]).options(
        color_index="i", cmap=cmap
    )
    integral_curves = [
        {"x": qs, "y": nu_qs.values[i], "i": colors[i]} for i in range(nu_qs.shape[0])
    ]
    integral_plot = (
        hv.Contours(integral_curves, vdims=["i"])
        .redim(y="integral")
        .options(color_index="i", cmap=cmap)
    )
    plots = fluctuation_plot, integral_plot
    return plots


def _spatial_plots(regionprops, labels, img, colors=None):
    if colors is None:
        colors = np.random.permutation(regionprops.shape[0])
        cmap = "Category20"
    else:
        cmap = "bmy"
    # swap and flip coordinates to follow RevImage convention
    regionprops["x"] = regionprops["centroid_y"]
    regionprops["y"] = regionprops["centroid_x"]
    regionprops["color"] = colors
    xy_plot = hv.Points(regionprops, kdims=["x", "y"], vdims=["color"]).options(
        color=hv.dim("color"), cmap=cmap
    )
    image_plot = regrid(RevImage(img))
    labels_plot = regrid(
        RevImage(segmentation.permute_labels(labels)), aggregator="max"
    ).redim(z="label")
    spatial_plots = xy_plot + image_plot + labels_plot
    return spatial_plots


class PhotobleachingViewer(param.Parameterized):
    def __init__(self, data, *args, **kwargs):
        self._data = data
        filenames = [k for k in data.keys() if not k.startswith("_")]
        self.param["filename"].objects = filenames
        self.filename = filenames[0]
        self._update_file()
        self._process_traces()
        super().__init__(*args, **kwargs)

    filename = param.ObjectSelector()
    position = param.ObjectSelector()
    measurement = param.ObjectSelector()
    estimator = param.ObjectSelector(
        objects=["variance", "mad", "biweight_midvariance"],
        default="biweight_midvariance",
    )
    fluctuation_colormap = param.ObjectSelector(
        objects=["random", "bin", "gain"], default="bin"
    )
    traces_colormap = param.ObjectSelector(
        objects=["random", "fluorescence", "bin", "area", "gain"], default="bin"
    )
    downsample = param.ObjectSelector(objects=[1000, 100, 30, 10, 1], default=30)
    normalize_traces = param.Boolean(True)

    @param.depends("filename", "position", watch=True)
    def _update_file(self):
        # position
        old_position = self.position
        positions = list(self._data[self.filename].keys())
        self.param["position"].objects = positions
        if old_position not in positions:
            self.position = positions[0]
        # measurement
        old_measurement = self.measurement
        measurements = list(self._data[self.filename][self.position]["traces"].keys())
        self.param["measurement"].objects = measurements
        if old_measurement not in measurements:
            self.measurement = measurements[0]
        data = self._data[self.filename][self.position]
        self._traces = data["traces"][self.measurement][
            1:, :
        ]  # skip background trace (index 0)
        self._regionprops = data["regionprops"]
        self._labels = data["labels"]
        self._img = data["segmentation_frame"]

    @param.depends("filename", "position", "measurement", "estimator", watch=True)
    def _process_traces(self):
        thresh = 30  # TODO: parameterize
        num_qs = 20
        initial_fluorescence_threshold = 3  # in sigma units
        traces = self._traces
        # FROM: numpy.histogram
        traces0 = traces[:, 0]
        bin_edges = np.histogram_bin_edges(traces0, bins="auto")
        bins = pd.Series(pd.cut(traces0, bin_edges), name="bin")
        traces_df = pd.DataFrame(traces)
        bin_counts = bins.groupby(bins).size()
        bin_counts.name = "bin_count"
        trace_info = pd.merge(
            bins, bin_counts, left_on="bin", right_index=True, how="left"
        )
        mask = trace_info["bin_count"] > thresh
        if initial_fluorescence_threshold:
            traces0_loc = astropy.stats.biweight.biweight_location(traces0)
            traces0_scale = astropy.stats.biweight.biweight_scale(
                traces0, modify_sample_size=True
            )
            mask &= (
                traces0 <= traces0_loc + initial_fluorescence_threshold * traces0_scale
            )
            mask &= (
                traces0 >= traces0_loc - initial_fluorescence_threshold * traces0_scale
            )
        traces_df = traces_df[mask]
        trace_info = trace_info[mask]
        # observed=True is important in these groupbys, otherwise plotting colormaps get messed up by all the NaNs
        traces_by_bin = traces_df.groupby(trace_info["bin"], observed=True)
        pbar = (
            traces_df.div(traces_df.iloc[:, 0], axis=0)
            .groupby(trace_info["bin"], observed=True)
            .mean()
        )
        # TODO: nan handling
        mu = traces_by_bin.mean()
        if self.estimator == "variance":
            sigma2 = traces_by_bin.var(ddof=0)
        elif self.estimator == "mad":
            # sigma2 = traces_by_bin.agg(partial(astropy.stats.biweight.mad_std, ignore_nan=True))**2
            sigma2 = traces_by_bin.agg(partial(astropy.stats.biweight.mad_std)) ** 2
        elif self.estimator == "biweight_midvariance":
            sigma2 = traces_by_bin.agg(
                partial(
                    astropy.stats.biweight.biweight_midvariance, modify_sample_size=True
                )
            )
        else:
            raise ValueError
        y = sigma2.div(mu[0], axis="rows")
        qs = np.linspace(0.1, 1, num_qs)  # TODO: make start point adjustable?
        # TODO: the following line is slow (because of pandas indexing), following is a faster version
        # nu_qs = pd.DataFrame(np.array([-1/(1/2*q**2 - 1/3*q**3)*simps(y[pbar >= 1-q].fillna(0), pbar, axis=1) for q in qs]).T, index=y.index, columns=qs)
        nu_qs = []
        y_ary = y.values.copy()  # we need an unmodified copy of y to plot
        pbar_ary = pbar.values
        for q in qs[::-1]:
            y_ary[pbar_ary < 1 - q] = 0
            nu_q = -1 / (1 / 2 * q**2 - 1 / 3 * q**3) * simps(y_ary, pbar, axis=1)
            nu_qs.append(nu_q)
        nu_qs = pd.DataFrame(np.array(nu_qs)[::-1, :].T, index=y.index, columns=qs)
        self._pbar = pbar
        self._y = y
        self._qs = qs
        self._nu_qs = nu_qs
        self._bins = bins

    def _colors(self, colormap):
        if colormap == "random":
            colors = None
        elif colormap == "fluorescence":
            colors = np.log(self._traces[:, 0])
        elif colormap == "bin":
            colors = self._bins.values.codes
        elif colormap == "area":
            colors = self._regionprops["area"].values
        elif colormap == "gain":
            # TODO: make cleaner
            gain = self._nu_qs.iloc[:, -1]
            colors = pd.merge(
                self._bins, gain, left_on="bin", right_index=True, how="left"
            )[gain.name].values
        else:
            raise ValueError
        return colors

    def _bin_colors(self, colormap):
        if colormap == "random":
            colors = None
        elif colormap == "bin":
            colors = self._nu_qs.index.values.codes
        elif colormap == "gain":
            colors = self._nu_qs.iloc[:, -1].values
        else:
            raise ValueError
        return colors

    @param.depends(
        "filename",
        "position",
        "measurement",
        "fluctuation_colormap",
        "traces_colormap",
        "downsample",
        "estimator",
        "normalize_traces",
    )
    def view(self):
        gs = pn.GridSpec(sizing_mode="stretch_both")
        trace_colors = self._colors(self.traces_colormap)
        fluctuation_colors = self._bin_colors(self.fluctuation_colormap)
        gs[0:1, 0:2] = _trace_plot(
            self._traces,
            downsample=self.downsample,
            normalize=self.normalize_traces,
            colors=trace_colors,
        )
        fluct_plots = _fluctuation_plots(
            self._pbar, self._y, self._qs, self._nu_qs, colors=fluctuation_colors
        )
        gs[1:2, 0:1] = fluct_plots[0]
        gs[1:2, 1:2] = fluct_plots[1]
        gs[2:3, 0:2] = _spatial_plots(
            self._regionprops, self._labels, self._img, colors=trace_colors
        )
        return gs
        # return pn.Row(hv.Curve(np.random.random(10)), hv.Points(np.random.random((10,2))))
        # return pn.Column(trace_plots[self.filename], overlay_plots[self.filename])


viewer = PhotobleachingViewer(data, name="Photobleaching")
pn.Column(viewer.param, viewer.view)