In [None]:
import pickle
import re
from datetime import datetime, timedelta
from pathlib import Path

import astropy.stats
import holoviews as hv
import hvplot.pandas
import hvplot.xarray
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from astropy.stats import mad_std, sigma_clip
from astropy.time import Time
from cytoolz import get_in, partial

IDX = pd.IndexSlice

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.io.metadata as metadata
import paulssonlab.projects.sigma_circuits.experiment as experiment
import paulssonlab.projects.sigma_circuits.matriarch_stub as matriarch_stub
import paulssonlab.projects.sigma_circuits.segmentation as segmentation

In [None]:
plt.rcParams["figure.figsize"] = (20, 10)
hv.extension("bokeh")

# Load data

In [None]:
# nd2_filename = "/n/scratch3/groups/hms/sysbio/paulsson/jqs1/210324/PWM_flipped.nd2"
# nd2_filename2 = "/n/scratch3/groups/hms/sysbio/paulsson/jqs1/210204/psc101_bistable.nd2"

In [None]:
# nd2 = matriarch_stub.get_nd2_reader(nd2_filename)

In [None]:
# filename = "/n/groups/paulsson/jqs1/sigma-circuits/210329_expt210324flipped.pickle"
# filename = "/n/groups/paulsson/jqs1/sigma-circuits/210329_expt210204nonflipped_psc101.pickle"
# filename = "/n/groups/paulsson/jqs1/sigma-circuits/210402_expt210331ramp.pickle"
# filename = "/n/groups/paulsson/jqs1/sigma-circuits/210507_expt210505.pickle"
# filename = "/n/groups/paulsson/jqs1/sigma-circuits/210514_expt210511.pickle"
filename = "/n/groups/paulsson/jqs1/sigma-circuits/210605_expt210602_2.pickle"
data = pickle.load(open(filename, "rb"))
df = data["table"]
md = data["metadata"]
mux_log = data["mux_log"]
experiment_txt = data["experiment.txt"]
grid_df = data["grid"]

# Acquisition times

In [None]:
md.keys()

In [None]:
dtimeabsolute = get_in(
    ("SLxPictureMetadata", "dTimeAbsolute"),
    md["image_metadata_sequence"],
)
dtimeabsolute = Time(dtimeabsolute, format="jd").to_datetime()
dtimemsec = get_in(
    ("SLxPictureMetadata", "dTimeMSec"),
    md["image_metadata_sequence"],
)
tdelta = timedelta(milliseconds=dtimemsec)

In [None]:
tdelta.total_seconds()

In [None]:
first_position_times = md["acquisition_times"][:: len(grid_df)] / 1e3  # seconds

In [None]:
last_good_cycle = np.argmax(np.diff(first_position_times) == 0) - 1

In [None]:
seconds_per_timepoint = first_position_times[last_good_cycle] / last_good_cycle

In [None]:
mux_times = [m[0] for m in mux_log]

In [None]:
mux_timepoints = [
    (m - dtimeabsolute).total_seconds() / seconds_per_timepoint for m in mux_times
]

In [None]:
mux_timepoints

In [None]:
mux_timepoints2 = [
    np.argmax((m - dtimeabsolute).total_seconds() <= first_position_times)
    for m in mux_times
]

In [None]:
mux_timepoints2

In [None]:
plt.plot(md["acquisition_times"][:5000])

# Filtered

In [None]:
sensor_width = int(get_in(("image_attributes", "SLxImageAttributes", "uiWidth"), md))
sensor_height = int(get_in(("image_attributes", "SLxImageAttributes", "uiHeight"), md))

In [None]:
%%time
# skip background
cells = df.loc[IDX[:, :, 1:]].join(grid_df).reset_index()
cells["dist_from_center"] = np.sqrt(
    (cells["centroid-0"] - sensor_width) ** 2
    + (cells["centroid-1"] - sensor_height) ** 2
)

In [None]:
cells["area"][cells["area"].between(100, 1e3)].hvplot.hist(bins=50)

In [None]:
cells["dist_from_center"].hvplot.hist(bins=50)

In [None]:
cells["RFP-PENTA"][cells["RFP-PENTA"].between(100, 5e3)].hvplot.hist(bins=100)

In [None]:
cells["GFP-PENTA"][cells["GFP-PENTA"].between(100, 3.5e3)].hvplot.hist(bins=100)

In [None]:
cells["RFP-PENTA"][
    (cells["RFP-PENTA"] > 1) & ~np.isnan(cells["RFP-PENTA"])
].hvplot.hist(logy=True)

In [None]:
%%time
filtered_cells = cells[
    cells["area"].between(150, 400)
    # & cells["GFP-PENTA"].between(100, 8000)
    # & cells["RFP-PENTA"].between(475, 4000)
    # & ~cells["row"].isin(["E"])
    # & (cells["dist_from_center"] < 3000)
    # & (cells["t"] <= 169)
][["GFP-PENTA", "RFP-PENTA", "area", "row", "pos", "t"]]

In [None]:
%%time
groupby = "row"  # row or pos

medians = filtered_cells.groupby([groupby, "t"]).agg(
    ["median", astropy.stats.median_absolute_deviation]
)


def get_limits(x):
    x = x.droplevel(0, axis=1)
    return pd.DataFrame(
        {
            "lower": x["median"] - x["median_absolute_deviation"],
            "upper": x["median"] + x["median_absolute_deviation"],
        }
    )


limits = medians.groupby(level=0, axis=1).apply(get_limits)

In [None]:
observable = "GFP-PENTA"

if groupby == "pos":
    medians2 = medians[observable][["median"]].join(grid_df)
    limits2 = limits[observable].join(grid_df)
    mask = medians2["row"].isin(["A"])  # TODO
    medians2 = medians2[mask]
    limits2 = limits2[mask]
else:
    medians2 = medians[observable].reset_index()
    limits2 = limits[observable].reset_index()
mean_plot = medians2.hvplot.line("t", "median", by=groupby, logy=True)
noise_plot = limits2.hvplot.area(
    x="t", y="lower", y2="upper", by=groupby, stacked=False, alpha=0.2, logy=True
)

In [None]:
(mean_plot * noise_plot).opts(width=800, height=800)

## Heatmap

In [None]:
%%time
channel = "GFP-PENTA"
bins = np.geomspace(filtered_cells[channel].min(), filtered_cells[channel].max(), 100)
heatmap = filtered_cells.groupby(["row", "t"]).apply(
    lambda x: pd.Series(np.histogram(x[channel], bins=bins)[0], index=bins[:-1])
)
heatmap.columns.name = channel
heatmap = xr.Dataset({row: heatmap.loc[row].T for row in heatmap.index.levels[0]})

In [None]:
hv.Layout(
    [
        heatmap[row].hvplot.quadmesh(
            cmap="blues",
            logy=True,
            logz=True,
            clim=(1, 1e4),
        )
        for row in ("A", "B", "C", "D", "E", "F", "G", "H")
    ]
).cols(1)

In [None]:
heatmap.D.hvplot.quadmesh(cmap="blues", logy=True, logz=True, clim=(1, 1e4))

In [None]:
%%time
channel = "GFP-PENTA"
bins = np.geomspace(filtered_cells[channel].min(), filtered_cells[channel].max(), 100)
heatmap_pos = filtered_cells.groupby(["pos", "t"]).apply(
    lambda x: pd.Series(np.histogram(x[channel], bins=bins)[0], index=bins[:-1])
)
heatmap_pos.columns.name = channel
heatmap_pos = xr.concat(
    [xr.DataArray(heatmap_pos.loc[pos].T) for pos in heatmap_pos.index.levels[0]],
    heatmap_pos.index.levels[0],
)

In [None]:
grid_df[grid_df["row"] == "D"]

In [None]:
hv.Layout(
    [
        heatmap_pos.sel(pos=pos).hvplot.quadmesh(
            cmap="blues",
            logy=True,
            logz=True,
            clim=(1, 1e4),
        )
        for pos in range(60, 65)
    ]
).cols(1)