In [None]:
from trenchripper.utils import kymo_handle, pandas_hdf5_handler, writedir
from trenchripper.segment import phase_segmentation, phase_segmentation_cluster
from trenchripper.cluster import dask_controller
import matplotlib.pyplot as plt
import numpy as np
import skimage as sk
import os
import h5py
import resource

In [None]:
import dask

In [None]:
headpath = "/n/scratch2/bj66/vibrio_37_mux_salt_concentration"

In [None]:
segmenter = phase_segmentation_cluster(headpath, seg_channel="Phase")
segmenter.bit_max = None

In [None]:
data = segmenter.load_trench_array_list(np.random.randint(145))

In [None]:
data.shape

In [None]:
# trench = data[0,:]
trench = data[np.random.randint(data.shape[0]), :]
timepoint = np.random.randint(trench.shape[0])
img = trench[timepoint, :, :]
fig1, ax1 = plt.subplots(figsize=(10, 10))
ax1.imshow(img)

In [None]:
conn_comp, trench_masks, img_mask, maxima = segmenter.segment(
    img, return_all=True, show_plots=True
)

In [None]:
fig1, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, figsize=(8, 10))

ax1.imshow(img, cmap="gray")
ax2.imshow(conn_comp, cmap="inferno_r")
ax3.imshow(img_mask, cmap="gray")
ax4.imshow(maxima, cmap="gray")
if trench_masks is not None:
    ax5.imshow(trench_masks, cmap="gray")

In [None]:
dc = dask_controller(
    walltime="12:00:00",
    local=False,
    n_workers=50,
    memory="7.5GB",
    death_timeout=120.0,
    working_directory=headpath + "/dask",
)
dc.startdask()
dc.daskcluster.start_workers()

In [None]:
dc.displaydashboard()

In [None]:
segmenter.dask_segment(dc)

In [None]:
dc.shutdown()

# Check results

In [None]:
file_idx = np.random.randint(200)
with h5py.File(
    segmenter.phasesegmentationpath + "/segmentation_" + str(file_idx) + ".hdf5", "r"
) as input_file:
    data = input_file["data"][:]

In [None]:
trench_idx = np.random.randint(data.shape[0])
time_idx = np.random.randint(data.shape[1])
img = data[trench_idx, time_idx, :, :]
fig1, ax1 = plt.subplots(figsize=(10, 10))
ax1.imshow(img)
print("File %d, Trench %d, Time %d" % (file_idx, trench_idx, time_idx))

# Get loading fractions

In [None]:
dc = dask_controller(
    walltime="4:00:00",
    local=False,
    n_workers=40,
    memory="4GB",
    death_timeout=120.0,
    working_directory=headpath + "/dask",
)
dc.startdask()
dc.daskcluster.start_workers()

In [None]:
dc.displaydashboard()

In [None]:
segmenter.dask_characterize_trench_loading(dc)

In [None]:
segmenter.dask_postprocess_trench_loading(dc)

In [None]:
trench_loadings = np.concatenate(
    dc.daskclient.gather(dc.futures["Trench Loading"]), axis=0
)

In [None]:
dc.shutdown()

In [None]:
fig1, ax1 = plt.subplots(figsize=(10, 10))

ax1.hist(trench_loadings, bins=40)
ax1.set_ylabel("Count")
ax1.set_xlabel("Loading fraction")

# Test get cell properties

In [None]:
kymodf = segmenter.meta_handle.read_df("kymograph", read_metadata=True)

In [None]:
metadata = kymodf.metadata

In [None]:
kymodf = kymodf.reset_index()
kymodf = kymodf.set_index(["File Index", "File Trench Index", "timepoints"])

In [None]:
test = kymodf.loc[0]

In [None]:
len(kymodf.index.unique("File Trench Index"))

In [None]:
times = kymodf.loc[file_idx, "time (s)"]
global_trench_indices = kymodf.loc[file_idx, "trenchid"]
trench_loadings = kymodf.loc[file_idx, "Trench Loading"]

In [None]:
writedir(segmenter.phasedatapath, overwrite=True)

In [None]:
columns = [
    "area",
    "bbox",
    "centroid",
    "convex_area",
    "eccentricity",
    "equivalent_diameter",
    "extent",
    "label",
    "major_axis_length",
    "minor_axis_length",
    "orientation",
    "perimeter",
    "solidity",
]

In [None]:
segmenter.extract_cell_data(
    file_idx, data, times, global_trench_indices, trench_loadings, columns, metadata
)

In [None]:
from pandas import HDFStore

store = HDFStore(os.path.join(segmenter.phasedatapath, "data_%d.h5" % file_idx))

In [None]:
testdf = store.get("metrics")

In [None]:
testdf.shape

In [None]:
testdf.tail(10)

# Get cell properties

In [None]:
dc = dask_controller(
    walltime="10:00:00",
    local=False,
    n_workers=100,
    memory="6GB",
    death_timeout=120.0,
    working_directory=headpath + "/dask",
)
dc.startdask()
dc.daskcluster.start_workers()

In [None]:
dc.displaydashboard()

In [None]:
columns = [
    "area",
    "bbox",
    "centroid",
    "convex_area",
    "eccentricity",
    "equivalent_diameter",
    "extent",
    "label",
    "major_axis_length",
    "minor_axis_length",
    "orientation",
    "perimeter",
    "solidity",
]
segmenter.dask_extract_cell_data(dc, columns)

In [None]:
dc.shutdown()

In [None]:
lost_count = 0
for key, value in dc.futures.items():
    if value.status == "lost":
        print(key)

In [None]:
kymodf = segmenter.meta_handle.read_df("kymograph", read_metadata=True)
metadata = kymodf.metadata
#         width = metadata['kymograph_params']['trench_width_x']
#         height = metadata['kymograph_params']['ttl_len_y']
file_list = kymodf["File Index"].unique().tolist()
num_file_jobs = len(file_list)

In [None]:
lost_count = 0
for key, value in dc.futures.items():
    if value.status == "lost":
        print(key)

# Check proprety extraction results

In [None]:
from pandas import HDFStore
import pandas as pd
import random

file_idx = 9
p = os.path.join(segmenter.phasesegmentationpath, "data_%d.h5" % file_idx)
test_df = pd.read_hdf(p)

In [None]:
len(test_df.index.unique("time_s")

In [None]:
test_df.head(3)

In [None]:
trenches = list(test_df.index.unique("file_trench_index"))
random.shuffle(trenches)
trench_idx = trenches[0]
time_idx = 0

In [None]:
# Show segmentation mask
with h5py.File(
    segmenter.phasesegmentationpath + "/segmentation_" + str(file_idx) + ".hdf5", "r"
) as input_file:
    data = input_file["data"][:]
    img = data[trench_idx, time_idx, :, :]
    fig1, ax1 = plt.subplots(figsize=(10, 10))
    ax1.imshow(img)
    print("File %d, Trench %d, Time %d" % (file_idx, trench_idx, time_idx))

In [None]:
times = test_df.loc[trench_idx].index.unique("time_s")

In [None]:
len(times)

In [None]:
fig1, axes = plt.subplots(1, 15, figsize=(20, 10))
for i, ax in enumerate(axes):
    ax.imshow(data[trench_idx, time_idx + i, :, :], cmap="inferno_r")
    ax.scatter(
        test_df.loc[trench_idx, times[time_idx + i]]["centy"],
        test_df.loc[trench_idx, times[time_idx + i]]["centx"],
    )

# Get division times

# Get lineage single-cell growth rates

In [None]:
from trenchripper.DetectPeaks import detect_peaks

In [None]:
import scipy.signal as signal

In [None]:
mother_cell = test_df.loc[trench_idx, :, 1]

In [None]:
major_axis_length = np.array(mother_cell["major_axis_length"])

In [None]:
plt.plot(major_axis_length)

In [None]:
mal_smoothed = signal.wiener(major_axis_length)
plt.plot(mal_smoothed)

In [None]:
mal_smoothed = signal.savgol_filter(major_axis_length, 5, 2)
plt.plot(mal_smoothed)

In [None]:
detect_peaks(major_axis_length, mpd=3, show=True)

In [None]:
detect_peaks(mal_smoothed, mpd=3, show=True)

In [None]:
detect_peaks(mal_smoothed, mpd=3, show=True)