In [None]:
import glob
import os

import dask
import dask.array as da
import holoviews as hv
import matplotlib.pyplot as plt
import nd2reader
import numpy as np
import zarr
from cytoolz import compose, partial
from dask import delayed
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster, progress

In [None]:
hv.extension("bokeh")

In [None]:
import segmentation
from matriarch_stub import recursive_sequence_map

In [None]:
def nd2_to_dask(filename, channel, rechunk=True):
    nd2 = segmentation.get_nd2_reader(filename)
    num_positions = nd2.sizes.get("v", 1)
    num_timepoints = nd2.sizes.get("t", 1)
    frame0 = segmentation.get_nd2_frame(filename, 0, channel, 0)
    _get_nd2_frame = delayed(segmentation.get_nd2_frame, pure=True)
    ary = [
        [
            da.from_delayed(
                _get_nd2_frame(filename, position, channel, t),
                dtype=frame0.dtype,
                shape=frame0.shape,
            )
            for t in range(num_timepoints)
        ]
        for position in range(num_positions)
    ]
    ary = recursive_sequence_map(da.stack, ary, max_level=1)
    if rechunk:
        if num_timepoints >= 5:
            ary = ary.rechunk({1: "auto"})
        else:
            ary = ary.rechunk({0: "auto"})
    return ary

# Cluster

In [None]:
dask.config.config["distributed"]["scheduler"]["allowed-failures"] = 20
# dask.config.config['distributed']['worker']['memory'] = {'target': 0.4,
#                                                         'spill': 0.5,
#                                                         'pause': 0.9,
#                                                         'terminate': 0.95}

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="03:00:00",
    memory="16GB",
    # job_extra=['--exclude=compute-e-16-181,compute-e-16-186'],
    local_directory="/tmp",
    log_directory="/home/jqs1/projects/molecule-counting/log",
    cores=1,
    processes=1,
)
# diagnostics_port=('127.0.0.1', 8787),
# env_extra=['export PYTHONPATH=\"/home/jqs1/projects/matriarch\"'])
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

# Analysis

## One FOV long

In [None]:
a = nd2_to_dask("/n/scratch2/jqs1/190411/TADA_100pct_laser_300ms_000.nd2", 0)

In [None]:
a

## Big run

In [None]:
ary = nd2_to_dask("/n/scratch2/jqs1/190411/TADA_scan_300ms_100pct.nd2", 0)

In [None]:
ary2 = ary  # ary.rechunk({2: 512, 3: 512})

In [None]:
def _polyfit_image(ary, degree=1):
    shape = ary.shape
    p = np.polyfit(np.arange(shape[0]), ary.reshape((shape[0], -1)), degree)
    p = p.reshape((degree + 1, *shape[1:]))
    return p


def polyfit_image(ary, degree=1):
    return da.apply_gufunc(
        _polyfit_image,
        "(k,i,j)->(p,i,j)",
        ary,
        output_dtypes=np.float32,
        output_sizes={"p": degree + 1},
        vectorize=True,
        allow_rechunk=True,
    )

In [None]:
# linear_fits = da.apply_gufunc(polyfit_image, "(k,i,j)->(i,j),(i,j)", np.log(ary[:10,:50]), output_dtypes=(np.float32,np.float32), vectorize=True, allow_rechunk=True)
linear_fits_futures = client.compute(polyfit_image(np.log(ary2[:, :10])))

In [None]:
p = client.gather(linear_fits_futures)

In [None]:
linear_fits_futures10 = client.compute(polyfit_image(np.log(ary2[:, :10])))
linear_fits_futures50 = client.compute(polyfit_image(np.log(ary2[:, :50])))
linear_fits_futures100 = client.compute(polyfit_image(np.log(ary2[:, :100])))

In [None]:
p10 = client.gather(linear_fits_futures10)
p50 = client.gather(linear_fits_futures50)
p100 = client.gather(linear_fits_futures100)

In [None]:
%store p10
%store p50
%store p100

In [None]:
plt.imshow(p10[1][0])

In [None]:
plt.figure(figsize=(14, 14))
plt.imshow(p10[21, 0])

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(p[0][6] / p[0][5])

In [None]:
plt.figure(figsize=(14, 14))
plt.imshow(-p[0][6])

In [None]:
plt.figure(figsize=(14, 14))
plt.imshow(-np.median(p[0], axis=0))

In [None]:
plt.figure(figsize=(14, 14))
plt.imshow(np.median(p[1] / p[1].mean(axis=(1, 2))[:, np.newaxis, np.newaxis], axis=0))

In [None]:
plt.figure(figsize=(14, 14))
plt.imshow(
    (np.median(p[1] / p[1].mean(axis=(1, 2))[:, np.newaxis, np.newaxis], axis=0))
    / -np.median(p[0], axis=0)
)

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(p[0][6].flat, p[0][3].flat, s=0.1, alpha=0.1)

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(p[1][6].flat, p[1][3].flat, s=0.1, alpha=0.1)

### Old

In [None]:
b = ary[:, :, ::64, ::64].compute()

In [None]:
plt.plot(b[0, :, 0, 0])

In [None]:
b.shape

In [None]:
bg = 0  # 4000

In [None]:
b.min()

In [None]:
c = b[:, :, 10, 10].T

In [None]:
# normed_traces = ((b[800:1300]-bg)/(b[800:850]-bg).mean(axis=0))[:,::32,::32]
# normed_traces = ((c[800:1300]-bg)/(c[800:850]-bg).mean(axis=0))
normed_traces = (c[:] - bg) / (c[:] - bg).mean(axis=0)

In [None]:
normed_traces.shape

In [None]:
plt.plot(normed_traces[:, 30])

In [None]:
normed_traces.T.shape

In [None]:
ary.shape

In [None]:
z1 = ary[3, 740].compute()
z2 = ary[3, 770].compute()

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(z2 - z1 > 10000)

In [None]:
plt.figure(figsize=(12, 12))
plt.plot(normed_traces[:, 3])

In [None]:
plt.figure(figsize=(30, 12))
plt.plot(np.log(normed_traces / normed_traces[-10, :]))
# plt.plot(np.log(normed_traces).reshape((normed_traces.shape[0],-1)));

In [None]:
((b[800:1300] - bg) / (b[800:850] - bg).mean(axis=0)).shape

In [None]:
c = np.log(((b[800:1300] - bg) / (b[800:850] - bg).mean(axis=0))[:, ::, ::])
log_traces = c.reshape((c.shape[0], -1))

In [None]:
log_traces.shape

In [None]:
p = np.polyfit(np.arange(log_traces.shape[0]), log_traces, 1)

In [None]:
x = int(np.sqrt(log_traces.shape[1]))
plt.figure(figsize=(12, 12))
plt.imshow(p[0].reshape((x, x)))

In [None]:
plt.figure(figsize=(18, 12))
plt.plot(log_traces)

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(b[0])

In [None]:
# GIT COMMIT OLD STUFF
# update nd2_to_dask
# background subtraction (to maximize linearity)
# truncate to exponential section
# get best fit slopes/intercepts
# scatter plot comparison of two positions' slopes [DONE]
# compare slopes early vs. late (non exp. vs approx exp regime)