In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import holoviews as hv
from holoviews.operation.datashader import regrid
import hvplot.pandas
import panel as pn
import param
import pickle
import os
from scipy.integrate import simps
import segmentation
from matriarch_stub import RevImage

In [None]:
hv.extension("bokeh")

# Load data

In [None]:
# select file
# checkboxes: positions
# select: var estimator
# select: mean/median/p.095
# double-slider: size filtering (on histogram?)
# slider: bin size
# slider: threshold
# plot: downsampled log traces (checkbox: normalized to t0)
# plot: var as a function of p
# plot: integral as a function of p
# plot: spatial (select: color dimension: initial intensity, etc.)
# plot: counts per bin
# plot: regridded image+seg mask?

In [None]:
base_filename = "/n/groups/paulsson/jqs1/molecule-counting"
filename = "190304photobleaching.pickle"
# filename = '190328photobleaching_flatcorr.pickle'
with open(os.path.join(base_filename, filename), "rb") as f:
    data = pickle.load(f)

In [None]:
# import pyarrow as pa
# buf = pa.serialize(data).to_buffer()
# with pa.OSFile(os.path.join(base_filename, '190328photobleaching_flatcorr.arrow'), 'wb') as f:
#    f.write(buf)

In [None]:
# fluctuation scatter plot/merge fluct+int plots

# colors: initial fluorescence, initial fluorescence bin, area, fluorescence bin gain
# plot update dependencies/cache computations

# fast simpson integration/benchmark
# midweight_bivariance

# check integral values

In [None]:
def _trace_plot(traces, downsample=10, normalize=True, colors=None):
    if colors is None:
        colors = np.random.permutation(traces.shape[0])
        cmap = "Category20"
    else:
        cmap = "inferno"
    y = traces
    if normalize:
        y = y / y[:, 0, np.newaxis]
    curves = [
        {"x": np.arange(traces.shape[1]), "y": y[i], "i": colors[i]}
        for i in range(traces.shape[0] // downsample)
    ]
    plot = hv.Contours(curves, vdims=["i"]).options(
        color_index="i", cmap=cmap, logy=True
    )
    return plot


def _fluctuation_plots(pbar, y, qs, nu_qs, colors=None):
    if colors is None:
        colors = np.random.permutation(y.shape[0])
        cmap = "Category20"
    else:
        cmap = "inferno"
    nus = nu_qs.iloc[:, -1]
    fluctuation_curves = [
        {"x": 1 - pbar.values[i], "y": y.values[i], "i": colors[i]}
        for i in range(y.shape[0])
    ]
    fluctuation_plot = hv.Contours(fluctuation_curves, vdims=["i"]).options(
        color_index="i", cmap=cmap
    )
    integral_curves = [
        {"x": qs, "y": nu_qs.values[i], "i": colors[i]} for i in range(nu_qs.shape[0])
    ]
    integral_plot = (
        hv.Contours(integral_curves, vdims=["i"])
        .redim(y="integral")
        .options(color_index="i", cmap=cmap)
    )
    plots = fluctuation_plot, integral_plot
    return plots


def _spatial_plots(regionprops, labels, img, colors=None):
    if colors is None:
        colors = np.random.permutation(regionprops.shape[0])
        cmap = "Category20"
    else:
        cmap = "inferno"
    regionprops["x"] = regionprops["centroid_x"]
    regionprops["y"] = regionprops["centroid_y"]
    regionprops["color"] = colors
    xy_plot = hv.Points(regionprops, kdims=["x", "y"], vdims=["color"]).options(
        color=hv.dim("color"), cmap=cmap
    )
    image_plot = regrid(RevImage(img))
    labels_plot = regrid(
        RevImage(segmentation.permute_labels(labels)), aggregator="max"
    ).redim(z="label")
    spatial_plots = xy_plot + image_plot + labels_plot
    return spatial_plots


class PhotobleachingViewer(param.Parameterized):
    def __init__(self, data, *args, **kwargs):
        self._data = data
        filenames = list(data.keys())
        self.param["filename"].objects = filenames
        # self.param['filename'].default = filenames[0]
        self.filename = filenames[0]
        self.measurement = list(data[self.filename][0].keys())[0]
        self._update_measurements()
        self._process_traces()
        super().__init__(*args, **kwargs)

    filename = param.ObjectSelector()
    measurement = param.ObjectSelector()
    estimator = param.ObjectSelector(
        objects=["variance", "mad", "biweight_midvariance"],
        default="biweight_midvariance",
    )
    fluctuation_colormap = param.ObjectSelector(
        objects=["random", "bin", "gain"], default="bin"
    )
    xy_colormap = param.ObjectSelector(
        objects=["random", "fluorescence", "bin", "area", "gain"], default="bin"
    )
    traces_colormap = param.ObjectSelector(
        objects=["random", "fluorescence", "bin", "area", "gain"], default="bin"
    )
    downsample = param.ObjectSelector(objects=[1000, 100, 10, 1], default=100)
    normalize_traces = param.Boolean(True)

    @param.depends("filename", watch=True)
    def _update_measurements(self):
        old_measurement = self.measurement
        measurements = list(self._data[self.filename][0].keys())
        self.param["measurement"].objects = measurements
        if old_measurement not in measurements:
            self.measurement = measurements[0]
        traces, regionprops, labels, img = self._data[self.filename]
        self._traces = traces[self.measurement]
        self._regionprops = regionprops
        self._labels = labels
        self._img = img

    @param.depends("filename", "measurement", watch=True)
    def _process_traces(self):
        thresh = 30  # TODO: parameterize
        num_qs = 10
        traces = self._traces
        # bins = np.arange(data[0].min()-1,data[0].max()+1,50)
        # FROM: numpy.histogram
        traces0 = traces[:, 0]
        bin_edges = np.histogram_bin_edges(traces0, bins="auto")
        bins = pd.cut(traces0, bin_edges)
        traces_df = pd.DataFrame(traces)
        bin_count = pd.Series(bins).groupby(bins).size()
        bin_count.name = "bin_count"
        # p.join(bin_count, on='bins')
        pbar = (
            traces_df.div(traces_df.iloc[:, 0], axis=0)
            .groupby(bins)
            .mean()[bin_count > thresh]
        )
        mu = traces_df.groupby(bins).mean()[bin_count > thresh]
        sigma2 = traces_df.groupby(bins).var(ddof=0)[bin_count > thresh]
        y = sigma2.div(mu[0], axis="rows")
        qs = np.linspace(0.1, 1, num_qs)  # TODO: make start point adjustable?
        nu_qs = pd.DataFrame(
            np.array(
                [
                    -1
                    / (1 / 2 * q**2 - 1 / 3 * q**3)
                    * simps(y[pbar >= 1 - q].fillna(0), pbar, axis=1)
                    for q in qs
                ]
            ).T,
            index=y.index,
            columns=qs,
        )
        self._pbar = pbar
        self._y = y
        self._qs = qs
        self._nu_qs = nu_qs

    @param.depends(
        "filename",
        "measurement",
        "fluctuation_colormap",
        "xy_colormap",
        "traces_colormap",
        "downsample",
        "estimator",
        "normalize_traces",
    )
    def view(self):
        traces, regionprops, labels, img = self._data[self.filename]
        traces = traces[self.measurement]
        gs = pn.GridSpec(sizing_mode="stretch_both")
        if self.traces_colormap == "random":
            trace_colors = None
        elif self.traces_colormap == "fluorescence":
            trace_colors = np.log(traces[:, 0])
        elif self.traces_colormap == "bin":
            trace_colors = None  # TODO
        elif self.traces_colormap == "area":
            colors = regionprops["area"]
        elif self.traces_colormap == "gain":
            trace_colors = None  # TODO
        else:
            raise ValueError
        fluctuation_colors = None
        spatial_colors = None
        gs[0:1, 0:2] = _trace_plot(
            traces,
            downsample=self.downsample,
            normalize=self.normalize_traces,
            colors=trace_colors,
        )
        fluct_plots = _fluctuation_plots(
            self._pbar, self._y, self._qs, self._nu_qs, colors=fluctuation_colors
        )
        gs[1:2, 0:1] = fluct_plots[0]
        gs[1:2, 1:2] = fluct_plots[1]
        gs[2:3, 0:2] = _spatial_plots(regionprops, labels, img, colors=spatial_colors)
        return gs
        # return pn.Row(hv.Curve(np.random.random(10)), hv.Points(np.random.random((10,2))))
        # return pn.Column(trace_plots[self.filename], overlay_plots[self.filename])


viewer = PhotobleachingViewer(data, name="Photobleaching")
pn.Column(viewer.param, viewer.view)

In [None]:
traces, regionprops, _, _ = data[
    "/n/scratch2/jqs1/fidelity/190301/jqs_photobleach_100ms_de32_mkate2_0002.nd2"
]
traces = traces["mean"]

In [None]:
# bins = np.arange(data[0].min()-1,data[0].max()+1,50)
# FROM: numpy.histogram
traces0 = traces[:, 0]
bin_edges = np.histogram_bin_edges(traces0, bins="auto")
bins = pd.cut(traces0, bin_edges)
# traces_normed = traces / traces0[:,np.newaxis]
traces_df = pd.DataFrame(traces)
bin_count = pd.Series(bins).groupby(bins).size()
bin_count.name = "bin_count"
# p.join(bin_count, on='bins')
thresh = 30
pbar = (
    traces_df.div(traces_df.iloc[:, 0], axis=0).groupby(bins).mean()[bin_count > thresh]
)
mu = traces_df.groupby(bins).mean()[bin_count > thresh]
sigma2 = traces_df.groupby(bins).var(ddof=0)[bin_count > thresh]

In [None]:
q = 0.5
cq = -1 / (1 / 2 * q**2 - 1 / 3 * q**3)
y = sigma2.div(mu[0], axis="rows")
q_mask = pbar >= 1 - q
nus = pd.Series(cq * simps(y[q_mask].fillna(0), pbar, axis=1), index=y.index, name="nu")

In [None]:
qs = np.linspace(0.1, 1, 10)
nu_qs = pd.DataFrame(
    np.array(
        [
            -1
            / (1 / 2 * q**2 - 1 / 3 * q**3)
            * simps(y[pbar >= 1 - q].fillna(0), pbar, axis=1)
            for q in qs
        ]
    ).T,
    index=y.index,
    columns=qs,
)

In [None]:
nus = nu_qs.iloc[:, -1]

In [None]:
b = pd.Series(bins)

In [None]:
b.name = "bin"

In [None]:
pd.merge(b, nus)

In [None]:
pd.merge(pd.Series(bins, name="bin"), nus, left_on="bin", right_index=True, how="outer")

In [None]:
nu_qs.values[0]

In [None]:
hv.Points(regionprops, kdims=["centroid_x", "centroid_y"], vdims=["area"]).options(
    color=hv.dim("area")
)

In [None]:
regionprops["centroid_x"]

In [None]:
# qs = np.arange(0, 1, 10)
q = 0.5
cq = -1 / (1 / 2 * q**2 - 1 / 3 * q**3)
q_mask = pbar >= 1 - q
# nu_qs = pd.Series(cq*simps(y[q_mask].fillna(0),pbar,axis=1), index=y.index, name='nu')
idxs = np.random.permutation(y.shape[0])
curves = [{"x": qs, "y": nu_qs.values[i], "i": idxs[i]} for i in range(nu_qs.shape[0])]
integral_plot = hv.Contours(curves, vdims=["i"]).options(
    color_index="i", cmap="Category20"
)

In [None]:
integral_plot

In [None]:
integral_plot

In [None]:
# plt.scatter(1-pbar.loc[name].values,y.loc[name].values, color = cmap(c),label = name)

In [None]:
idxs = np.random.permutation(y.shape[0])
curves = [
    {"x": 1 - pbar.values[i], "y": y.values[i], "i": idxs[i]} for i in range(y.shape[0])
]
plot = hv.Contours(curves, vdims=["i"]).options(color_index="i", cmap="Category20")
plot

In [None]:
y.fillna(0)

In [None]:
def fluct_plot(pbar, mu, sigma2, thresh):
    hist_df = df.groupby(df["bin"]).size()
    hist_df = hist_df[hist_df.values > thresh]
    print(hist_df)
    plt.figure(figsize=(12, 8))
    cmap = cm.get_cmap("coolwarm")
    y = sigma2.div(mu[0].values, axis="rows")
    imax = pbar.index[-1].left
    imin = pbar.index[0].left
    q = 1 / 2
    cq = -1 / (1 / 2 * q**2 - 1 / 3 * q**3)
    # plt.vlines(q,0,3)

    for name, group in pbar.groupby("bin"):
        dp = pbar.loc[name].values
        dp = dp[pbar.loc[name].values > 1 - q]
        f = y.loc[name].values[pbar.loc[name].values > 1 - q]
        c = (name.left - imin) / (imax - imin)
        plt.scatter(
            1 - pbar.loc[name].values, y.loc[name].values, color=cmap(c), label=name
        )
        plt.legend()

    plt.title(
        r"$\nu = "
        + str(np.round(-cq))
        + r"\cdot \int\frac{\hat{\sigma}^2}{f_{max}}dp$ ="
        + str(np.round(nu_df.mean(), 2)),
        fontsize=20,
        pad=20,
    )
    plt.xlabel(r"$(1-\hat{p})$", fontsize=20)
    plt.ylabel(r"$\frac{\hat{\sigma}^2}{f_{max}}$", fontsize=20)


thresh = 40

pbar, mu, sigma2 = get_stats(df, thresh)


nu_df = nu_int(pbar, mu, sigma2)
print(nu_df)

fluct_plot(pbar, mu, sigma2, thresh)

In [None]:
measurements, regionprops, labels, img = d[
    "/n/scratch2/jqs1/fidelity/190311/190311_mGFPmut2_100ms_laser100pct_006.nd2"
]

In [None]:
# regionprops.reset_index(inplace=True)
regionprops.head()

In [None]:
%%output size=150
#%%opts Image {+axiswise}
hv.Image(img / img.max()).options(cmap="gray") + hv.Image(labels != 0).options(
    cmap="blues"
)

# Plotting

In [None]:
list(measurements.keys())

In [None]:
traces = []
rp_list = []
rp_df = pd.DataFrame()
for i in range(7):
    print(i)
    measurements, regionprops, labels, img = d[
        "/n/scratch2/jqs1/fidelity/190311/190311_mGFPmut2_100ms_laser100pct_00"
        + str(i)
        + ".nd2"
    ]
    traces.append(measurements["mean"][1:])
    rp_list.append(regionprops)
    print(measurements["mean"][1:].shape)
traces = np.concatenate(traces)
rp_df = pd.concat(rp_list, sort=False)
rp_df.reset_index(inplace=True)
print(traces.shape)
data = pd.DataFrame(traces)  # + np.random.normal(0,1,Gsamp.T.shape))

bins = np.arange(data[0].min() - 1, data[0].max() + 1, 50)
data["bin"] = pd.cut(data[0], bins=bins)
data = pd.concat([data, rp_df], axis=1, sort=False)
# rp_df['initial_intensity'] = df[0]
t_end = traces.shape[1]

In [None]:
%%output size=250
# traces = measurements['mean']
idxs = np.random.permutation(traces.shape[0])
downsample = (
    10  # set to 1 to show all traces (instead of 10%); this will make your browser slow
)
curves = [
    {"x": np.arange(traces.shape[1]), "y": traces[i], "i": idxs[i]}
    for i in range(traces.shape[0] // downsample)
]
hv.Contours(curves, vdims=["i"]).options(color_index="i", cmap="Category20", logy=True)

In [None]:
%matplotlib inline
import pylab as plt
import pandas as pd
import numpy as np
from matplotlib import cm
from scipy.integrate import simps


def filter_df(df, prop_dict):
    processed_df = df.copy()
    for prop in prop_dict:
        processed_df = processed_df[
            (processed_df[prop] > prop_dict[prop][0])
            & (processed_df[prop] < prop_dict[prop][1])
        ]

    return processed_df


def get_stats(df, thresh=30):
    p = df.iloc[:, :t_end].apply(lambda x: x / x[0], axis=1)
    p["bin"] = df["bin"]
    pbar = p.groupby(p["bin"]).mean()[p.groupby(p["bin"]).size() > thresh]
    mu = (
        df.iloc[:, :t_end]
        .groupby(df["bin"])
        .mean()[df.groupby(df["bin"]).size() > thresh]
    )
    sigma2 = (
        df.iloc[:, :t_end]
        .groupby(df["bin"])
        .var(ddof=0)[df.groupby(df["bin"]).size() > thresh]
    )

    return pbar, mu, sigma2


def nu_int(pbar, mu, sigma, q=1):
    nu_dict = {}  # pd.Series()
    cq = -1 / (1 / 2 * q**2 - 1 / 3 * q**3)
    y = sigma2.div(mu[0].values, axis="rows")
    for name, group in pbar.groupby("bin"):
        dp = pbar.loc[name].values
        dp = dp[pbar.loc[name].values > 1 - q]
        f = y.loc[name].values[pbar.loc[name].values > 1 - q]
        nu_dict[name] = cq * simps(f, dp)

    return pd.Series(nu_dict)


def fluct_plot(pbar, mu, sigma2, thresh):
    hist_df = df.groupby(df["bin"]).size()
    hist_df = hist_df[hist_df.values > thresh]
    print(hist_df)

    plt.figure(figsize=(12, 8))

    cmap = cm.get_cmap("coolwarm")

    y = sigma2.div(mu[0].values, axis="rows")
    imax = pbar.index[-1].left
    imin = pbar.index[0].left
    q = 1 / 2
    cq = -1 / (1 / 2 * q**2 - 1 / 3 * q**3)
    # plt.vlines(q,0,3)

    for name, group in pbar.groupby("bin"):
        dp = pbar.loc[name].values
        dp = dp[pbar.loc[name].values > 1 - q]
        f = y.loc[name].values[pbar.loc[name].values > 1 - q]
        c = (name.left - imin) / (imax - imin)
        plt.scatter(
            1 - pbar.loc[name].values, y.loc[name].values, color=cmap(c), label=name
        )
        plt.legend()

    plt.title(
        r"$\nu = "
        + str(np.round(-cq))
        + r"\cdot \int\frac{\hat{\sigma}^2}{f_{max}}dp$ ="
        + str(np.round(nu_df.mean(), 2)),
        fontsize=20,
        pad=20,
    )
    plt.xlabel(r"$(1-\hat{p})$", fontsize=20)
    plt.ylabel(r"$\frac{\hat{\sigma}^2}{f_{max}}$", fontsize=20)


thresh = 40

prop_dict = {
    "centroid_x": [300, 1000],
    "centroid_y": [750, 1500],
    "area": [30, 150],
    0: [2000, 10000],
}
# prop_dict = {'area': [30,150],
#              0: [2000, 10000]}

df = filter_df(data, prop_dict)
pbar, mu, sigma2 = get_stats(df, thresh)


nu_df = nu_int(pbar, mu, sigma2)
print(nu_df)

fluct_plot(pbar, mu, sigma2, thresh)

In [None]:
%%output size=250
# traces = measurements['mean']
idxs = np.random.permutation(traces.shape[0])
downsample = (
    10  # set to 1 to show all traces (instead of 10%); this will make your browser slow
)
curves = [
    {"x": np.arange(traces.shape[1]), "y": df[i], "i": idxs[i]}
    for i in range(traces.shape[0] // downsample)
]
hv.Contours(curves, vdims=["i"]).options(color_index="i", cmap="Category20", logy=True)

In [None]:
I0 = traces[:, 0]
plt.hist(I0, bins=30)
print(np.mean(I0), np.var(I0))

In [None]:
from mpl_toolkits.mplot3d import Axes3D

cmap = cm.get_cmap("coolwarm")
fig = plt.figure(figsize=(12, 8))
# ax = fig.add_subplot(111, projection='3d')
imax = pbar.index[-1].left
imin = pbar.index[0].left
for name, group in df.groupby("bin"):
    if group.shape[0] > thresh:
        c = (name.left - imin) / (imax - imin)
        #         c = (nu_dict[name] - min(nu_int))/(max(nu_int) - min(nu_int))
        z = name.left * np.ones(group.centroid_x.shape)
        plt.scatter(
            group.centroid_x, group.centroid_y, color=cmap(c), label=nu_df[name]
        )
        plt.xlim(data.centroid_x.min(), data.centroid_x.max())
        plt.ylim(data.centroid_y.min(), data.centroid_y.max())
    else:
        plt.scatter(group.centroid_x, group.centroid_y, color="k", alpha=0.1)