# TrenchRipper Master Notebook

## Introduction

This notebook contains the entire `TrenchRipper` pipline, divided into simple steps. This pipline is ideal for Mother <br>Machine image data where cells possess fluorescent segmentation markers. Segmentation on phase or brightfield data <br>is being developed, but is still an experimental feature.

The steps in this pipeline are as follows:
1. Extracting your Mother Machine data (.nd2) into hdf5 format
2. Identifying and cropping individual trenches into kymographs
3. Segmenting cells with a fluorescent marker
4. Determining lineages and object properties

In each step, the user will dynamically specify parameters using a series of interactive diagnostics on their dataset. <br>Following this, a parameter file will be written to disk and then used to deploy a parallel computation on the <br>dataset, either locally or on a SLURM cluster.


This is intended as an end-to-end solution to analyzing Mother Machine data. As such, **it is not trivial to plug data <br>directly into intermediate steps**, as it will lack the correct formatting and associated metadata. A notable <br>exception to this is using another program to segment data. The library references binary segmentation masks using <br>only metadata derived from their associated kymographs. As such, it is possible to generate segmentations on these <br>kymographs elsewhere and place them into the segmentation data path to have `TrenchRipper` act on those <br>segmentations instead. More on this in the segmentation section...

#### Imports

Run this section to import all relavent packages and libraries used in this notebook. You must run this everytime you open a new python kernel.

In [None]:
import paulssonlab.deaton.trenchripper.trenchripper as tr

import warnings

warnings.filterwarnings(action="once")

import matplotlib

matplotlib.rcParams["figure.figsize"] = [20, 10]

#### Specify Paths

Begin by defining the directory in which all processing will be done, as well as the initial nd2 file we will be <br>processing. This line should be run everytime you open a new python kernel.

The format should be: `headpath = "/path/to/folder"` and `nd2file = "/path/to/file.nd2"`

For example:
```
headpath = "/n/scratch2/de64/2019-05-31_validation_data"
nd2file = "/n/scratch2/de64/2019-05-31_validation_data/Main_Experiment.nd2"
```

Ideally, these files should be placed in a storage location with relatively fast I/O

In [None]:
headpath = "/n/scratch2/de64/2019-11-09_CN_Growth_Curve/"
nd2file = "/n/scratch2/de64/2019-11-09_CN_Growth_Curve/CN_Limited_GC_restart.nd2"

In [None]:
viewer = tr.hdf5_viewer(headpath)
viewer.view()

## Extract to hdf5 files

In this section, we will be extracting our image data. Currently this notebook only supports `.nd2` format; however <br>there are `.tiff` extractors in the TrenchRipper source files that are being added to `Master.ipynb` soon.

In the abstract, this step will take a single `.nd2` file and split it into a set of `.hdf5` files stored in <br>`headpath/hdf5`. Splitting the file up in this way will facilitate quick procesing in later steps. Each field of <br>view will be split into one or more `.hdf5` files, depending on the number of images per file requested (more on <br>this later). 

To keep track of which output files correspond to which FOVs, as well as to keep track of experiment metadata, the <br>extractor also outputs a `metadata.hdf5` file in the `headpath` folder. The data from this step is accessible in <br>that `metadata.hdf5` file under the `global` key. If you would like to look at this metadata, you may use the <br>`tr.utils.pandas_hdf5_handler` to read from this file. Later steps will add additional metadata under different <br>keys into the `metadata.hdf5` file.

#### Start Dask Workers

First, we start a `dask_controller` instance which will handle all of our parallel processing. The default parameters <br>here work well on O2. The critical arguments here are:

**walltime** : For a cluster, the length of time you will request each node for.

**local** : `True` if you want to perform computation locally. `False` if you want to perform it on a SLURM cluster.

**n_workers** : Number of nodes to request if on the cluster, or number of processes if computing locally.

**memory** : For a cluster, the amount of memory you will request each node for.

**working_directory** : For a cluster, the directory in which data will be spilled to disk. Usually set as a folder in <br>the `headpath`.

In [None]:
dask_controller = tr.trcluster.dask_controller(
    walltime="04:00:00",
    local=False,
    n_workers=20,
    memory="2GB",
    working_directory=headpath + "/dask",
)
dask_controller.startdask()

After running the above line, you will have a running Dask client. Run the line below and click the link to supervise <br>the computation being administered by the scheduler. 

Don't be alarmed if the screen starts mostly blank, it may take time for your workers to spin up. If you get a 404 <br>error on a cluster, it is likely that your ports are not being forwarded properly. If this occurs, please register <br>the issue on github.

In [None]:
dask_controller.daskclient

##### Perform Extraction

Now that we have our cluster scheduler spun up, it is time to convert files. This will be handled by the <br>`hdf5_extractor` object. This extractor will pull up each FOV and split it such that each derived `.hdf5` file <br>contains, at maximum, N timepoints of that FOV per file. The image data stored in these files takes the <br>form of `(N,Y,X)` arrays that are accessible using the desired channel name as a key. 

The arguments for this extractor are:

 - **nd2file** : The filepath to the `.nd2` file you intend to extract.
 
 - **headpath** : The folder in which processing is occuring. Should be the same for each step in the pipeline.

 - **tpts_per_file** : The maximum number of timepoints stored in each output `.hdf5` file. Typical values are between 25 <br>and 100.

 - **ignore_fovmetadata** : Used when `.nd2` data is corrupted and does not possess records for stage positions or <br>timepoints. Only set `False` if the extractor throws errors on metadata handling.

 - **nd2reader_override** : Overrides values in metadata recovered using the `nd2reader`. Currently set to <br>`{"z_levels":[],"z_coordinates":[]}` by default to correct a known issue where z coordinates are mistakenly <br>interpreted as a z stack. See the [nd2reader](https://rbnvrw.github.io/nd2reader/) documentation for more info.

In [None]:
hdf5_extractor = tr.ndextract.hdf5_fov_extractor(
    nd2file,
    headpath,
    tpts_per_file=50,
    ignore_fovmetadata=False,
    nd2reader_override={"z_levels": [], "z_coordinates": []},
)

##### Extraction Parameters

Here, you may set the time interval you want to extract. Useful for cropping data to the period exhibiting the dynamics of interest.

Optionally take notes to add to the `metadata.hdf5` file. Notes may also be taken directly in this notebook.

In [None]:
hdf5_extractor.inter_set_params()

##### Begin Extraction 

Running the following line will start the extraction process. This may be monitored by examining the `Dask Dashboard` <br> under the link displayed earlier. Once the computation is complete, move to the next line.

This step may take a long time, though it is possible to speed it up using additional workers.

In [None]:
hdf5_extractor.extract(dask_controller)

##### Shutdown Dask

Once extraction is complete, it is likely that you will want to shutdown your `dask_controller` if you are on a <br>
cluster. This is because the specifications of the current `dask_controller` will not be optimal for later steps. <br>
To do this, run the following line and wait for it to complete. If it hangs, interrupt your kernel and re-run it. <br>
If this also fails to shutdown your workers, you will have to manually shut them down using `scancel` in a terminal.

In [None]:
dask_controller.shutdown()

## Kymographs

Now that you have extracted your data into a series of `.hdf5` files, we will now perform identification and cropping <br>of the individual trenches/growth channels present in the images. This algorithm assumes that your growth trenches <br>are vertically aligned and that they alternate in their orientation from top to bottom. See the example image for the <br>correct geometry:

![example_image](./resources/example_image.jpg)

The output of this step will be a set of `.hdf5` files stored in `headpath/kymograph`. The image data stored in these <br>files takes the form of `(K,T,Y,X)` arrays where K is the trench index, T is time, and Y,X are the crop dimensions. <br>These arrays are accessible using keys of the form `"[Image Channel]"`. For example, looking up phase channel <br>data of trenches in the topmost row of an image will require the key `"Phase"`

### Test Parameters



##### Initialize the interactive kymograph class

As a first step, initialize the `tr.interactive.kymograph_interactive` class that will be help us choose the <br>parameters we will use to generate kymographs. 

In [None]:
interactive_kymograph = tr.kymograph_interactive(headpath)

##### Examine Images

Here you can manually inspect images before beginning parameter tuning.

In [None]:
interactive_kymograph.view_image_interactive()

You will now want to select a few test FOVs to try out parameters on, the channel you want to detect trenches on, and <br>the time interval on which you will perform your processing.

The arguments for this step are:

- **seg_channel (string)** : The channel name that you would like to segment on.

- **invert (list)** : Whether or not you want to invert the image before detecting trenches. By default, it is assumed that <br>the trenches have a high pixel intensity relative to the background. This should be the case for Phase Contrast and <br>Fluorescence Imageing, but may not be the case for Brightfield Imaging, in which case you will want to invert the image.

- **fov_list (list)** : List of integers corresponding to the FOVs that you wish to make test kymographs of.

- **t_subsample_step (int)** : Step size to be used for subsampling input files in time, recommend that subsampling results in <br>between 5 and 10 timepoints for quick processing.

Hit the "Run Interact" button to lock in your parameters. The button will become transparent briefly and become solid again <br>when processing is complete. After that has occured, move on to the next step. 

# Napari Viewer

In [None]:
%gui qt5
# Note that this Magics command needs to be run in a cell
# before any of the Napari objects are instantiated to
# ensure it has time to finish executing before they are
# called

from skimage import data
import napari

In [None]:
%gui qt5

### Issue

Napari does not function in line in the Jupyter notebook. May be annoying to integrate with the remote framework the rest of trenchripper is built on. For now, just try working with holoviews/datashader

# Holoviews Viewer

In [None]:
import pandas as pd
import numpy as np
import holoviews as hv
import dask.array as da
import xarray as xr
import h5py

from holoviews.operation.datashader import regrid

hv.extension("bokeh")

In [None]:
class hdf5_viewer:
    def __init__(
        self, headpath, compute_data=False, persist_data=False, select_fovs=[]
    ):
        meta_handle = tr.pandas_hdf5_handler(headpath + "metadata.hdf5")
        hdf5_df = meta_handle.read_df("global", read_metadata=True)
        metadata = hdf5_df.metadata
        index_df = pd.DataFrame(range(len(hdf5_df)), columns=["lookup index"])
        index_df.index = hdf5_df.index
        hdf5_df = hdf5_df.join(index_df)
        self.channels = metadata["channels"]
        if len(select_fovs) > 0:
            fov_indices = select_fovs
        else:
            fov_indices = hdf5_df.index.get_level_values("fov").unique().tolist()
        file_indices = hdf5_df["File Index"].unique().tolist()

        dask_arrays = []

        for fov_idx in fov_indices:
            fov_arrays = []
            fov_df = hdf5_df.loc[fov_idx:fov_idx]
            file_indices = fov_df["File Index"].unique().tolist()
            for channel in self.channels:
                channel_arrays = []
                for file_idx in file_indices:
                    infile = h5py.File(
                        headpath + "/hdf5/hdf5_" + str(file_idx) + ".hdf5", "r"
                    )
                    data = infile[channel]
                    array = da.from_array(
                        data, chunks=(1, data.shape[1], data.shape[2])
                    )
                    channel_arrays.append(array)
                da_channel_arrays = da.concatenate(channel_arrays, axis=0)
                fov_arrays.append(da_channel_arrays)
            da_fov_arrays = da.stack(fov_arrays, axis=0)
            dask_arrays.append(da_fov_arrays)
        self.main_array = da.stack(dask_arrays, axis=0)
        if compute_data:
            self.main_array = self.main_array.compute()
        elif persist_data:
            self.main_array = self.main_array.persist()

    def view(
        self, width=1000, height=1000, cmap="Greys_r", hist_on=False, hist_color="grey"
    ):
        # Wrap in xarray DataArray and label coordinates
        dims = [
            "FOV",
            "Channel",
            "time",
            "y",
            "x",
        ]
        coords = {d: np.arange(s) for d, s in zip(dims, self.main_array.shape)}
        coords["Channel"] = np.array(self.channels)
        xrstack = xr.DataArray(
            self.main_array, dims=dims, coords=coords, name="Data"
        ).astype("uint16")

        # Wrap in HoloViews Dataset
        ds = hv.Dataset(xrstack)

        # # Convert to stack of images with x/y-coordinates along axes
        image_stack = ds.to(hv.Image, ["x", "y"], dynamic=True)

        # # Apply regridding if each image is large
        regridded = regrid(image_stack)

        # # Set a global Intensity range
        # regridded = regridded.redim.range(Intensity=(0, 1000))

        # # Set plot options
        display_obj = regridded.opts(
            plot={
                "Image": dict(
                    colorbar=True, width=width, height=height, tools=["hover"]
                )
            }
        )
        display_obj = display_obj.opts(cmap=cmap)

        if hist_on:
            hist = hv.operation.histogram(image_stack, num_bins=30)
            hist = hist.opts(line_width=0, color=hist_color, width=200, height=height)
            return display_obj << hist
        else:
            return display_obj

In [None]:
dims = [
    "FOV",
    "Channel",
    "time",
    "y",
    "x",
]
coords = {d: np.arange(s) for d, s in zip(dims, viewer.main_array.shape)}
coords["Channel"] = np.array(viewer.channels)
xrstack = xr.DataArray(viewer.main_array, dims=dims, coords=coords, name="Data").astype(
    "uint16"
)

In [None]:
xrstack

In [None]:
ds = hv.Dataset(xrstack)

In [None]:
viewer = hdf5_viewer(headpath)

In [None]:
viewer.view(hist_on=False)

In [None]:
hv.operation.datashader.regrid

In [None]:
fov_indices = hdf5_df.index.get_level_values("fov").unique().tolist()

In [None]:
file_indices = hdf5_df["File Index"].unique().tolist()

In [None]:
dask_arrays = []

for fov_idx in fov_indices:
    fov_arrays = []
    fov_df = hdf5_df.loc[fov_idx:fov_idx]
    file_indices = fov_df["File Index"].unique().tolist()
    for channel in channels:
        channel_arrays = []
        for file_idx in file_indices:
            infile = h5py.File(headpath + "/hdf5/hdf5_" + str(file_idx) + ".hdf5", "r")
            data = infile[channel]
            array = da.from_array(data, chunks=(1, data.shape[1], data.shape[2]))
            channel_arrays.append(array)
        da_channel_arrays = da.concatenate(channel_arrays, axis=0)
        fov_arrays.append(da_channel_arrays)
    da_fov_arrays = da.stack(fov_arrays, axis=0)
    dask_arrays.append(da_fov_arrays)
x = da.stack(dask_arrays, axis=0)

In [None]:
x

In [None]:
hv.help(image_stack)

In [None]:
kdims = [
    hv.Dimension("phase", range=(0, np.pi)),
    hv.Dimension("frequency", values=[0.1, 1, 2, 5, 10]),
    hv.Dimension("amplitude", values=[0.5, 5, 10]),
]

In [None]:
image_stack.opts(kdims=)

In [None]:
coords

In [None]:
# Wrap in xarray DataArray and label coordinates
dims = [
    "FOV",
    "Channel",
    "time",
    "y",
    "x",
]
coords = {d: np.arange(s) for d, s in zip(dims, x.shape)}
coords["Channel"] = np.array(channels)
xrstack = xr.DataArray(
    x,
    dims=dims,
    coords=coords,
).astype("uint16")

# Wrap in HoloViews Dataset
ds = hv.Dataset(xrstack)

# # Convert to stack of images with x/y-coordinates along axes
image_stack = ds.to(hv.Image, ["x", "y"], dynamic=True)

# # Apply regridding if each image is large
regridded = regrid(image_stack)

# # Set a global Intensity range
# regridded = regridded.redim.range(Intensity=(0, 1000))

# # Set plot options
display_obj = regridded.opts(
    plot={"Image": dict(colorbar=True, width=1000, height=1000, tools=["hover"])}
)
display_obj = display_obj.opts(cmap="Greys_r")

# hist = histogram(image_stack,num_bins=30)
# hist = hist.opts(line_width=0,color="grey", width=200,height=1000)

In [None]:
display_obj  # << hist

In [None]:
def selected_hist(x_range, y_range):
    # Apply current ranges
    obj = img.select(x=x_range, y=y_range) if x_range and y_range else img

    # Compute histogram
    return hv.operation.histogram(obj)


rangexy = hv.streams.RangeXY(source=display_obj)

display_obj << hv.DynamicMap(selected_hist, streams=[rangexy])

In [None]:
def selected_hist(x_range, y_range):
    # Apply current ranges
    obj = img.select(x=x_range, y=y_range) if x_range and y_range else img

    # Compute histogram
    return hv.operation.histogram(obj)


# Define a RangeXY stream linked to the image
rangexy = hv.streams.RangeXY(source=display_obj)

# Adjoin the dynamic histogram computed based on the current ranges
img << hv.DynamicMap(selected_hist, streams=[rangexy])

In [None]:
2**16

In [None]:
display_obj

In [None]:
display_obj + hist

In [None]:
# Set plot and style options
opts.defaults(
    opts.Curve(
        xaxis=None,
        yaxis=None,
        show_grid=False,
        show_frame=False,
        color="orangered",
        framewise=True,
        width=100,
    ),
    opts.Image(
        width=800,
        height=400,
        shared_axes=False,
        logz=True,
        xaxis=None,
        yaxis=None,
        axiswise=True,
    ),
    opts.HLine(color="white", line_width=1),
    opts.Layout(shared_axes=False),
    opts.VLine(color="white", line_width=1),
)

# Read the parquet file
df = dd.read_parquet("./data/nyc_taxi_wide.parq").persist()

# Declare points
points = hv.Points(df, kdims=["pickup_x", "pickup_y"], vdims=[])

# Use datashader to rasterize and linked streams for interactivity
agg = aggregate(points, link_inputs=True, x_sampling=0.0001, y_sampling=0.0001)
pointerx = hv.streams.PointerX(x=np.mean(points.range("pickup_x")), source=points)
pointery = hv.streams.PointerY(y=np.mean(points.range("pickup_y")), source=points)
vline = hv.DynamicMap(lambda x: hv.VLine(x), streams=[pointerx])
hline = hv.DynamicMap(lambda y: hv.HLine(y), streams=[pointery])

sampled = hv.util.Dynamic(
    agg,
    operation=lambda obj, x: obj.sample(pickup_x=x),
    streams=[pointerx],
    link_inputs=False,
)

hvobj = (agg * hline * vline) << sampled

# Obtain Bokeh document and set the title
doc = renderer.server_doc(hvobj)
doc.title = "NYC Taxi Crosshair"